mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -47,6 +47,12 @@ version = "3.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.23.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
@@ -307,9 +313,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.20.2"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
@@ -353,6 +359,7 @@ dependencies = [
|
||||
"cmake",
|
||||
"criterion",
|
||||
"itertools 0.14.0",
|
||||
"once_cell",
|
||||
"poulpy-hal",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
@@ -368,6 +375,7 @@ dependencies = [
|
||||
"byteorder",
|
||||
"criterion",
|
||||
"itertools 0.14.0",
|
||||
"once_cell",
|
||||
"poulpy-backend",
|
||||
"poulpy-hal",
|
||||
"rug",
|
||||
@@ -377,10 +385,12 @@ dependencies = [
|
||||
name = "poulpy-hal"
|
||||
version = "0.1.2"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"cmake",
|
||||
"criterion",
|
||||
"itertools 0.14.0",
|
||||
"once_cell",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
|
||||
@@ -12,3 +12,4 @@ itertools = "0.14.0"
|
||||
criterion = "0.7.0"
|
||||
byteorder = "1.5.0"
|
||||
zstd = "0.13.3"
|
||||
once_cell = "1.21.3"
|
||||
@@ -18,6 +18,7 @@ rand = {workspace = true}
|
||||
rand_distr = {workspace = true}
|
||||
rand_core = {workspace = true}
|
||||
byteorder = {workspace = true}
|
||||
once_cell = {workspace = true}
|
||||
rand_chacha = "0.9.0"
|
||||
|
||||
[build-dependencies]
|
||||
@@ -26,3 +27,8 @@ cmake = "0.1.54"
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "vmp"
|
||||
harness = false
|
||||
224
poulpy-backend/benches/fft.rs
Normal file
224
poulpy-backend/benches/fft.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use std::{ffi::c_void, hint::black_box};
|
||||
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use poulpy_backend::cpu_spqlios::reim;
|
||||
use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTRef, ReimFFTTable, ReimIFFTRef, ReimIFFTTable};
|
||||
|
||||
pub fn bench_fft_ref(c: &mut Criterion) {
|
||||
let group_name: String = "fft_ref".to_string();
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner(m: usize) -> impl FnMut() {
|
||||
let mut values: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / (2 * m) as f64;
|
||||
values
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
|
||||
move || {
|
||||
ReimFFTRef::reim_dft_execute(&table, &mut values);
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for log_m in [9, 10, 11, 12, 13, 14, 15] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m));
|
||||
let mut runner = runner(1 << log_m);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_fft_avx2_fma(c: &mut Criterion) {
|
||||
let group_name: String = "fft_avx2_fma".to_string();
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn runner(m: usize) -> impl FnMut() {
|
||||
let mut values: Vec<f64> = vec![0f64; m << 1];
|
||||
|
||||
let scale = 1.0f64 / (2 * m) as f64;
|
||||
values
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
|
||||
move || {
|
||||
use poulpy_backend::cpu_fft64_avx::ReimFFTAvx;
|
||||
|
||||
ReimFFTAvx::reim_dft_execute(&table, &mut values);
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
if std::is_x86_feature_detected!("avx2") {
|
||||
for log_m in [9, 10, 11, 12, 13, 14, 15] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m));
|
||||
unsafe {
|
||||
let mut runner = runner(1 << log_m);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_fft_spqlios(c: &mut Criterion) {
|
||||
let group_name: String = "fft_spqlios".to_string();
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner(m: usize) -> impl FnMut() {
|
||||
let mut values: Vec<f64> = vec![0f64; m << 1];
|
||||
|
||||
let scale = 1.0f64 / (2 * m) as f64;
|
||||
values
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
unsafe {
|
||||
reim::reim_fft_simple(m as u32, values.as_mut_ptr() as *mut c_void);
|
||||
}
|
||||
|
||||
move || {
|
||||
unsafe {
|
||||
reim::reim_fft_simple(m as u32, values.as_mut_ptr() as *mut c_void);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for log_m in [9, 10, 11, 12, 13, 14, 15] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m));
|
||||
let mut runner = runner(1 << log_m);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_ifft_ref(c: &mut Criterion) {
|
||||
let group_name: String = "ifft_ref".to_string();
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner(m: usize) -> impl FnMut() {
|
||||
let mut values: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / (2 * m) as f64;
|
||||
values
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
|
||||
move || {
|
||||
ReimIFFTRef::reim_dft_execute(&table, &mut values);
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for log_m in [9, 10, 11, 12, 13, 14, 15] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m));
|
||||
let mut runner = runner(1 << log_m);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_ifft_avx2_fma(c: &mut Criterion) {
|
||||
let group_name: String = "ifft_avx2_fma".to_string();
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn runner(m: usize) -> impl FnMut() {
|
||||
let mut values: Vec<f64> = vec![0f64; m << 1];
|
||||
|
||||
let scale = 1.0f64 / (2 * m) as f64;
|
||||
values
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
|
||||
move || {
|
||||
use poulpy_backend::cpu_fft64_avx::ReimIFFTAvx;
|
||||
|
||||
ReimIFFTAvx::reim_dft_execute(&table, &mut values);
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
if std::is_x86_feature_detected!("avx2") {
|
||||
for log_m in [9, 10, 11, 12, 13, 14, 15] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m));
|
||||
unsafe {
|
||||
let mut runner = runner(1 << log_m);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_ifft_spqlios(c: &mut Criterion) {
|
||||
let group_name: String = "ifft_spqlios".to_string();
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner(m: usize) -> impl FnMut() {
|
||||
let mut values: Vec<f64> = vec![0f64; m << 1];
|
||||
|
||||
let scale = 1.0f64 / (2 * m) as f64;
|
||||
values
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
unsafe {
|
||||
reim::reim_ifft_simple(m as u32, values.as_mut_ptr() as *mut c_void);
|
||||
}
|
||||
|
||||
move || {
|
||||
unsafe {
|
||||
reim::reim_ifft_simple(m as u32, values.as_mut_ptr() as *mut c_void);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for log_m in [9, 10, 11, 12, 13, 14, 15] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("n: {}", 2 << log_m));
|
||||
let mut runner = runner(1 << log_m);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_fft_ref,
|
||||
bench_fft_avx2_fma,
|
||||
bench_fft_spqlios,
|
||||
bench_ifft_ref,
|
||||
bench_ifft_avx2_fma,
|
||||
bench_ifft_spqlios
|
||||
);
|
||||
criterion_main!(benches);
|
||||
43
poulpy-backend/benches/vec_znx.rs
Normal file
43
poulpy-backend/benches/vec_znx.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
// poulpy-backend/benches/vec_znx_add.rs
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use poulpy_backend::{cpu_fft64_ref, cpu_spqlios};
|
||||
use poulpy_hal::reference::vec_znx::{bench_vec_znx_add, bench_vec_znx_automorphism, bench_vec_znx_normalize_inplace};
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bench_vec_znx_add_cpu_spqlios_fft64(c: &mut Criterion) {
|
||||
bench_vec_znx_add::<cpu_spqlios::FFT64Spqlios>(c, "cpu_spqlios::fft64");
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bench_vec_znx_add_cpu_ref_fft64(c: &mut Criterion) {
|
||||
bench_vec_znx_add::<cpu_fft64_ref::FFT64Ref>(c, "cpu_spqlios::fft64");
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bench_vec_znx_normalize_inplace_cpu_ref_fft64(c: &mut Criterion) {
|
||||
bench_vec_znx_normalize_inplace::<cpu_fft64_ref::FFT64Ref>(c, "cpu_ref::fft64");
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bench_vec_znx_normalize_inplace_cpu_spqlios_fft64(c: &mut Criterion) {
|
||||
bench_vec_znx_normalize_inplace::<cpu_spqlios::FFT64Spqlios>(c, "cpu_spqlios::fft64");
|
||||
}
|
||||
|
||||
fn bench_vec_znx_automorphism_cpu_ref_fft64(c: &mut Criterion) {
|
||||
bench_vec_znx_automorphism::<cpu_fft64_ref::FFT64Ref>(c, "cpu_ref::fft64");
|
||||
}
|
||||
|
||||
fn bench_vec_znx_automorphism_cpu_spqlios_fft64(c: &mut Criterion) {
|
||||
bench_vec_znx_automorphism::<cpu_spqlios::FFT64Spqlios>(c, "cpu_spqlios::fft64");
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
// bench_vec_znx_add_cpu_spqlios_fft64,
|
||||
// bench_vec_znx_add_cpu_ref_fft64,
|
||||
// bench_vec_znx_normalize_inplace_cpu_ref_fft64,
|
||||
// bench_vec_znx_normalize_inplace_cpu_spqlios_fft64,
|
||||
bench_vec_znx_automorphism_cpu_ref_fft64,
|
||||
bench_vec_znx_automorphism_cpu_spqlios_fft64,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
24
poulpy-backend/benches/vmp.rs
Normal file
24
poulpy-backend/benches/vmp.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
// poulpy-backend/benches/vec_znx_add.rs
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use poulpy_backend::{FFT64Avx, FFT64Ref, FFT64Spqlios};
|
||||
use poulpy_hal::bench_suite::vmp::bench_vmp_apply_dft_to_dft;
|
||||
|
||||
fn bench_vmp_apply_dft_to_dft_cpu_spqlios_fft64(c: &mut Criterion) {
|
||||
bench_vmp_apply_dft_to_dft::<FFT64Spqlios>(c, "cpu_spqlios::fft64");
|
||||
}
|
||||
|
||||
fn bench_vmp_apply_dft_to_dft_cpu_ref_fft64(c: &mut Criterion) {
|
||||
bench_vmp_apply_dft_to_dft::<FFT64Ref>(c, "cpu_ref::fft64");
|
||||
}
|
||||
|
||||
fn bench_vmp_apply_dft_to_dft_cpu_avx_fft64(c: &mut Criterion) {
|
||||
bench_vmp_apply_dft_to_dft::<FFT64Avx>(c, "cpu_avx::fft64");
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_vmp_apply_dft_to_dft_cpu_spqlios_fft64,
|
||||
bench_vmp_apply_dft_to_dft_cpu_ref_fft64,
|
||||
bench_vmp_apply_dft_to_dft_cpu_avx_fft64,
|
||||
);
|
||||
criterion_main!(benches);
|
||||
@@ -1,10 +1,10 @@
|
||||
use itertools::izip;
|
||||
use poulpy_backend::cpu_spqlios::FFT64;
|
||||
use poulpy_backend::cpu_spqlios::FFT64Spqlios;
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTTmpA, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare,
|
||||
VecZnxAddNormal, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxBigSubSmallBInplace, VecZnxDftAlloc, VecZnxFillUniform, VecZnxNormalizeInplace,
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
||||
VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
|
||||
},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, ZnxInfos},
|
||||
source::Source,
|
||||
@@ -16,9 +16,9 @@ fn main() {
|
||||
let ct_size: usize = 3;
|
||||
let msg_size: usize = 2;
|
||||
let log_scale: usize = msg_size * basek - 5;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n as u64);
|
||||
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(n as u64);
|
||||
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::<FFT64>::alloc(module.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::<FFT64Spqlios>::alloc(module.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let seed: [u8; 32] = [0; 32];
|
||||
let mut source: Source = Source::new(seed);
|
||||
@@ -28,7 +28,7 @@ fn main() {
|
||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(s.cols());
|
||||
let mut s_dft: SvpPPol<Vec<u8>, FFT64Spqlios> = module.svp_ppol_alloc(s.cols());
|
||||
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||
@@ -41,14 +41,14 @@ fn main() {
|
||||
);
|
||||
|
||||
// Fill the second column with random values: ct = (0, a)
|
||||
module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source);
|
||||
module.vec_znx_fill_uniform(basek, &mut ct, 1, &mut source);
|
||||
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(1, ct_size);
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size);
|
||||
|
||||
module.dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_inplace(
|
||||
module.svp_apply_dft_to_dft_inplace(
|
||||
&mut buf_dft, // DFT(ct[1] * s)
|
||||
0, // Selects the first column of res
|
||||
&s_dft, // DFT(s)
|
||||
@@ -58,8 +58,8 @@ fn main() {
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(1, ct_size);
|
||||
module.idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64Spqlios> = module.vec_znx_big_alloc(1, ct_size);
|
||||
module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m = VecZnx::alloc(
|
||||
@@ -109,8 +109,8 @@ fn main() {
|
||||
// Decryption
|
||||
|
||||
// DFT(ct[1] * s)
|
||||
module.dft(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.svp_apply_inplace(
|
||||
module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
module.svp_apply_dft_to_dft_inplace(
|
||||
&mut buf_dft,
|
||||
0, // Selects the first column of res.
|
||||
&s_dft,
|
||||
@@ -118,7 +118,7 @@ fn main() {
|
||||
);
|
||||
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// BIG(c1 * s) + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
18
poulpy-backend/src/cpu_fft64_avx/mod.rs
Normal file
18
poulpy-backend/src/cpu_fft64_avx/mod.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
mod module;
|
||||
mod reim;
|
||||
mod reim4;
|
||||
mod scratch;
|
||||
mod svp;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp;
|
||||
mod zn;
|
||||
mod znx_avx;
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
pub struct FFT64Avx {}
|
||||
pub use reim::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
478
poulpy-backend/src/cpu_fft64_avx/module.rs
Normal file
478
poulpy-backend/src/cpu_fft64_avx/module.rs
Normal file
@@ -0,0 +1,478 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
reference::{
|
||||
fft64::{
|
||||
reim::{
|
||||
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx,
|
||||
ReimToZnxInplace, ReimZero, reim_copy_ref, reim_zero_ref,
|
||||
},
|
||||
reim4::{
|
||||
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
|
||||
},
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub,
|
||||
ZnxSubABInplace, ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_rotate, znx_zero_ref,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::{
|
||||
FFT64Avx,
|
||||
reim::{
|
||||
ReimFFTAvx, ReimIFFTAvx, reim_add_avx2_fma, reim_add_inplace_avx2_fma, reim_addmul_avx2_fma, reim_from_znx_i64_bnd50_fma,
|
||||
reim_mul_avx2_fma, reim_mul_inplace_avx2_fma, reim_negate_avx2_fma, reim_negate_inplace_avx2_fma,
|
||||
reim_sub_ab_inplace_avx2_fma, reim_sub_avx2_fma, reim_sub_ba_inplace_avx2_fma, reim_to_znx_i64_inplace_bnd63_avx2_fma,
|
||||
},
|
||||
reim_to_znx_i64_bnd63_avx2_fma,
|
||||
reim4::{
|
||||
reim4_extract_1blk_from_reim_avx, reim4_save_1blk_to_reim_avx, reim4_save_2blk_to_reim_avx,
|
||||
reim4_vec_mat1col_product_avx, reim4_vec_mat2cols_2ndcol_product_avx, reim4_vec_mat2cols_product_avx,
|
||||
},
|
||||
znx_avx::{
|
||||
znx_add_avx, znx_add_inplace_avx, znx_automorphism_avx, znx_negate_avx, znx_negate_inplace_avx,
|
||||
znx_normalize_final_step_avx, znx_normalize_final_step_inplace_avx, znx_normalize_first_step_avx,
|
||||
znx_normalize_first_step_carry_only_avx, znx_normalize_first_step_inplace_avx, znx_normalize_middle_step_avx,
|
||||
znx_normalize_middle_step_carry_only_avx, znx_normalize_middle_step_inplace_avx, znx_sub_ab_inplace_avx, znx_sub_avx,
|
||||
znx_sub_ba_inplace_avx, znx_switch_ring_avx,
|
||||
},
|
||||
};
|
||||
|
||||
#[repr(C)]
|
||||
pub struct FFT64AvxHandle {
|
||||
table_fft: ReimFFTTable<f64>,
|
||||
table_ifft: ReimIFFTTable<f64>,
|
||||
}
|
||||
|
||||
impl Backend for FFT64Avx {
|
||||
type ScalarPrep = f64;
|
||||
type ScalarBig = i64;
|
||||
type Handle = FFT64AvxHandle;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe {
|
||||
drop(Box::from_raw(handle.as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
fn layout_big_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn layout_prep_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64Avx {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
if !std::arch::is_x86_feature_detected!("avx")
|
||||
|| !std::arch::is_x86_feature_detected!("avx2")
|
||||
|| !std::arch::is_x86_feature_detected!("fma")
|
||||
{
|
||||
panic!("arch must support avx2, avx and fma")
|
||||
}
|
||||
|
||||
let handle: FFT64AvxHandle = FFT64AvxHandle {
|
||||
table_fft: ReimFFTTable::new(n as usize >> 1),
|
||||
table_ifft: ReimIFFTTable::new(n as usize >> 1),
|
||||
};
|
||||
// Leak Box to get a stable NonNull pointer
|
||||
let ptr: NonNull<FFT64AvxHandle> = NonNull::from(Box::leak(Box::new(handle)));
|
||||
unsafe { Module::from_nonnull(ptr, n) }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FFT64ModuleHandle {
|
||||
fn get_fft_table(&self) -> &ReimFFTTable<f64>;
|
||||
fn get_ifft_table(&self) -> &ReimIFFTTable<f64>;
|
||||
}
|
||||
|
||||
impl FFT64ModuleHandle for Module<FFT64Avx> {
|
||||
fn get_fft_table(&self) -> &ReimFFTTable<f64> {
|
||||
let h: &FFT64AvxHandle = unsafe { &*self.ptr() };
|
||||
&h.table_fft
|
||||
}
|
||||
fn get_ifft_table(&self) -> &ReimIFFTTable<f64> {
|
||||
let h: &FFT64AvxHandle = unsafe { &*self.ptr() };
|
||||
&h.table_ifft
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAdd for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
unsafe {
|
||||
znx_add_avx(res, a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAddInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_add_inplace_avx(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSub for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
unsafe {
|
||||
znx_sub_avx(res, a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubABInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_sub_ab_inplace_avx(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubBAInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_sub_ba_inplace_avx(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAutomorphism for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_automorphism_avx(p, res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegate for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]) {
|
||||
unsafe {
|
||||
znx_negate_avx(res, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegateInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_negate_inplace(res: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_negate_inplace_avx(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
znx_rotate::<Self>(p, res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
znx_switch_ring_avx(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_final_step_avx(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_final_step_inplace_avx(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_first_step_avx(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_first_step_carry_only_avx(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_first_step_inplace_avx(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_middle_step_avx(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_middle_step_carry_only_avx(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
unsafe {
|
||||
znx_normalize_middle_step_inplace_avx(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
ReimFFTAvx::reim_dft_execute(table, data);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
ReimIFFTAvx::reim_dft_execute(table, data);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimFromZnx for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_from_znx(res: &mut [f64], a: &[i64]) {
|
||||
unsafe {
|
||||
reim_from_znx_i64_bnd50_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimToZnx for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
unsafe {
|
||||
reim_to_znx_i64_bnd63_avx2_fma(res, divisor, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimToZnxInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64) {
|
||||
unsafe {
|
||||
reim_to_znx_i64_inplace_bnd63_avx2_fma(res, divisor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimAdd for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
unsafe {
|
||||
reim_add_avx2_fma(res, a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimAddInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_add_inplace(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_add_inplace_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSub for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
unsafe {
|
||||
reim_sub_avx2_fma(res, a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubABInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_sub_ab_inplace_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubBAInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_sub_ba_inplace_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimNegate for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_negate(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_negate_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimNegateInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_negate_inplace(res: &mut [f64]) {
|
||||
unsafe {
|
||||
reim_negate_inplace_avx2_fma(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimMul for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
unsafe {
|
||||
reim_mul_avx2_fma(res, a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimMulInplace for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_mul_inplace(res: &mut [f64], a: &[f64]) {
|
||||
unsafe {
|
||||
reim_mul_inplace_avx2_fma(res, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimAddMul for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
unsafe {
|
||||
reim_addmul_avx2_fma(res, a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimCopy for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]) {
|
||||
reim_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimZero for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim_zero(res: &mut [f64]) {
|
||||
reim_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Extract1Blk for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
unsafe {
|
||||
reim4_extract_1blk_from_reim_avx(m, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Save1Blk for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
unsafe {
|
||||
reim4_save_1blk_to_reim_avx::<OVERWRITE>(m, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Save2Blks for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
unsafe {
|
||||
reim4_save_2blk_to_reim_avx::<OVERWRITE>(m, blk, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Mat1ColProd for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
unsafe {
|
||||
reim4_vec_mat1col_product_avx(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Mat2ColsProd for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
unsafe {
|
||||
reim4_vec_mat2cols_product_avx(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Mat2Cols2ndColProd for FFT64Avx {
|
||||
#[inline(always)]
|
||||
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
unsafe {
|
||||
reim4_vec_mat2cols_2ndcol_product_avx(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
271
poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs
Normal file
271
poulpy-backend/src/cpu_fft64_avx/reim/conversion.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
/// # Correctness
|
||||
/// Ensured for inputs absolute value bounded by 2^50-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "fma")]
|
||||
pub fn reim_from_znx_i64_bnd50_fma(res: &mut [f64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_epi64, _mm256_castsi256_pd, _mm256_loadu_si256, _mm256_or_pd, _mm256_set1_epi64x,
|
||||
_mm256_set1_pd, _mm256_storeu_pd, _mm256_sub_pd,
|
||||
};
|
||||
|
||||
let expo: f64 = (1i64 << 52) as f64;
|
||||
let add_cst: i64 = 1i64 << 51;
|
||||
let sub_cst: f64 = (3i64 << 51) as f64;
|
||||
|
||||
let expo_256: __m256d = _mm256_set1_pd(expo);
|
||||
let add_cst_256: __m256i = _mm256_set1_epi64x(add_cst);
|
||||
let sub_cst_256: __m256d = _mm256_set1_pd(sub_cst);
|
||||
|
||||
let mut res_ptr: *mut f64 = res.as_mut_ptr();
|
||||
let mut a_ptr: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
let mut ai64_256: __m256i = _mm256_loadu_si256(a_ptr);
|
||||
|
||||
ai64_256 = _mm256_add_epi64(ai64_256, add_cst_256);
|
||||
|
||||
let mut af64_256: __m256d = _mm256_castsi256_pd(ai64_256);
|
||||
af64_256 = _mm256_or_pd(af64_256, expo_256);
|
||||
af64_256 = _mm256_sub_pd(af64_256, sub_cst_256);
|
||||
|
||||
_mm256_storeu_pd(res_ptr, af64_256);
|
||||
|
||||
res_ptr = res_ptr.add(4);
|
||||
a_ptr = a_ptr.add(1);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::fft64::reim::reim_from_znx_i64_ref;
|
||||
reim_from_znx_i64_ref(&mut res[span << 2..], &a[span << 2..])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Correctness
|
||||
/// Only ensured for inputs absoluate value bounded by 2^63-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
|
||||
#[allow(dead_code)]
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_to_znx_i64_bnd63_avx2_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
let sign_mask: u64 = 0x8000000000000000u64;
|
||||
let expo_mask: u64 = 0x7FF0000000000000u64;
|
||||
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
|
||||
let mantissa_msb: u64 = 0x0010000000000000u64;
|
||||
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
|
||||
let offset: f64 = divisor / 2.;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
|
||||
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
|
||||
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
|
||||
};
|
||||
|
||||
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
|
||||
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
|
||||
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
|
||||
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
|
||||
let offset_256 = _mm256_set1_pd(offset);
|
||||
let divi_bits_256 = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
|
||||
|
||||
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut a_ptr: *const f64 = a.as_ptr();
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
// read the next value
|
||||
use std::arch::x86_64::_mm256_storeu_si256;
|
||||
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
|
||||
|
||||
// a += sign(a) * m/2
|
||||
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
|
||||
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
|
||||
|
||||
// sign: either 0 or -1
|
||||
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
|
||||
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
|
||||
|
||||
// compute the exponents
|
||||
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
|
||||
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
|
||||
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
|
||||
a0lsh = _mm256_srli_epi64(a0lsh, 52);
|
||||
a0rsh = _mm256_srli_epi64(a0rsh, 52);
|
||||
|
||||
// compute the new mantissa
|
||||
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
|
||||
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
|
||||
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
|
||||
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
|
||||
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
|
||||
|
||||
// negate if the sign was negative
|
||||
out = _mm256_xor_si256(out, sign_mask);
|
||||
out = _mm256_sub_epi64(out, sign_mask);
|
||||
|
||||
// stores
|
||||
_mm256_storeu_si256(res_ptr, out);
|
||||
|
||||
res_ptr = res_ptr.add(1);
|
||||
a_ptr = a_ptr.add(4);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
|
||||
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Correctness
|
||||
/// Only ensured for inputs absoluate value bounded by 2^63-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma,avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_to_znx_i64_inplace_bnd63_avx2_fma(res: &mut [f64], divisor: f64) {
|
||||
let sign_mask: u64 = 0x8000000000000000u64;
|
||||
let expo_mask: u64 = 0x7FF0000000000000u64;
|
||||
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
|
||||
let mantissa_msb: u64 = 0x0010000000000000u64;
|
||||
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
|
||||
let offset: f64 = divisor / 2.;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
|
||||
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
|
||||
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
|
||||
};
|
||||
|
||||
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_inplace_ref;
|
||||
|
||||
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
|
||||
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
|
||||
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
|
||||
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
|
||||
let offset_256: __m256d = _mm256_set1_pd(offset);
|
||||
let divi_bits_256: __m256i = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
|
||||
|
||||
let mut res_ptr_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut res_ptr_1xf64: *mut f64 = res.as_mut_ptr();
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
// read the next value
|
||||
use std::arch::x86_64::_mm256_storeu_si256;
|
||||
let mut a: __m256d = _mm256_loadu_pd(res_ptr_1xf64);
|
||||
|
||||
// a += sign(a) * m/2
|
||||
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
|
||||
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
|
||||
|
||||
// sign: either 0 or -1
|
||||
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
|
||||
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
|
||||
|
||||
// compute the exponents
|
||||
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
|
||||
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
|
||||
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
|
||||
a0lsh = _mm256_srli_epi64(a0lsh, 52);
|
||||
a0rsh = _mm256_srli_epi64(a0rsh, 52);
|
||||
|
||||
// compute the new mantissa
|
||||
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
|
||||
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
|
||||
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
|
||||
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
|
||||
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
|
||||
|
||||
// negate if the sign was negative
|
||||
out = _mm256_xor_si256(out, sign_mask);
|
||||
out = _mm256_sub_epi64(out, sign_mask);
|
||||
|
||||
// stores
|
||||
_mm256_storeu_si256(res_ptr_4xi64, out);
|
||||
|
||||
res_ptr_4xi64 = res_ptr_4xi64.add(1);
|
||||
res_ptr_1xf64 = res_ptr_1xf64.add(4);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
reim_to_znx_i64_inplace_ref(&mut res[span << 2..], divisor)
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
/// # Correctness
|
||||
/// Only ensured for inputs absoluate value bounded by 2^50-1
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports FMA (e.g., via `is_x86_feature_detected!("fma")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "fma")]
|
||||
#[allow(dead_code)]
|
||||
pub fn reim_to_znx_i64_avx2_bnd50_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{
|
||||
__m256d, __m256i, _mm256_add_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_loadu_pd, _mm256_set1_epi64x,
|
||||
_mm256_set1_pd, _mm256_storeu_si256, _mm256_sub_epi64,
|
||||
};
|
||||
|
||||
let mantissa_mask: u64 = 0x000FFFFFFFFFFFFFu64;
|
||||
let sub_cst: i64 = 1i64 << 51;
|
||||
let add_cst: f64 = divisor * (3i64 << 51) as f64;
|
||||
|
||||
let sub_cst_4: __m256i = _mm256_set1_epi64x(sub_cst);
|
||||
let add_cst_4: std::arch::x86_64::__m256d = _mm256_set1_pd(add_cst);
|
||||
let mantissa_mask_4: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
|
||||
|
||||
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut a_ptr = a.as_ptr();
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
for _ in 0..span {
|
||||
// read the next value
|
||||
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
|
||||
a = _mm256_add_pd(a, add_cst_4);
|
||||
let mut ai: __m256i = _mm256_castpd_si256(a);
|
||||
ai = _mm256_and_si256(ai, mantissa_mask_4);
|
||||
ai = _mm256_sub_epi64(ai, sub_cst_4);
|
||||
// store the next value
|
||||
_mm256_storeu_si256(res_ptr, ai);
|
||||
|
||||
res_ptr = res_ptr.add(1);
|
||||
a_ptr = a_ptr.add(4);
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
|
||||
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
|
||||
}
|
||||
}
|
||||
}
|
||||
162
poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s
Normal file
162
poulpy-backend/src/cpu_fft64_avx/reim/fft16_avx2_fma.s
Normal file
@@ -0,0 +1,162 @@
|
||||
# ----------------------------------------------------------------------
|
||||
# This kernel is a direct port of the FFT16 routine from spqlios-arithmetic
|
||||
# (https://github.com/tfhe/spqlios-arithmetic)
|
||||
# ----------------------------------------------------------------------
|
||||
#
|
||||
|
||||
.text
|
||||
.globl fft16_avx2_fma_asm
|
||||
.hidden fft16_avx2_fma_asm
|
||||
.p2align 4, 0x90
|
||||
.type fft16_avx2_fma_asm,@function
|
||||
fft16_avx2_fma_asm:
|
||||
.att_syntax prefix
|
||||
|
||||
# SysV args: %rdi = re*, %rsi = im*, %rdx = omg*
|
||||
# stage 0: load inputs
|
||||
vmovupd (%rdi),%ymm0 # ra0
|
||||
vmovupd 0x20(%rdi),%ymm1 # ra4
|
||||
vmovupd 0x40(%rdi),%ymm2 # ra8
|
||||
vmovupd 0x60(%rdi),%ymm3 # ra12
|
||||
vmovupd (%rsi),%ymm4 # ia0
|
||||
vmovupd 0x20(%rsi),%ymm5 # ia4
|
||||
vmovupd 0x40(%rsi),%ymm6 # ia8
|
||||
vmovupd 0x60(%rsi),%ymm7 # ia12
|
||||
|
||||
# stage 1
|
||||
vmovupd (%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
|
||||
vmulpd %ymm6,%ymm13,%ymm8
|
||||
vmulpd %ymm7,%ymm13,%ymm9
|
||||
vmulpd %ymm2,%ymm13,%ymm10
|
||||
vmulpd %ymm3,%ymm13,%ymm11
|
||||
vfmsub231pd %ymm2,%ymm12,%ymm8
|
||||
vfmsub231pd %ymm3,%ymm12,%ymm9
|
||||
vfmadd231pd %ymm6,%ymm12,%ymm10
|
||||
vfmadd231pd %ymm7,%ymm12,%ymm11
|
||||
vsubpd %ymm8,%ymm0,%ymm2
|
||||
vsubpd %ymm9,%ymm1,%ymm3
|
||||
vsubpd %ymm10,%ymm4,%ymm6
|
||||
vsubpd %ymm11,%ymm5,%ymm7
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vaddpd %ymm9,%ymm1,%ymm1
|
||||
vaddpd %ymm10,%ymm4,%ymm4
|
||||
vaddpd %ymm11,%ymm5,%ymm5
|
||||
|
||||
# stage 2
|
||||
vmovupd 16(%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
|
||||
vmulpd %ymm5,%ymm13,%ymm8
|
||||
vmulpd %ymm7,%ymm12,%ymm9
|
||||
vmulpd %ymm1,%ymm13,%ymm10
|
||||
vmulpd %ymm3,%ymm12,%ymm11
|
||||
vfmsub231pd %ymm1,%ymm12,%ymm8
|
||||
vfmadd231pd %ymm3,%ymm13,%ymm9
|
||||
vfmadd231pd %ymm5,%ymm12,%ymm10
|
||||
vfmsub231pd %ymm7,%ymm13,%ymm11
|
||||
vsubpd %ymm8,%ymm0,%ymm1
|
||||
vaddpd %ymm9,%ymm2,%ymm3
|
||||
vsubpd %ymm10,%ymm4,%ymm5
|
||||
vaddpd %ymm11,%ymm6,%ymm7
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vsubpd %ymm9,%ymm2,%ymm2
|
||||
vaddpd %ymm10,%ymm4,%ymm4
|
||||
vsubpd %ymm11,%ymm6,%ymm6
|
||||
|
||||
# stage 3
|
||||
vmovupd 0x20(%rdx),%ymm12
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # omar
|
||||
|
||||
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8
|
||||
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9
|
||||
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10
|
||||
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11
|
||||
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0
|
||||
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1
|
||||
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2
|
||||
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3
|
||||
|
||||
vmulpd %ymm10,%ymm13,%ymm4
|
||||
vmulpd %ymm11,%ymm12,%ymm5
|
||||
vmulpd %ymm8,%ymm13,%ymm6
|
||||
vmulpd %ymm9,%ymm12,%ymm7
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm4
|
||||
vfmadd231pd %ymm9,%ymm13,%ymm5
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm6
|
||||
vfmsub231pd %ymm11,%ymm13,%ymm7
|
||||
vsubpd %ymm4,%ymm0,%ymm8
|
||||
vaddpd %ymm5,%ymm1,%ymm9
|
||||
vsubpd %ymm6,%ymm2,%ymm10
|
||||
vaddpd %ymm7,%ymm3,%ymm11
|
||||
vaddpd %ymm4,%ymm0,%ymm0
|
||||
vsubpd %ymm5,%ymm1,%ymm1
|
||||
vaddpd %ymm6,%ymm2,%ymm2
|
||||
vsubpd %ymm7,%ymm3,%ymm3
|
||||
|
||||
# stage 4
|
||||
vmovupd 0x40(%rdx),%ymm12
|
||||
vmovupd 0x60(%rdx),%ymm13
|
||||
|
||||
vunpckhpd %ymm1,%ymm0,%ymm4
|
||||
vunpckhpd %ymm3,%ymm2,%ymm6
|
||||
vunpckhpd %ymm9,%ymm8,%ymm5
|
||||
vunpckhpd %ymm11,%ymm10,%ymm7
|
||||
vunpcklpd %ymm1,%ymm0,%ymm0
|
||||
vunpcklpd %ymm3,%ymm2,%ymm2
|
||||
vunpcklpd %ymm9,%ymm8,%ymm1
|
||||
vunpcklpd %ymm11,%ymm10,%ymm3
|
||||
|
||||
vmulpd %ymm6,%ymm13,%ymm8
|
||||
vmulpd %ymm7,%ymm12,%ymm9
|
||||
vmulpd %ymm4,%ymm13,%ymm10
|
||||
vmulpd %ymm5,%ymm12,%ymm11
|
||||
vfmsub231pd %ymm4,%ymm12,%ymm8
|
||||
vfmadd231pd %ymm5,%ymm13,%ymm9
|
||||
vfmadd231pd %ymm6,%ymm12,%ymm10
|
||||
vfmsub231pd %ymm7,%ymm13,%ymm11
|
||||
vsubpd %ymm8,%ymm0,%ymm4
|
||||
vaddpd %ymm9,%ymm1,%ymm5
|
||||
vsubpd %ymm10,%ymm2,%ymm6
|
||||
vaddpd %ymm11,%ymm3,%ymm7
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vsubpd %ymm9,%ymm1,%ymm1
|
||||
vaddpd %ymm10,%ymm2,%ymm2
|
||||
vsubpd %ymm11,%ymm3,%ymm3
|
||||
|
||||
vunpckhpd %ymm7,%ymm3,%ymm11
|
||||
vunpckhpd %ymm5,%ymm1,%ymm9
|
||||
vunpcklpd %ymm7,%ymm3,%ymm10
|
||||
vunpcklpd %ymm5,%ymm1,%ymm8
|
||||
vunpckhpd %ymm6,%ymm2,%ymm3
|
||||
vunpckhpd %ymm4,%ymm0,%ymm1
|
||||
vunpcklpd %ymm6,%ymm2,%ymm2
|
||||
vunpcklpd %ymm4,%ymm0,%ymm0
|
||||
|
||||
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
|
||||
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
|
||||
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
|
||||
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
|
||||
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
|
||||
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
|
||||
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
|
||||
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
|
||||
|
||||
# stores
|
||||
vmovupd %ymm0,(%rdi) # ra0
|
||||
vmovupd %ymm1,0x20(%rdi) # ra4
|
||||
vmovupd %ymm2,0x40(%rdi) # ra8
|
||||
vmovupd %ymm3,0x60(%rdi) # ra12
|
||||
vmovupd %ymm4,(%rsi) # ia0
|
||||
vmovupd %ymm5,0x20(%rsi) # ia4
|
||||
vmovupd %ymm6,0x40(%rsi) # ia8
|
||||
vmovupd %ymm7,0x60(%rsi) # ia12
|
||||
vzeroupper
|
||||
ret
|
||||
|
||||
.size fft16_avx2_fma_asm, .-fft16_avx2_fma_asm
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
278
poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs
Normal file
278
poulpy-backend/src/cpu_fft64_avx/reim/fft_avx2_fma.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use std::arch::x86_64::{
|
||||
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
|
||||
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub(crate) fn fft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
|
||||
if m < 16 {
|
||||
use poulpy_hal::reference::fft64::reim::fft_ref;
|
||||
|
||||
fft_ref(m, omg, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m == 16 {
|
||||
fft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(re),
|
||||
as_arr_mut::<16, f64>(im),
|
||||
as_arr::<16, f64>(omg),
|
||||
)
|
||||
} else if m <= 2048 {
|
||||
fft_bfs_16_avx2_fma(m, re, im, omg, 0);
|
||||
} else {
|
||||
fft_rec_16_avx2_fma(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "sysv64" {
|
||||
unsafe fn fft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn fft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
|
||||
unsafe {
|
||||
fft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn fft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fft_bfs_16_avx2_fma(m, re, im, omg, pos);
|
||||
};
|
||||
|
||||
let h: usize = m >> 1;
|
||||
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos = fft_rec_16_avx2_fma(h, re, im, omg, pos);
|
||||
pos = fft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn fft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h: usize = mm >> 1;
|
||||
twiddle_fft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
mm = h
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_avx2_fma(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, f64>(&omg[pos..]),
|
||||
);
|
||||
|
||||
pos += 4;
|
||||
}
|
||||
mm = h
|
||||
}
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
fft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(&mut re[off..]),
|
||||
as_arr_mut::<16, f64>(&mut im[off..]),
|
||||
as_arr::<16, f64>(&omg[pos..]),
|
||||
);
|
||||
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn twiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
|
||||
unsafe {
|
||||
let omx: __m128d = _mm_load_pd(omg.as_ptr());
|
||||
let omra: __m256d = _mm256_set_m128d(omx, omx);
|
||||
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
|
||||
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let mut tra: __m256d = _mm256_mul_pd(omi, ui1);
|
||||
let mut tia: __m256d = _mm256_mul_pd(omi, ur1);
|
||||
|
||||
tra = _mm256_fmsub_pd(omr, ur1, tra);
|
||||
tia = _mm256_fmadd_pd(omr, ui1, tia);
|
||||
ur1 = _mm256_sub_pd(ur0, tra);
|
||||
ui1 = _mm256_sub_pd(ui0, tia);
|
||||
ur0 = _mm256_add_pd(ur0, tra);
|
||||
ui0 = _mm256_add_pd(ui0, tia);
|
||||
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn bitwiddle_fft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
|
||||
unsafe {
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
|
||||
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
|
||||
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
|
||||
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
|
||||
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
|
||||
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
|
||||
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
|
||||
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
|
||||
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
|
||||
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ur2: __m256d = _mm256_loadu_pd(r2);
|
||||
let mut ur3: __m256d = _mm256_loadu_pd(r3);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let mut ui2: __m256d = _mm256_loadu_pd(i2);
|
||||
let mut ui3: __m256d = _mm256_loadu_pd(i3);
|
||||
|
||||
let mut tra: __m256d = _mm256_mul_pd(omai, ui2);
|
||||
let mut trb: __m256d = _mm256_mul_pd(omai, ui3);
|
||||
let mut tia: __m256d = _mm256_mul_pd(omai, ur2);
|
||||
let mut tib: __m256d = _mm256_mul_pd(omai, ur3);
|
||||
tra = _mm256_fmsub_pd(omar, ur2, tra);
|
||||
trb = _mm256_fmsub_pd(omar, ur3, trb);
|
||||
tia = _mm256_fmadd_pd(omar, ui2, tia);
|
||||
tib = _mm256_fmadd_pd(omar, ui3, tib);
|
||||
ur2 = _mm256_sub_pd(ur0, tra);
|
||||
ur3 = _mm256_sub_pd(ur1, trb);
|
||||
ui2 = _mm256_sub_pd(ui0, tia);
|
||||
ui3 = _mm256_sub_pd(ui1, tib);
|
||||
ur0 = _mm256_add_pd(ur0, tra);
|
||||
ur1 = _mm256_add_pd(ur1, trb);
|
||||
ui0 = _mm256_add_pd(ui0, tia);
|
||||
ui1 = _mm256_add_pd(ui1, tib);
|
||||
|
||||
tra = _mm256_mul_pd(ombi, ui1);
|
||||
trb = _mm256_mul_pd(ombr, ui3);
|
||||
tia = _mm256_mul_pd(ombi, ur1);
|
||||
tib = _mm256_mul_pd(ombr, ur3);
|
||||
tra = _mm256_fmsub_pd(ombr, ur1, tra);
|
||||
trb = _mm256_fmadd_pd(ombi, ur3, trb);
|
||||
tia = _mm256_fmadd_pd(ombr, ui1, tia);
|
||||
tib = _mm256_fmsub_pd(ombi, ui3, tib);
|
||||
ur1 = _mm256_sub_pd(ur0, tra);
|
||||
ur3 = _mm256_add_pd(ur2, trb);
|
||||
ui1 = _mm256_sub_pd(ui0, tia);
|
||||
ui3 = _mm256_add_pd(ui2, tib);
|
||||
ur0 = _mm256_add_pd(ur0, tra);
|
||||
ur2 = _mm256_sub_pd(ur2, trb);
|
||||
ui0 = _mm256_add_pd(ui0, tia);
|
||||
ui2 = _mm256_sub_pd(ui2, tib);
|
||||
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(r2, ur2);
|
||||
_mm256_storeu_pd(r3, ur3);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
_mm256_storeu_pd(i2, ui2);
|
||||
_mm256_storeu_pd(i3, ui3);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
r2 = r2.add(4);
|
||||
r3 = r3.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
i2 = i2.add(4);
|
||||
i3 = i3.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_avx2_fma() {
|
||||
use super::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn internal(log_m: usize) {
|
||||
use poulpy_hal::reference::fft64::reim::ReimFFTRef;
|
||||
|
||||
let m = 1 << log_m;
|
||||
|
||||
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
|
||||
|
||||
let mut values_0: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / m as f64;
|
||||
values_0
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let mut values_1: Vec<f64> = vec![0f64; m << 1];
|
||||
values_1
|
||||
.iter_mut()
|
||||
.zip(values_0.iter())
|
||||
.for_each(|(y, x)| *y = *x);
|
||||
|
||||
ReimFFTAvx::reim_dft_execute(&table, &mut values_0);
|
||||
ReimFFTRef::reim_dft_execute(&table, &mut values_1);
|
||||
|
||||
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
|
||||
|
||||
for i in 0..m * 2 {
|
||||
let diff: f64 = (values_0[i] - values_1[i]).abs();
|
||||
assert!(
|
||||
diff <= max_diff,
|
||||
"{} -> {}-{} = {}",
|
||||
i,
|
||||
values_0[i],
|
||||
values_1[i],
|
||||
diff
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if std::is_x86_feature_detected!("avx2") {
|
||||
for log_m in 0..16 {
|
||||
unsafe { internal(log_m) }
|
||||
}
|
||||
} else {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
}
|
||||
}
|
||||
350
poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs
Normal file
350
poulpy-backend/src/cpu_fft64_avx/reim/fft_vec_avx2_fma.rs
Normal file
@@ -0,0 +1,350 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_add_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
let mut bb: *const f64 = b.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let b_256: __m256d = _mm256_loadu_pd(bb);
|
||||
_mm256_storeu_pd(rr, _mm256_add_pd(a_256, b_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
bb = bb.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_add_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_add_pd(r_256, a_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
let mut bb: *const f64 = b.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let b_256: __m256d = _mm256_loadu_pd(bb);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, b_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
bb = bb.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_ab_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(r_256, a_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_sub_ba_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_sub_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_sub_pd(a_256, r_256));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_negate_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::_mm256_set1_pd;
|
||||
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let mut aa: *const f64 = a.as_ptr();
|
||||
|
||||
let neg0: __m256d = _mm256_set1_pd(-0.0);
|
||||
|
||||
for _ in 0..span {
|
||||
let a_256: __m256d = _mm256_loadu_pd(aa);
|
||||
_mm256_storeu_pd(rr, _mm256_xor_pd(a_256, neg0));
|
||||
rr = rr.add(4);
|
||||
aa = aa.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_negate_inplace_avx2_fma(res: &mut [f64]) {
|
||||
use std::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd, _mm256_xor_pd};
|
||||
|
||||
let span: usize = res.len() >> 2;
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::_mm256_set1_pd;
|
||||
|
||||
let mut rr: *mut f64 = res.as_mut_ptr();
|
||||
let neg0: __m256d = _mm256_set1_pd(-0.0);
|
||||
|
||||
for _ in 0..span {
|
||||
let r_256: __m256d = _mm256_loadu_pd(rr);
|
||||
_mm256_storeu_pd(rr, _mm256_xor_pd(r_256, neg0));
|
||||
rr = rr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_addmul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
unsafe {
|
||||
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
|
||||
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
|
||||
let mut ar_ptr: *const f64 = ar.as_ptr();
|
||||
let mut ai_ptr: *const f64 = ai.as_ptr();
|
||||
let mut br_ptr: *const f64 = br.as_ptr();
|
||||
let mut bi_ptr: *const f64 = bi.as_ptr();
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
for _ in 0..(m >> 2) {
|
||||
let mut rr: __m256d = _mm256_loadu_pd(rr_ptr);
|
||||
let mut ri: __m256d = _mm256_loadu_pd(ri_ptr);
|
||||
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
|
||||
let br: __m256d = _mm256_loadu_pd(br_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
|
||||
|
||||
rr = _mm256_fmsub_pd(ai, bi, rr);
|
||||
rr = _mm256_fmsub_pd(ar, br, rr);
|
||||
ri = _mm256_fmadd_pd(ar, bi, ri);
|
||||
ri = _mm256_fmadd_pd(ai, br, ri);
|
||||
|
||||
_mm256_storeu_pd(rr_ptr, rr);
|
||||
_mm256_storeu_pd(ri_ptr, ri);
|
||||
|
||||
rr_ptr = rr_ptr.add(4);
|
||||
ri_ptr = ri_ptr.add(4);
|
||||
ar_ptr = ar_ptr.add(4);
|
||||
ai_ptr = ai_ptr.add(4);
|
||||
br_ptr = br_ptr.add(4);
|
||||
bi_ptr = bi_ptr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_mul_avx2_fma(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
unsafe {
|
||||
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
|
||||
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
|
||||
let mut ar_ptr: *const f64 = ar.as_ptr();
|
||||
let mut ai_ptr: *const f64 = ai.as_ptr();
|
||||
let mut br_ptr: *const f64 = br.as_ptr();
|
||||
let mut bi_ptr: *const f64 = bi.as_ptr();
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
|
||||
|
||||
for _ in 0..(m >> 2) {
|
||||
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
|
||||
let br: __m256d = _mm256_loadu_pd(br_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(bi_ptr);
|
||||
|
||||
let t1: __m256d = _mm256_mul_pd(ai, bi);
|
||||
let t2: __m256d = _mm256_mul_pd(ar, bi);
|
||||
|
||||
let rr: __m256d = _mm256_fmsub_pd(ar, br, t1);
|
||||
let ri: __m256d = _mm256_fmadd_pd(ai, br, t2);
|
||||
|
||||
_mm256_storeu_pd(rr_ptr, rr);
|
||||
_mm256_storeu_pd(ri_ptr, ri);
|
||||
|
||||
rr_ptr = rr_ptr.add(4);
|
||||
ri_ptr = ri_ptr.add(4);
|
||||
ar_ptr = ar_ptr.add(4);
|
||||
ai_ptr = ai_ptr.add(4);
|
||||
br_ptr = br_ptr.add(4);
|
||||
bi_ptr = bi_ptr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim_mul_inplace_avx2_fma(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
|
||||
unsafe {
|
||||
let mut rr_ptr: *mut f64 = rr.as_mut_ptr();
|
||||
let mut ri_ptr: *mut f64 = ri.as_mut_ptr();
|
||||
let mut ar_ptr: *const f64 = ar.as_ptr();
|
||||
let mut ai_ptr: *const f64 = ai.as_ptr();
|
||||
|
||||
use std::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_storeu_pd};
|
||||
|
||||
for _ in 0..(m >> 2) {
|
||||
let ar: __m256d = _mm256_loadu_pd(ar_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(ai_ptr);
|
||||
let br: __m256d = _mm256_loadu_pd(rr_ptr);
|
||||
let bi: __m256d = _mm256_loadu_pd(ri_ptr);
|
||||
|
||||
let t1: __m256d = _mm256_mul_pd(ai, bi);
|
||||
let t2: __m256d = _mm256_mul_pd(ar, bi);
|
||||
|
||||
let rr = _mm256_fmsub_pd(ar, br, t1);
|
||||
let ri = _mm256_fmadd_pd(ai, br, t2);
|
||||
|
||||
_mm256_storeu_pd(rr_ptr, rr);
|
||||
_mm256_storeu_pd(ri_ptr, ri);
|
||||
|
||||
rr_ptr = rr_ptr.add(4);
|
||||
ri_ptr = ri_ptr.add(4);
|
||||
ar_ptr = ar_ptr.add(4);
|
||||
ai_ptr = ai_ptr.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
181
poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s
Normal file
181
poulpy-backend/src/cpu_fft64_avx/reim/ifft16_avx2_fma.s
Normal file
@@ -0,0 +1,181 @@
|
||||
# ----------------------------------------------------------------------
|
||||
# This kernel is a direct port of the IFFT16 routine from spqlios-arithmetic
|
||||
# (https://github.com/tfhe/spqlios-arithmetic)
|
||||
# ----------------------------------------------------------------------
|
||||
#
|
||||
|
||||
.text
|
||||
.globl ifft16_avx2_fma_asm
|
||||
.hidden ifft16_avx2_fma_asm
|
||||
.p2align 4, 0x90
|
||||
.type ifft16_avx2_fma_asm,@function
|
||||
ifft16_avx2_fma_asm:
|
||||
.att_syntax prefix
|
||||
|
||||
vmovupd (%rdi),%ymm0 # ra0
|
||||
vmovupd 0x20(%rdi),%ymm1 # ra4
|
||||
vmovupd 0x40(%rdi),%ymm2 # ra8
|
||||
vmovupd 0x60(%rdi),%ymm3 # ra12
|
||||
vmovupd (%rsi),%ymm4 # ia0
|
||||
vmovupd 0x20(%rsi),%ymm5 # ia4
|
||||
vmovupd 0x40(%rsi),%ymm6 # ia8
|
||||
vmovupd 0x60(%rsi),%ymm7 # ia12
|
||||
|
||||
1:
|
||||
vmovupd 0x00(%rdx),%ymm12
|
||||
vmovupd 0x20(%rdx),%ymm13
|
||||
|
||||
vperm2f128 $0x31,%ymm2,%ymm0,%ymm8 # ymm8 contains re to mul (tw)
|
||||
vperm2f128 $0x31,%ymm3,%ymm1,%ymm9 # ymm9 contains re to mul (itw)
|
||||
vperm2f128 $0x31,%ymm6,%ymm4,%ymm10 # ymm10 contains im to mul (tw)
|
||||
vperm2f128 $0x31,%ymm7,%ymm5,%ymm11 # ymm11 contains im to mul (itw)
|
||||
vperm2f128 $0x20,%ymm2,%ymm0,%ymm0 # ymm0 contains re to add (tw)
|
||||
vperm2f128 $0x20,%ymm3,%ymm1,%ymm1 # ymm1 contains re to add (itw)
|
||||
vperm2f128 $0x20,%ymm6,%ymm4,%ymm2 # ymm2 contains im to add (tw)
|
||||
vperm2f128 $0x20,%ymm7,%ymm5,%ymm3 # ymm3 contains im to add (itw)
|
||||
|
||||
vunpckhpd %ymm1,%ymm0,%ymm4 # (0,1) -> (0,4)
|
||||
vunpckhpd %ymm3,%ymm2,%ymm6 # (2,3) -> (2,6)
|
||||
vunpckhpd %ymm9,%ymm8,%ymm5 # (8,9) -> (1,5)
|
||||
vunpckhpd %ymm11,%ymm10,%ymm7 # (10,11) -> (3,7)
|
||||
vunpcklpd %ymm1,%ymm0,%ymm0
|
||||
vunpcklpd %ymm3,%ymm2,%ymm2
|
||||
vunpcklpd %ymm9,%ymm8,%ymm1
|
||||
vunpcklpd %ymm11,%ymm10,%ymm3
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm4) and Im:(ymm2,ymm6) with omega=(ymm12,ymm13)
|
||||
# invcitwiddle Re:(ymm1,ymm5) and Im:(ymm3,ymm7) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm4,%ymm0,%ymm8 # retw
|
||||
vsubpd %ymm5,%ymm1,%ymm9 # reitw
|
||||
vsubpd %ymm6,%ymm2,%ymm10 # imtw
|
||||
vsubpd %ymm7,%ymm3,%ymm11 # imitw
|
||||
vaddpd %ymm4,%ymm0,%ymm0
|
||||
vaddpd %ymm5,%ymm1,%ymm1
|
||||
vaddpd %ymm6,%ymm2,%ymm2
|
||||
vaddpd %ymm7,%ymm3,%ymm3
|
||||
# multiply 8,9,10,11 by 12,13, result to: 4,5,6,7
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
# invtwiddles use reom=ymm13, imom=-ymm12
|
||||
vmulpd %ymm10,%ymm13,%ymm4 # imtw.omai (tw)
|
||||
vmulpd %ymm11,%ymm12,%ymm5 # imitw.omar (itw)
|
||||
vmulpd %ymm8,%ymm13,%ymm6 # retw.omai (tw)
|
||||
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm4 # rprod0 (tw)
|
||||
vfmadd231pd %ymm9,%ymm13,%ymm5 # rprod4 (itw)
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0 (tw)
|
||||
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
|
||||
|
||||
vunpckhpd %ymm7,%ymm3,%ymm11 # (0,4) -> (0,1)
|
||||
vunpckhpd %ymm5,%ymm1,%ymm9 # (2,6) -> (2,3)
|
||||
vunpcklpd %ymm7,%ymm3,%ymm10
|
||||
vunpcklpd %ymm5,%ymm1,%ymm8
|
||||
vunpckhpd %ymm6,%ymm2,%ymm3 # (1,5) -> (8,9)
|
||||
vunpckhpd %ymm4,%ymm0,%ymm1 # (3,7) -> (10,11)
|
||||
vunpcklpd %ymm6,%ymm2,%ymm2
|
||||
vunpcklpd %ymm4,%ymm0,%ymm0
|
||||
|
||||
2:
|
||||
vmovupd 0x40(%rdx),%ymm12
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omaiii'i'
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omarrr'r'
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm8) and Im:(ymm2,ymm10) with omega=(ymm12,ymm13)
|
||||
# invcitwiddle Re:(ymm1,ymm9) and Im:(ymm3,ymm11) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm8,%ymm0,%ymm4 # retw
|
||||
vsubpd %ymm9,%ymm1,%ymm5 # reitw
|
||||
vsubpd %ymm10,%ymm2,%ymm6 # imtw
|
||||
vsubpd %ymm11,%ymm3,%ymm7 # imitw
|
||||
vaddpd %ymm8,%ymm0,%ymm0
|
||||
vaddpd %ymm9,%ymm1,%ymm1
|
||||
vaddpd %ymm10,%ymm2,%ymm2
|
||||
vaddpd %ymm11,%ymm3,%ymm3
|
||||
# multiply 4,5,6,7 by 12,13, result to 8,9,10,11
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
# invtwiddles use reom=ymm13, imom=-ymm12
|
||||
vmulpd %ymm6,%ymm13,%ymm8 # imtw.omai (tw)
|
||||
vmulpd %ymm7,%ymm12,%ymm9 # imitw.omar (itw)
|
||||
vmulpd %ymm4,%ymm13,%ymm10 # retw.omai (tw)
|
||||
vmulpd %ymm5,%ymm12,%ymm11 # reitw.omar (itw)
|
||||
vfmsub231pd %ymm4,%ymm12,%ymm8 # rprod0 (tw)
|
||||
vfmadd231pd %ymm5,%ymm13,%ymm9 # rprod4 (itw)
|
||||
vfmadd231pd %ymm6,%ymm12,%ymm10 # iprod0 (tw)
|
||||
vfmsub231pd %ymm7,%ymm13,%ymm11 # iprod4 (itw)
|
||||
|
||||
vperm2f128 $0x31,%ymm10,%ymm2,%ymm6
|
||||
vperm2f128 $0x31,%ymm11,%ymm3,%ymm7
|
||||
vperm2f128 $0x20,%ymm10,%ymm2,%ymm4
|
||||
vperm2f128 $0x20,%ymm11,%ymm3,%ymm5
|
||||
vperm2f128 $0x31,%ymm8,%ymm0,%ymm2
|
||||
vperm2f128 $0x31,%ymm9,%ymm1,%ymm3
|
||||
vperm2f128 $0x20,%ymm8,%ymm0,%ymm0
|
||||
vperm2f128 $0x20,%ymm9,%ymm1,%ymm1
|
||||
|
||||
3:
|
||||
vmovupd 0x60(%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm1) and Im:(ymm4,ymm5) with omega=(ymm12,ymm13)
|
||||
# invcitwiddle Re:(ymm2,ymm3) and Im:(ymm6,ymm7) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm1,%ymm0,%ymm8 # retw
|
||||
vsubpd %ymm3,%ymm2,%ymm9 # reitw
|
||||
vsubpd %ymm5,%ymm4,%ymm10 # imtw
|
||||
vsubpd %ymm7,%ymm6,%ymm11 # imitw
|
||||
vaddpd %ymm1,%ymm0,%ymm0
|
||||
vaddpd %ymm3,%ymm2,%ymm2
|
||||
vaddpd %ymm5,%ymm4,%ymm4
|
||||
vaddpd %ymm7,%ymm6,%ymm6
|
||||
# multiply 8,9,10,11 by 12,13, result to 1,3,5,7
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
# invtwiddles use reom=ymm13, imom=-ymm12
|
||||
vmulpd %ymm10,%ymm13,%ymm1 # imtw.omai (tw)
|
||||
vmulpd %ymm11,%ymm12,%ymm3 # imitw.omar (itw)
|
||||
vmulpd %ymm8,%ymm13,%ymm5 # retw.omai (tw)
|
||||
vmulpd %ymm9,%ymm12,%ymm7 # reitw.omar (itw)
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm1 # rprod0 (tw)
|
||||
vfmadd231pd %ymm9,%ymm13,%ymm3 # rprod4 (itw)
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm5 # iprod0 (tw)
|
||||
vfmsub231pd %ymm11,%ymm13,%ymm7 # iprod4 (itw)
|
||||
|
||||
4:
|
||||
vmovupd 0x70(%rdx),%xmm12
|
||||
vinsertf128 $1, %xmm12, %ymm12, %ymm12 # omriri
|
||||
vshufpd $15, %ymm12, %ymm12, %ymm13 # ymm13: omai
|
||||
vshufpd $0, %ymm12, %ymm12, %ymm12 # ymm12: omar
|
||||
|
||||
# invctwiddle Re:(ymm0,ymm2) and Im:(ymm4,ymm6) with omega=(ymm12,ymm13)
|
||||
# invctwiddle Re:(ymm1,ymm3) and Im:(ymm5,ymm7) with omega=(ymm12,ymm13)
|
||||
vsubpd %ymm2,%ymm0,%ymm8 # retw1
|
||||
vsubpd %ymm3,%ymm1,%ymm9 # retw2
|
||||
vsubpd %ymm6,%ymm4,%ymm10 # imtw1
|
||||
vsubpd %ymm7,%ymm5,%ymm11 # imtw2
|
||||
vaddpd %ymm2,%ymm0,%ymm0
|
||||
vaddpd %ymm3,%ymm1,%ymm1
|
||||
vaddpd %ymm6,%ymm4,%ymm4
|
||||
vaddpd %ymm7,%ymm5,%ymm5
|
||||
# multiply 8,9,10,11 by 12,13, result to 2,3,6,7
|
||||
# twiddles use reom=ymm12, imom=ymm13
|
||||
vmulpd %ymm10,%ymm13,%ymm2 # imtw1.omai
|
||||
vmulpd %ymm11,%ymm13,%ymm3 # imtw2.omai
|
||||
vmulpd %ymm8,%ymm13,%ymm6 # retw1.omai
|
||||
vmulpd %ymm9,%ymm13,%ymm7 # retw2.omai
|
||||
vfmsub231pd %ymm8,%ymm12,%ymm2 # rprod0
|
||||
vfmsub231pd %ymm9,%ymm12,%ymm3 # rprod4
|
||||
vfmadd231pd %ymm10,%ymm12,%ymm6 # iprod0
|
||||
vfmadd231pd %ymm11,%ymm12,%ymm7 # iprod4
|
||||
|
||||
5:
|
||||
vmovupd %ymm0,(%rdi) # ra0
|
||||
vmovupd %ymm1,0x20(%rdi) # ra4
|
||||
vmovupd %ymm2,0x40(%rdi) # ra8
|
||||
vmovupd %ymm3,0x60(%rdi) # ra12
|
||||
vmovupd %ymm4,(%rsi) # ia0
|
||||
vmovupd %ymm5,0x20(%rsi) # ia4
|
||||
vmovupd %ymm6,0x40(%rsi) # ia8
|
||||
vmovupd %ymm7,0x60(%rsi) # ia12
|
||||
vzeroupper
|
||||
ret
|
||||
|
||||
.size ifft16_avx_fma, .-ifft16_avx_fma
|
||||
.section .note.GNU-stack,"",@progbits
|
||||
271
poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs
Normal file
271
poulpy-backend/src/cpu_fft64_avx/reim/ifft_avx2_fma.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
use std::arch::x86_64::{
|
||||
__m128d, __m256d, _mm_load_pd, _mm256_add_pd, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_mul_pd,
|
||||
_mm256_permute2f128_pd, _mm256_set_m128d, _mm256_storeu_pd, _mm256_sub_pd, _mm256_unpackhi_pd, _mm256_unpacklo_pd,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub(crate) fn ifft_avx2_fma(m: usize, omg: &[f64], data: &mut [f64]) {
|
||||
if m < 16 {
|
||||
use poulpy_hal::reference::fft64::reim::ifft_ref;
|
||||
ifft_ref(m, omg, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m == 16 {
|
||||
ifft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(re),
|
||||
as_arr_mut::<16, f64>(im),
|
||||
as_arr::<16, f64>(omg),
|
||||
)
|
||||
} else if m <= 2048 {
|
||||
ifft_bfs_16_avx2_fma(m, re, im, omg, 0);
|
||||
} else {
|
||||
ifft_rec_16_avx2_fma(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "sysv64" {
|
||||
unsafe fn ifft16_avx2_fma_asm(re: *mut f64, im: *mut f64, omg: *const f64);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn ifft16_avx2_fma(re: &mut [f64; 16], im: &mut [f64; 16], omg: &[f64; 16]) {
|
||||
unsafe {
|
||||
ifft16_avx2_fma_asm(re.as_mut_ptr(), im.as_mut_ptr(), omg.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn ifft_rec_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return ifft_bfs_16_avx2_fma(m, re, im, omg, pos);
|
||||
};
|
||||
let h: usize = m >> 1;
|
||||
pos = ifft_rec_16_avx2_fma(h, re, im, omg, pos);
|
||||
pos = ifft_rec_16_avx2_fma(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn ifft_bfs_16_avx2_fma(m: usize, re: &mut [f64], im: &mut [f64], omg: &[f64], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
ifft16_avx2_fma(
|
||||
as_arr_mut::<16, f64>(&mut re[off..]),
|
||||
as_arr_mut::<16, f64>(&mut im[off..]),
|
||||
as_arr::<16, f64>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_avx2_fma(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, f64>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
inv_twiddle_ifft_avx2_fma(h, re, im, *as_arr::<2, f64>(&omg[pos..]));
|
||||
pos += 2;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn inv_twiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: [f64; 2]) {
|
||||
unsafe {
|
||||
let omx: __m128d = _mm_load_pd(omg.as_ptr());
|
||||
let omra: __m256d = _mm256_set_m128d(omx, omx);
|
||||
let omi: __m256d = _mm256_unpackhi_pd(omra, omra);
|
||||
let omr: __m256d = _mm256_unpacklo_pd(omra, omra);
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let tra = _mm256_sub_pd(ur0, ur1);
|
||||
let tia = _mm256_sub_pd(ui0, ui1);
|
||||
ur0 = _mm256_add_pd(ur0, ur1);
|
||||
ui0 = _mm256_add_pd(ui0, ui1);
|
||||
ur1 = _mm256_mul_pd(omi, tia);
|
||||
ui1 = _mm256_mul_pd(omi, tra);
|
||||
ur1 = _mm256_fmsub_pd(omr, tra, ur1);
|
||||
ui1 = _mm256_fmadd_pd(omr, tia, ui1);
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn inv_bitwiddle_ifft_avx2_fma(h: usize, re: &mut [f64], im: &mut [f64], omg: &[f64; 4]) {
|
||||
unsafe {
|
||||
let mut r0: *mut f64 = re.as_mut_ptr();
|
||||
let mut r1: *mut f64 = re.as_mut_ptr().add(h);
|
||||
let mut r2: *mut f64 = re.as_mut_ptr().add(2 * h);
|
||||
let mut r3: *mut f64 = re.as_mut_ptr().add(3 * h);
|
||||
let mut i0: *mut f64 = im.as_mut_ptr();
|
||||
let mut i1: *mut f64 = im.as_mut_ptr().add(h);
|
||||
let mut i2: *mut f64 = im.as_mut_ptr().add(2 * h);
|
||||
let mut i3: *mut f64 = im.as_mut_ptr().add(3 * h);
|
||||
let om0: __m256d = _mm256_loadu_pd(omg.as_ptr());
|
||||
let omb: __m256d = _mm256_permute2f128_pd(om0, om0, 0x11);
|
||||
let oma: __m256d = _mm256_permute2f128_pd(om0, om0, 0x00);
|
||||
let omai: __m256d = _mm256_unpackhi_pd(oma, oma);
|
||||
let omar: __m256d = _mm256_unpacklo_pd(oma, oma);
|
||||
let ombi: __m256d = _mm256_unpackhi_pd(omb, omb);
|
||||
let ombr: __m256d = _mm256_unpacklo_pd(omb, omb);
|
||||
for _ in (0..h).step_by(4) {
|
||||
let mut ur0: __m256d = _mm256_loadu_pd(r0);
|
||||
let mut ur1: __m256d = _mm256_loadu_pd(r1);
|
||||
let mut ur2: __m256d = _mm256_loadu_pd(r2);
|
||||
let mut ur3: __m256d = _mm256_loadu_pd(r3);
|
||||
let mut ui0: __m256d = _mm256_loadu_pd(i0);
|
||||
let mut ui1: __m256d = _mm256_loadu_pd(i1);
|
||||
let mut ui2: __m256d = _mm256_loadu_pd(i2);
|
||||
let mut ui3: __m256d = _mm256_loadu_pd(i3);
|
||||
|
||||
let mut tra: __m256d = _mm256_sub_pd(ur0, ur1);
|
||||
let mut trb: __m256d = _mm256_sub_pd(ur2, ur3);
|
||||
let mut tia: __m256d = _mm256_sub_pd(ui0, ui1);
|
||||
let mut tib: __m256d = _mm256_sub_pd(ui2, ui3);
|
||||
ur0 = _mm256_add_pd(ur0, ur1);
|
||||
ur2 = _mm256_add_pd(ur2, ur3);
|
||||
ui0 = _mm256_add_pd(ui0, ui1);
|
||||
ui2 = _mm256_add_pd(ui2, ui3);
|
||||
ur1 = _mm256_mul_pd(omai, tia);
|
||||
ur3 = _mm256_mul_pd(omar, tib);
|
||||
ui1 = _mm256_mul_pd(omai, tra);
|
||||
ui3 = _mm256_mul_pd(omar, trb);
|
||||
ur1 = _mm256_fmsub_pd(omar, tra, ur1);
|
||||
ur3 = _mm256_fmadd_pd(omai, trb, ur3);
|
||||
ui1 = _mm256_fmadd_pd(omar, tia, ui1);
|
||||
ui3 = _mm256_fmsub_pd(omai, tib, ui3);
|
||||
|
||||
tra = _mm256_sub_pd(ur0, ur2);
|
||||
trb = _mm256_sub_pd(ur1, ur3);
|
||||
tia = _mm256_sub_pd(ui0, ui2);
|
||||
tib = _mm256_sub_pd(ui1, ui3);
|
||||
ur0 = _mm256_add_pd(ur0, ur2);
|
||||
ur1 = _mm256_add_pd(ur1, ur3);
|
||||
ui0 = _mm256_add_pd(ui0, ui2);
|
||||
ui1 = _mm256_add_pd(ui1, ui3);
|
||||
ur2 = _mm256_mul_pd(ombi, tia);
|
||||
ur3 = _mm256_mul_pd(ombi, tib);
|
||||
ui2 = _mm256_mul_pd(ombi, tra);
|
||||
ui3 = _mm256_mul_pd(ombi, trb);
|
||||
ur2 = _mm256_fmsub_pd(ombr, tra, ur2);
|
||||
ur3 = _mm256_fmsub_pd(ombr, trb, ur3);
|
||||
ui2 = _mm256_fmadd_pd(ombr, tia, ui2);
|
||||
ui3 = _mm256_fmadd_pd(ombr, tib, ui3);
|
||||
|
||||
_mm256_storeu_pd(r0, ur0);
|
||||
_mm256_storeu_pd(r1, ur1);
|
||||
_mm256_storeu_pd(r2, ur2);
|
||||
_mm256_storeu_pd(r3, ur3);
|
||||
_mm256_storeu_pd(i0, ui0);
|
||||
_mm256_storeu_pd(i1, ui1);
|
||||
_mm256_storeu_pd(i2, ui2);
|
||||
_mm256_storeu_pd(i3, ui3);
|
||||
|
||||
r0 = r0.add(4);
|
||||
r1 = r1.add(4);
|
||||
r2 = r2.add(4);
|
||||
r3 = r3.add(4);
|
||||
i0 = i0.add(4);
|
||||
i1 = i1.add(4);
|
||||
i2 = i2.add(4);
|
||||
i3 = i3.add(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_avx2_fma() {
|
||||
use super::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
fn internal(log_m: usize) {
|
||||
use poulpy_hal::reference::fft64::reim::ReimIFFTRef;
|
||||
|
||||
let m: usize = 1 << log_m;
|
||||
|
||||
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
|
||||
|
||||
let mut values_0: Vec<f64> = vec![0f64; m << 1];
|
||||
let scale: f64 = 1.0f64 / m as f64;
|
||||
values_0
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
|
||||
|
||||
let mut values_1: Vec<f64> = vec![0f64; m << 1];
|
||||
values_1
|
||||
.iter_mut()
|
||||
.zip(values_0.iter())
|
||||
.for_each(|(y, x)| *y = *x);
|
||||
|
||||
ReimIFFTAvx::reim_dft_execute(&table, &mut values_0);
|
||||
ReimIFFTRef::reim_dft_execute(&table, &mut values_1);
|
||||
|
||||
let max_diff: f64 = 1.0 / ((1u64 << (53 - log_m - 1)) as f64);
|
||||
|
||||
for i in 0..m * 2 {
|
||||
let diff: f64 = (values_0[i] - values_1[i]).abs();
|
||||
assert!(
|
||||
diff <= max_diff,
|
||||
"{} -> {}-{} = {}",
|
||||
i,
|
||||
values_0[i],
|
||||
values_1[i],
|
||||
diff
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if std::is_x86_feature_detected!("avx2") {
|
||||
for log_m in 0..16 {
|
||||
unsafe { internal(log_m) }
|
||||
}
|
||||
} else {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
}
|
||||
}
|
||||
72
poulpy-backend/src/cpu_fft64_avx/reim/mod.rs
Normal file
72
poulpy-backend/src/cpu_fft64_avx/reim/mod.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
// ----------------------------------------------------------------------
|
||||
// DISCLAIMER
|
||||
//
|
||||
// This module contains code that has been directly ported from the
|
||||
// spqlios-arithmetic library
|
||||
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
|
||||
// under the Apache License, Version 2.0.
|
||||
//
|
||||
// The porting process from C to Rust was done with minimal changes
|
||||
// in order to preserve the semantics and performance characteristics
|
||||
// of the original implementation.
|
||||
//
|
||||
// Both Poulpy and spqlios-arithmetic are distributed under the terms
|
||||
// of the Apache License, Version 2.0. See the LICENSE file for details.
|
||||
//
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#![allow(bad_asm_style)]
|
||||
|
||||
mod conversion;
|
||||
mod fft_avx2_fma;
|
||||
mod fft_vec_avx2_fma;
|
||||
mod ifft_avx2_fma;
|
||||
|
||||
use std::arch::global_asm;
|
||||
|
||||
pub(crate) use conversion::*;
|
||||
pub(crate) use fft_vec_avx2_fma::*;
|
||||
|
||||
use poulpy_hal::reference::fft64::reim::{ReimDFTExecute, ReimFFTTable, ReimIFFTTable};
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::cpu_fft64_avx::reim::{fft_avx2_fma::fft_avx2_fma, ifft_avx2_fma::ifft_avx2_fma};
|
||||
|
||||
global_asm!(
|
||||
include_str!("fft16_avx2_fma.s"),
|
||||
include_str!("ifft16_avx2_fma.s")
|
||||
);
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const SIZE: usize, R: Float + FloatConst>(x: &[R]) -> &[R; SIZE] {
|
||||
debug_assert!(x.len() >= SIZE);
|
||||
unsafe { &*(x.as_ptr() as *const [R; SIZE]) }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_mut<const SIZE: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; SIZE] {
|
||||
debug_assert!(x.len() >= SIZE);
|
||||
unsafe { &mut *(x.as_mut_ptr() as *mut [R; SIZE]) }
|
||||
}
|
||||
|
||||
pub struct ReimFFTAvx;
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTAvx {
|
||||
#[inline(always)]
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
unsafe {
|
||||
fft_avx2_fma(table.m(), table.omg(), data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimIFFTAvx;
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTAvx {
|
||||
#[inline(always)]
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
unsafe {
|
||||
ifft_avx2_fma(table.m(), table.omg(), data);
|
||||
}
|
||||
}
|
||||
}
|
||||
264
poulpy-backend/src/cpu_fft64_avx/reim4/arithmetic_avx.rs
Normal file
264
poulpy-backend/src/cpu_fft64_avx/reim4/arithmetic_avx.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx")]
|
||||
pub fn reim4_extract_1blk_from_reim_avx(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
|
||||
unsafe {
|
||||
let mut src_ptr: *const __m256d = src.as_ptr().add(blk << 2) as *const __m256d; // src + 4*blk
|
||||
let mut dst_ptr: *mut __m256d = dst.as_mut_ptr() as *mut __m256d;
|
||||
|
||||
let step: usize = m >> 2;
|
||||
|
||||
// Each iteration copies 4 doubles; advance src by m doubles each row
|
||||
for _ in 0..2 * rows {
|
||||
let v: __m256d = _mm256_loadu_pd(src_ptr as *const f64);
|
||||
_mm256_storeu_pd(dst_ptr as *mut f64, v);
|
||||
dst_ptr = dst_ptr.add(1);
|
||||
src_ptr = src_ptr.add(step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim4_save_1blk_to_reim_avx<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
unsafe {
|
||||
let off: usize = blk * 4;
|
||||
let src_ptr: *const f64 = src.as_ptr();
|
||||
|
||||
let s0: __m256d = _mm256_loadu_pd(src_ptr);
|
||||
let s1: __m256d = _mm256_loadu_pd(src_ptr.add(4));
|
||||
|
||||
let d0_ptr: *mut f64 = dst.as_mut_ptr().add(off);
|
||||
let d1_ptr: *mut f64 = d0_ptr.add(m);
|
||||
|
||||
if OVERWRITE {
|
||||
_mm256_storeu_pd(d0_ptr, s0);
|
||||
_mm256_storeu_pd(d1_ptr, s1);
|
||||
} else {
|
||||
let d0: __m256d = _mm256_loadu_pd(d0_ptr);
|
||||
let d1: __m256d = _mm256_loadu_pd(d1_ptr);
|
||||
_mm256_storeu_pd(d0_ptr, _mm256_add_pd(d0, s0));
|
||||
_mm256_storeu_pd(d1_ptr, _mm256_add_pd(d1, s1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2,fma")]
|
||||
pub fn reim4_save_2blk_to_reim_avx<const OVERWRITE: bool>(
|
||||
m: usize, //
|
||||
blk: usize, // block index
|
||||
dst: &mut [f64], //
|
||||
src: &[f64], // 16 doubles [re1(4), im1(4), re2(4), im2(4)]
|
||||
) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_add_pd, _mm256_loadu_pd, _mm256_storeu_pd};
|
||||
unsafe {
|
||||
let off: usize = blk * 4;
|
||||
let src_ptr: *const f64 = src.as_ptr();
|
||||
|
||||
let d0_ptr: *mut f64 = dst.as_mut_ptr().add(off);
|
||||
let d1_ptr: *mut f64 = d0_ptr.add(m);
|
||||
let d2_ptr: *mut f64 = d1_ptr.add(m);
|
||||
let d3_ptr: *mut f64 = d2_ptr.add(m);
|
||||
|
||||
let s0: __m256d = _mm256_loadu_pd(src_ptr);
|
||||
let s1: __m256d = _mm256_loadu_pd(src_ptr.add(4));
|
||||
let s2: __m256d = _mm256_loadu_pd(src_ptr.add(8));
|
||||
let s3: __m256d = _mm256_loadu_pd(src_ptr.add(12));
|
||||
|
||||
if OVERWRITE {
|
||||
_mm256_storeu_pd(d0_ptr, s0);
|
||||
_mm256_storeu_pd(d1_ptr, s1);
|
||||
_mm256_storeu_pd(d2_ptr, s2);
|
||||
_mm256_storeu_pd(d3_ptr, s3);
|
||||
} else {
|
||||
let d0: __m256d = _mm256_loadu_pd(d0_ptr);
|
||||
let d1: __m256d = _mm256_loadu_pd(d1_ptr);
|
||||
let d2: __m256d = _mm256_loadu_pd(d2_ptr);
|
||||
let d3: __m256d = _mm256_loadu_pd(d3_ptr);
|
||||
_mm256_storeu_pd(d0_ptr, _mm256_add_pd(d0, s0));
|
||||
_mm256_storeu_pd(d1_ptr, _mm256_add_pd(d1, s1));
|
||||
_mm256_storeu_pd(d2_ptr, _mm256_add_pd(d2, s2));
|
||||
_mm256_storeu_pd(d3_ptr, _mm256_add_pd(d3, s3));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub fn reim4_vec_mat1col_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
|
||||
}
|
||||
|
||||
unsafe {
|
||||
use std::arch::x86_64::{_mm256_add_pd, _mm256_sub_pd};
|
||||
|
||||
let mut re1: __m256d = _mm256_setzero_pd();
|
||||
let mut im1: __m256d = _mm256_setzero_pd();
|
||||
let mut re2: __m256d = _mm256_setzero_pd();
|
||||
let mut im2: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let mut u_ptr: *const f64 = u.as_ptr();
|
||||
let mut v_ptr: *const f64 = v.as_ptr();
|
||||
|
||||
for _ in 0..nrows {
|
||||
let ur: __m256d = _mm256_loadu_pd(u_ptr);
|
||||
let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4));
|
||||
let vr: __m256d = _mm256_loadu_pd(v_ptr);
|
||||
let vi: __m256d = _mm256_loadu_pd(v_ptr.add(4));
|
||||
|
||||
// re1 = re1 + ur*vr;
|
||||
re1 = _mm256_fmadd_pd(ur, vr, re1);
|
||||
// im1 = im1 + ur*d;
|
||||
im1 = _mm256_fmadd_pd(ur, vi, im1);
|
||||
// re2 = re2 + ui*d;
|
||||
re2 = _mm256_fmadd_pd(ui, vi, re2);
|
||||
// im2 = im2 + ui*vr;
|
||||
im2 = _mm256_fmadd_pd(ui, vr, im2);
|
||||
|
||||
u_ptr = u_ptr.add(8);
|
||||
v_ptr = v_ptr.add(8);
|
||||
}
|
||||
|
||||
// re1 - re2
|
||||
_mm256_storeu_pd(dst.as_mut_ptr(), _mm256_sub_pd(re1, re2));
|
||||
|
||||
// im1 + im2
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(4), _mm256_add_pd(im1, im2));
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub fn reim4_vec_mat2cols_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
dst.len() >= 8,
|
||||
"dst must be at least 8 doubles but is {}",
|
||||
dst.len()
|
||||
);
|
||||
assert!(
|
||||
u.len() >= nrows * 8,
|
||||
"u must be at least nrows={} * 8 doubles but is {}",
|
||||
nrows,
|
||||
u.len()
|
||||
);
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows={} * 16 doubles but is {}",
|
||||
nrows,
|
||||
v.len()
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut re1: __m256d = _mm256_setzero_pd();
|
||||
let mut im1: __m256d = _mm256_setzero_pd();
|
||||
let mut re2: __m256d = _mm256_setzero_pd();
|
||||
let mut im2: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let mut u_ptr: *const f64 = u.as_ptr();
|
||||
let mut v_ptr: *const f64 = v.as_ptr();
|
||||
|
||||
for _ in 0..nrows {
|
||||
let ur: __m256d = _mm256_loadu_pd(u_ptr);
|
||||
let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4));
|
||||
|
||||
let ar: __m256d = _mm256_loadu_pd(v_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4));
|
||||
let br: __m256d = _mm256_loadu_pd(v_ptr.add(8));
|
||||
let bi: __m256d = _mm256_loadu_pd(v_ptr.add(12));
|
||||
|
||||
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
|
||||
re1 = _mm256_fmsub_pd(ui, ai, re1);
|
||||
re2 = _mm256_fmsub_pd(ui, bi, re2);
|
||||
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
|
||||
im1 = _mm256_fmadd_pd(ur, ai, im1);
|
||||
im2 = _mm256_fmadd_pd(ur, bi, im2);
|
||||
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
|
||||
re1 = _mm256_fmsub_pd(ur, ar, re1);
|
||||
re2 = _mm256_fmsub_pd(ur, br, re2);
|
||||
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
|
||||
im1 = _mm256_fmadd_pd(ui, ar, im1);
|
||||
im2 = _mm256_fmadd_pd(ui, br, im2);
|
||||
|
||||
u_ptr = u_ptr.add(8);
|
||||
v_ptr = v_ptr.add(16);
|
||||
}
|
||||
|
||||
_mm256_storeu_pd(dst.as_mut_ptr(), re1);
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(4), im1);
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(8), re2);
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(12), im2);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub fn reim4_vec_mat2cols_2ndcol_product_avx(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
use core::arch::x86_64::{__m256d, _mm256_fmadd_pd, _mm256_fmsub_pd, _mm256_loadu_pd, _mm256_setzero_pd, _mm256_storeu_pd};
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows * 16 doubles"
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut re1: __m256d = _mm256_setzero_pd();
|
||||
let mut im1: __m256d = _mm256_setzero_pd();
|
||||
|
||||
let mut u_ptr: *const f64 = u.as_ptr();
|
||||
let mut v_ptr: *const f64 = v.as_ptr().add(8); // Offset to 2nd column
|
||||
|
||||
for _ in 0..nrows {
|
||||
let ur: __m256d = _mm256_loadu_pd(u_ptr);
|
||||
let ui: __m256d = _mm256_loadu_pd(u_ptr.add(4));
|
||||
|
||||
let ar: __m256d = _mm256_loadu_pd(v_ptr);
|
||||
let ai: __m256d = _mm256_loadu_pd(v_ptr.add(4));
|
||||
|
||||
// re1 = re1 - ui*ai; re2 = re2 - ui*bi;
|
||||
re1 = _mm256_fmsub_pd(ui, ai, re1);
|
||||
// im1 = im1 + ur*ai; im2 = im2 + ur*bi;
|
||||
im1 = _mm256_fmadd_pd(ur, ai, im1);
|
||||
// re1 = re1 - ur*ar; re2 = re2 - ur*br;
|
||||
re1 = _mm256_fmsub_pd(ur, ar, re1);
|
||||
// im1 = im1 + ui*ar; im2 = im2 + ui*br;
|
||||
im1 = _mm256_fmadd_pd(ui, ar, im1);
|
||||
|
||||
u_ptr = u_ptr.add(8);
|
||||
v_ptr = v_ptr.add(16);
|
||||
}
|
||||
|
||||
_mm256_storeu_pd(dst.as_mut_ptr(), re1);
|
||||
_mm256_storeu_pd(dst.as_mut_ptr().add(4), im1);
|
||||
}
|
||||
}
|
||||
3
poulpy-backend/src/cpu_fft64_avx/reim4/mod.rs
Normal file
3
poulpy-backend/src/cpu_fft64_avx/reim4/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod arithmetic_avx;
|
||||
|
||||
pub(crate) use arithmetic_avx::*;
|
||||
261
poulpy-backend/src/cpu_fft64_avx/scratch.rs
Normal file
261
poulpy-backend/src/cpu_fft64_avx/scratch.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use poulpy_hal::{
|
||||
DEFAULTALIGN, alloc_aligned,
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::FFT64Avx;
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64Avx {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64Avx
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||
Scratch::from_bytes(&mut scratch.data)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64Avx {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64Avx {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64Avx
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64Avx
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64Avx
|
||||
where
|
||||
B: SvpPPolAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
SvpPPol::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64Avx
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64Avx
|
||||
where
|
||||
B: VecZnxBigAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64Avx
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64Avx
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B> + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64Avx
|
||||
where
|
||||
B: ScratchFromBytesImpl<B> + TakeVecZnxImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64Avx
|
||||
where
|
||||
B: VmpPMatAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64Avx
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
(take_slice, rem_slice)
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
66
poulpy-backend/src/cpu_fft64_avx/svp.rs
Normal file
66
poulpy-backend/src/cpu_fft64_avx/svp.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
|
||||
oep::{
|
||||
SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl,
|
||||
SvpPrepareImpl,
|
||||
},
|
||||
reference::fft64::svp::{svp_apply_dft_to_dft, svp_apply_dft_to_dft_inplace, svp_prepare},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64Avx {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64Avx {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64Avx {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64Avx {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
svp_prepare(module.get_fft_table(), res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyDftToDftImpl<Self> for FFT64Avx {
|
||||
fn svp_apply_dft_to_dft_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
svp_apply_dft_to_dft(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Avx {
|
||||
fn svp_apply_dft_to_dft_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
{
|
||||
svp_apply_dft_to_dft_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
117
poulpy-backend/src/cpu_fft64_avx/tests.rs
Normal file
117
poulpy-backend/src/cpu_fft64_avx/tests.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use poulpy_hal::{backend_test_suite, cross_backend_test_suite};
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
|
||||
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
|
||||
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
|
||||
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
|
||||
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
|
||||
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
|
||||
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
|
||||
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
|
||||
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
|
||||
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
|
||||
test_vec_znx_rsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh_inplace,
|
||||
test_vec_znx_lsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh,
|
||||
test_vec_znx_lsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh_inplace,
|
||||
test_vec_znx_negate => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate,
|
||||
test_vec_znx_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate_inplace,
|
||||
test_vec_znx_rotate => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate,
|
||||
test_vec_znx_rotate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate_inplace,
|
||||
test_vec_znx_automorphism => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism,
|
||||
test_vec_znx_automorphism_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism_inplace,
|
||||
test_vec_znx_mul_xp_minus_one => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one,
|
||||
test_vec_znx_mul_xp_minus_one_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one_inplace,
|
||||
test_vec_znx_normalize => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize,
|
||||
test_vec_znx_normalize_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize_inplace,
|
||||
test_vec_znx_switch_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_switch_ring,
|
||||
test_vec_znx_split_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_split_ring,
|
||||
test_vec_znx_copy => poulpy_hal::test_suite::vec_znx::test_vec_znx_copy,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod svp,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
|
||||
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx_big,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
|
||||
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
|
||||
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
|
||||
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
|
||||
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
|
||||
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
|
||||
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
|
||||
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
|
||||
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
|
||||
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
|
||||
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
|
||||
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
|
||||
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
|
||||
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
|
||||
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
|
||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b_inplace,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx_dft,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
|
||||
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
|
||||
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
|
||||
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
|
||||
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
|
||||
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
|
||||
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
|
||||
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vmp,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
|
||||
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,
|
||||
}
|
||||
}
|
||||
|
||||
backend_test_suite! {
|
||||
mod sampling,
|
||||
backend = crate::cpu_fft64_avx::FFT64Avx,
|
||||
size = 1 << 12,
|
||||
tests = {
|
||||
test_vec_znx_fill_uniform => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_uniform,
|
||||
test_vec_znx_fill_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_normal,
|
||||
test_vec_znx_add_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_normal,
|
||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal,
|
||||
}
|
||||
}
|
||||
538
poulpy-backend/src/cpu_fft64_avx/vec_znx.rs
Normal file
538
poulpy-backend/src/cpu_fft64_avx/vec_znx.rs
Normal file
@@ -0,0 +1,538 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
|
||||
},
|
||||
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
|
||||
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace,
|
||||
vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy,
|
||||
vec_znx_fill_normal_ref, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes,
|
||||
vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one, vec_znx_mul_xp_minus_one_inplace,
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
|
||||
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
|
||||
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
|
||||
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::FFT64Avx;
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_normalize_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64Avx
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_normalize_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_add_impl<R, A, B>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_add::<R, A, B, Self>(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_add_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
vec_znx_add_scalar_inplace::<R, A, Self>(res, res_col, res_limb, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddScalarImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_add_scalar_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_add_scalar::<R, A, B, Self>(res, res_col, a, a_col, b, b_col, b_limb);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_impl<R, A, B>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub::<R, A, B, Self>(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_scalar_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_scalar::<R, A, B, Self>(res, res_col, a, a_col, b, b_col, b_limb);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
vec_znx_sub_scalar_inplace::<R, A, Self>(res, res_col, res_limb, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_negate_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_negate::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_negate_inplace_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_negate_inplace::<R, Self>(res, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_lsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_rotate::<R, A, Self>(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceTmpBytesImpl<Self> for FFT64Avx
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rotate_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
Self: VecZnxRotateInplaceTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_rotate_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rotate_inplace::<R, Self>(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_automorphism_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_automorphism::<R, A, Self>(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_automorphism_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
Self: VecZnxAutomorphismInplaceTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_automorphism_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_automorphism_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_automorphism_inplace::<R, Self>(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_mul_xp_minus_one::<R, A, Self>(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl<Self> for FFT64Avx
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
Self: VecZnxMulXpMinusOneImpl<Self>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_mul_xp_minus_one_inplace::<R, Self>(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitRingTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_split_ring_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitRingImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxSplitRingTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_split_ring_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_split_ring::<R, A, Self>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMergeRingsTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_merge_rings_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMergeRingsImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxMergeRingsTmpBytes,
|
||||
{
|
||||
fn vec_znx_merge_rings_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &[A],
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_merge_rings::<R, A, Self>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSwitchRingImpl<Self> for FFT64Avx
|
||||
where
|
||||
Self: VecZnxCopyImpl<Self>,
|
||||
{
|
||||
fn vec_znx_switch_ring_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_switch_ring::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_copy::<R, A, Self>(res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
332
poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs
Normal file
332
poulpy-backend/src/cpu_fft64_avx/vec_znx_big.rs
Normal file
@@ -0,0 +1,332 @@
|
||||
use crate::cpu_fft64_avx::FFT64Avx;
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNormalizeTmpBytes},
|
||||
layouts::{
|
||||
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
|
||||
ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
fft64::vec_znx_big::{
|
||||
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
|
||||
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
|
||||
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
|
||||
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
|
||||
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64Avx> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let min_size: usize = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Avx {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64Avx {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_add(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64Avx {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_add_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64Avx {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_add_small(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64Avx {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_add_small_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_small_a(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_sub_small_a_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_sub_small_b(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Avx {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_sub_small_b_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_negate_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_negate(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_negate_inplace_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
{
|
||||
vec_znx_big_negate_inplace(res, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_big_normalize_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64Avx
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64Avx {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_automorphism(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_big_automorphism_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Module<Self>: VecZnxBigAutomorphismInplaceTmpBytes,
|
||||
{
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
186
poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs
Normal file
186
poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use poulpy_hal::{
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::fft64::vec_znx_dft::{
|
||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
|
||||
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
||||
vec_znx_idft_apply_tmpa,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDft::<Vec<u8>, Self>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Avx as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_idft_apply_tmp_bytes_impl(_module: &Module<Self>) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_idft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
_scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_idft_apply(module.get_ifft_table(), res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyTmpAImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToMut<Self>,
|
||||
{
|
||||
vec_znx_idft_apply_tmpa(module.get_ifft_table(), res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyConsumeImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<Self>, res: VecZnxDft<D, FFT64Avx>) -> VecZnxBig<D, FFT64Avx>
|
||||
where
|
||||
VecZnxDft<D, FFT64Avx>: VecZnxDftToMut<Self>,
|
||||
{
|
||||
vec_znx_idft_apply_consume(module.get_ifft_table(), res)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftApplyImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_dft_apply(module.get_fft_table(), step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_add_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_add(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_add_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_copy(step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Avx {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
{
|
||||
vec_znx_dft_zero(res);
|
||||
}
|
||||
}
|
||||
143
poulpy-backend/src/cpu_fft64_avx/vmp.rs
Normal file
143
poulpy-backend/src/cpu_fft64_avx/vmp.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VmpPrepareTmpBytes},
|
||||
layouts::{
|
||||
Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned,
|
||||
VmpPMatToMut, VmpPMatToRef, ZnxInfos,
|
||||
},
|
||||
oep::{
|
||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
reference::fft64::vmp::{
|
||||
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::{FFT64Avx, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<Self> for FFT64Avx {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<Self> for FFT64Avx {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<Self> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftImpl<Self> for FFT64Avx
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
FFT64Avx: VmpApplyDftToDftTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<Self>, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], Self> = pmat.to_ref();
|
||||
|
||||
let (tmp, _) = scratch.take_slice(
|
||||
Self::vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module,
|
||||
res.size(),
|
||||
a.size(),
|
||||
pmat.rows(),
|
||||
pmat.cols_in(),
|
||||
pmat.cols_out(),
|
||||
pmat.size(),
|
||||
) / size_of::<f64>(),
|
||||
);
|
||||
vmp_apply_dft_to_dft(&mut res, &a, &pmat, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddImpl<Self> for FFT64Avx
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
FFT64Avx: VmpApplyDftToDftTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
a: &A,
|
||||
pmat: &C,
|
||||
limb_offset: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], Self> = pmat.to_ref();
|
||||
|
||||
let (tmp, _) = scratch.take_slice(
|
||||
Self::vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module,
|
||||
res.size(),
|
||||
a.size(),
|
||||
pmat.rows(),
|
||||
pmat.cols_in(),
|
||||
pmat.cols_out(),
|
||||
pmat.size(),
|
||||
) / size_of::<f64>(),
|
||||
);
|
||||
vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, limb_offset * pmat.cols_out(), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<Self>, _rows: usize, _cols_in: usize, _cols_out: usize, _size: usize) -> usize {
|
||||
vmp_prepare_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareImpl<Self> for FFT64Avx {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VmpPMatToMut<Self>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
{}
|
||||
let mut res: VmpPMat<&mut [u8], Self> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
let (tmp, _) =
|
||||
scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()) / size_of::<f64>());
|
||||
vmp_prepare(module.get_fft_table(), &mut res, &a, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
_res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
_b_cols_out: usize,
|
||||
_b_size: usize,
|
||||
) -> usize {
|
||||
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
_res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
_b_cols_out: usize,
|
||||
_b_size: usize,
|
||||
) -> usize {
|
||||
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
||||
}
|
||||
}
|
||||
73
poulpy-backend/src/cpu_fft64_avx/zn.rs
Normal file
73
poulpy-backend/src/cpu_fft64_avx/zn.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use poulpy_hal::{
|
||||
api::TakeSlice,
|
||||
layouts::{Scratch, ZnToMut},
|
||||
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
|
||||
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_avx::FFT64Avx;
|
||||
|
||||
unsafe impl ZnNormalizeTmpBytesImpl<Self> for FFT64Avx {
|
||||
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize {
|
||||
zn_normalize_tmp_bytes(n)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Avx
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(n);
|
||||
zn_normalize_inplace::<R, FFT64Avx>(n, basek, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Avx {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
76
poulpy-backend/src/cpu_fft64_avx/znx_avx/add.rs
Normal file
76
poulpy-backend/src/cpu_fft64_avx/znx_avx/add.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_add_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
assert_eq!(res.len(), b.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
let mut bb: *const __m256i = b.as_ptr() as *const __m256i;
|
||||
|
||||
unsafe {
|
||||
for _ in 0..span {
|
||||
let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb));
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
bb = bb.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_add_ref;
|
||||
|
||||
znx_add_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_add_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_storeu_si256};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
unsafe {
|
||||
for _ in 0..span {
|
||||
let sum: __m256i = _mm256_add_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa));
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_add_inplace_ref;
|
||||
|
||||
znx_add_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
133
poulpy-backend/src/cpu_fft64_avx/znx_avx/automorphism.rs
Normal file
133
poulpy-backend/src/cpu_fft64_avx/znx_avx/automorphism.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
#[inline]
|
||||
fn inv_mod_pow2(p: usize, bits: u32) -> usize {
|
||||
// Compute p^{-1} mod 2^bits (p must be odd) through Hensel lifting.
|
||||
debug_assert!(p % 2 == 1);
|
||||
let mut x: usize = 1usize; // inverse mod 2
|
||||
let mut i: u32 = 1;
|
||||
while i < bits {
|
||||
// x <- x * (2 - p*x) mod 2^(2^i) (wrapping arithmetic)
|
||||
x = x.wrapping_mul(2usize.wrapping_sub(p.wrapping_mul(x)));
|
||||
i <<= 1;
|
||||
}
|
||||
x & ((1usize << bits) - 1)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
pub fn znx_automorphism_avx(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
debug_assert_eq!(res.len(), a.len());
|
||||
let n: usize = res.len();
|
||||
if n == 0 {
|
||||
return;
|
||||
}
|
||||
debug_assert!(n.is_power_of_two(), "n must be power of two");
|
||||
debug_assert!(p & 1 == 1, "p must be odd (invertible mod 2n)");
|
||||
|
||||
if n < 4 {
|
||||
use poulpy_hal::reference::znx::znx_automorphism_ref;
|
||||
|
||||
znx_automorphism_ref(p, res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let two_n: usize = n << 1;
|
||||
let span: usize = n >> 2;
|
||||
let bits: u32 = (two_n as u64).trailing_zeros();
|
||||
let mask_2n: usize = two_n - 1;
|
||||
let mask_1n: usize = n - 1;
|
||||
|
||||
// p mod 2n (positive)
|
||||
let p_2n: usize = (((p & mask_2n as i64) + two_n as i64) as usize) & mask_2n;
|
||||
|
||||
// p^-1 mod 2n
|
||||
let inv: usize = inv_mod_pow2(p_2n, bits);
|
||||
|
||||
// Broadcast constants
|
||||
let n_minus1_vec: __m256i = _mm256_set1_epi64x((n as i64) - 1);
|
||||
let mask_2n_vec: __m256i = _mm256_set1_epi64x(mask_2n as i64);
|
||||
let mask_1n_vec: __m256i = _mm256_set1_epi64x(mask_1n as i64);
|
||||
|
||||
// Lane offsets [0, inv, 2*inv, 3*inv] (mod 2n)
|
||||
let lane_offsets: __m256i = _mm256_set_epi64x(
|
||||
((inv * 3) & mask_2n) as i64,
|
||||
((inv * 2) & mask_2n) as i64,
|
||||
inv as i64,
|
||||
0i64,
|
||||
);
|
||||
|
||||
// t_base = (j * inv) mod 2n.
|
||||
let mut t_base: usize = 0;
|
||||
let step: usize = (inv << 2) & mask_2n;
|
||||
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let aa: *const i64 = a.as_ptr();
|
||||
|
||||
for _ in 0..span {
|
||||
// t_vec = (t_base + [0, inv, 2*inv, 3*inv]) & (2n-1)
|
||||
let t_base_vec: __m256i = _mm256_set1_epi64x(t_base as i64);
|
||||
let t_vec: __m256i = _mm256_and_si256(_mm256_add_epi64(t_base_vec, lane_offsets), mask_2n_vec);
|
||||
|
||||
// idx = t_vec & (n-1)
|
||||
let idx_vec: __m256i = _mm256_and_si256(t_vec, mask_1n_vec);
|
||||
|
||||
// sign = t >= n ? -1 : 0 (mask of all-ones where negate)
|
||||
let sign_mask: __m256i = _mm256_cmpgt_epi64(t_vec, n_minus1_vec);
|
||||
|
||||
// gather a[idx] (scale = 8 bytes per i64)
|
||||
let vals: __m256i = _mm256_i64gather_epi64(aa, idx_vec, 8);
|
||||
|
||||
// Conditional negate: (vals ^ sign_mask) - sign_mask
|
||||
let vals_x: __m256i = _mm256_xor_si256(vals, sign_mask);
|
||||
let out: __m256i = _mm256_sub_epi64(vals_x, sign_mask);
|
||||
|
||||
// store to res[j..j+4]
|
||||
_mm256_storeu_si256(rr, out);
|
||||
|
||||
// advance
|
||||
rr = rr.add(1);
|
||||
t_base = (t_base + step) & mask_2n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(all(test, any(target_arch = "x86_64", target_arch = "x86")))]
|
||||
mod tests {
|
||||
use poulpy_hal::reference::znx::znx_automorphism_ref;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[target_feature(enable = "avx2", enable = "fma")]
|
||||
fn test_znx_automorphism_internal() {
|
||||
let a: [i64; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
|
||||
|
||||
let p: i64 = -5;
|
||||
|
||||
let mut r0: Vec<i64> = vec![0i64; a.len()];
|
||||
let mut r1: Vec<i64> = vec![0i64; a.len()];
|
||||
|
||||
znx_automorphism_ref(p, &mut r0, &a);
|
||||
znx_automorphism_avx(p, &mut r1, &a);
|
||||
|
||||
assert_eq!(r0, r1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_znx_automorphism_avx() {
|
||||
if !std::is_x86_feature_detected!("avx2") {
|
||||
eprintln!("skipping: CPU lacks avx2");
|
||||
return;
|
||||
};
|
||||
unsafe {
|
||||
test_znx_automorphism_internal();
|
||||
}
|
||||
}
|
||||
}
|
||||
13
poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs
Normal file
13
poulpy-backend/src/cpu_fft64_avx/znx_avx/mod.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
mod add;
|
||||
mod automorphism;
|
||||
mod neg;
|
||||
mod normalization;
|
||||
mod sub;
|
||||
mod switch_ring;
|
||||
|
||||
pub(crate) use add::*;
|
||||
pub(crate) use automorphism::*;
|
||||
pub(crate) use neg::*;
|
||||
pub(crate) use normalization::*;
|
||||
pub(crate) use sub::*;
|
||||
pub(crate) use switch_ring::*;
|
||||
64
poulpy-backend/src/cpu_fft64_avx/znx_avx/neg.rs
Normal file
64
poulpy-backend/src/cpu_fft64_avx/znx_avx/neg.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_negate_avx(res: &mut [i64], src: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len())
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64};
|
||||
let span: usize = n >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut aa: *const __m256i = src.as_ptr() as *const __m256i;
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
for _ in 0..span {
|
||||
let v: __m256i = _mm256_loadu_si256(aa);
|
||||
let neg: __m256i = _mm256_sub_epi64(zero, v);
|
||||
_mm256_storeu_si256(rr, neg);
|
||||
aa = aa.add(1);
|
||||
rr = rr.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_negate_ref;
|
||||
|
||||
znx_negate_ref(&mut res[span << 2..], &src[span << 2..])
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_negate_inplace_avx(res: &mut [i64]) {
|
||||
let n: usize = res.len();
|
||||
|
||||
use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_setzero_si256, _mm256_storeu_si256, _mm256_sub_epi64};
|
||||
let span: usize = n >> 2;
|
||||
|
||||
unsafe {
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let zero: __m256i = _mm256_setzero_si256();
|
||||
for _ in 0..span {
|
||||
let v: __m256i = _mm256_loadu_si256(rr);
|
||||
let neg: __m256i = _mm256_sub_epi64(zero, v);
|
||||
_mm256_storeu_si256(rr, neg);
|
||||
rr = rr.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_negate_inplace_ref;
|
||||
|
||||
znx_negate_inplace_ref(&mut res[span << 2..])
|
||||
}
|
||||
}
|
||||
1023
poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs
Normal file
1023
poulpy-backend/src/cpu_fft64_avx/znx_avx/normalization.rs
Normal file
File diff suppressed because it is too large
Load Diff
113
poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs
Normal file
113
poulpy-backend/src/cpu_fft64_avx/znx_avx/sub.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_avx(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
assert_eq!(res.len(), b.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
let mut bb: *const __m256i = b.as_ptr() as *const __m256i;
|
||||
|
||||
unsafe {
|
||||
for _ in 0..span {
|
||||
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(bb));
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
bb = bb.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ref;
|
||||
|
||||
znx_sub_ref(&mut res[span << 2..], &a[span << 2..], &b[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_ab_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
unsafe {
|
||||
for _ in 0..span {
|
||||
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(rr), _mm256_loadu_si256(aa));
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ab_inplace_ref;
|
||||
|
||||
znx_sub_ab_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
|
||||
/// all inputs must have the same length and must not alias.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub fn znx_sub_ba_inplace_avx(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256, _mm256_sub_epi64};
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let span: usize = n >> 2;
|
||||
|
||||
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
unsafe {
|
||||
for _ in 0..span {
|
||||
let sum: __m256i = _mm256_sub_epi64(_mm256_loadu_si256(aa), _mm256_loadu_si256(rr));
|
||||
_mm256_storeu_si256(rr, sum);
|
||||
rr = rr.add(1);
|
||||
aa = aa.add(1);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if !res.len().is_multiple_of(4) {
|
||||
use poulpy_hal::reference::znx::znx_sub_ba_inplace_ref;
|
||||
|
||||
znx_sub_ba_inplace_ref(&mut res[span << 2..], &a[span << 2..]);
|
||||
}
|
||||
}
|
||||
87
poulpy-backend/src/cpu_fft64_avx/znx_avx/switch_ring.rs
Normal file
87
poulpy-backend/src/cpu_fft64_avx/znx_avx/switch_ring.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
pub unsafe fn znx_switch_ring_avx(res: &mut [i64], a: &[i64]) {
|
||||
unsafe {
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
let (n_in, n_out) = (a.len(), res.len());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n_in.is_power_of_two());
|
||||
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
|
||||
}
|
||||
|
||||
if n_in == n_out {
|
||||
use poulpy_hal::reference::znx::znx_copy_ref;
|
||||
|
||||
znx_copy_ref(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
if n_in > n_out {
|
||||
// Downsample: res[k] = a[k * gap_in], contiguous stores
|
||||
let gap_in: usize = n_in / n_out;
|
||||
|
||||
// index vector: [0*gap, 1*gap, 2*gap, 3*gap] * gap_in
|
||||
let step: __m256i = _mm256_setr_epi64x(0, gap_in as i64, 2 * gap_in as i64, 3 * gap_in as i64);
|
||||
|
||||
let span: usize = n_out >> 2;
|
||||
let bump: __m256i = _mm256_set1_epi64x(4 * gap_in as i64);
|
||||
|
||||
let mut res_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
|
||||
let a_ptr: *const i64 = a.as_ptr();
|
||||
|
||||
let mut base: __m256i = _mm256_setzero_si256(); // starts at 0*gap
|
||||
|
||||
for _ in 0..span {
|
||||
// idx = base + step
|
||||
let idx: __m256i = _mm256_add_epi64(base, step);
|
||||
|
||||
// gather 4 spaced i64 (scale=8 bytes)
|
||||
let v: __m256i = _mm256_i64gather_epi64(a_ptr, idx, 8);
|
||||
|
||||
// store contiguously
|
||||
_mm256_storeu_si256(res_4xi64, v);
|
||||
|
||||
base = _mm256_add_epi64(base, bump);
|
||||
res_4xi64 = res_4xi64.add(1);
|
||||
}
|
||||
} else {
|
||||
// Upsample: res[k * gap_out] = a[k], i.e. res has holes;
|
||||
|
||||
use poulpy_hal::reference::znx::znx_zero_ref;
|
||||
let gap_out = n_out / n_in;
|
||||
|
||||
// zero then scatter scalar stores
|
||||
znx_zero_ref(res);
|
||||
|
||||
let mut a_4xi64: *const __m256i = a.as_ptr() as *const __m256i;
|
||||
|
||||
for i in (0..n_in).step_by(4) {
|
||||
// Load contiguously 4 inputs
|
||||
let v = _mm256_loadu_si256(a_4xi64);
|
||||
|
||||
// extract 4 lanes (pextrq). This is still the best we can do on AVX2.
|
||||
let x0: i64 = _mm256_extract_epi64(v, 0);
|
||||
let x1: i64 = _mm256_extract_epi64(v, 1);
|
||||
let x2: i64 = _mm256_extract_epi64(v, 2);
|
||||
let x3: i64 = _mm256_extract_epi64(v, 3);
|
||||
|
||||
// starting output pointer for this group
|
||||
let mut p: *mut i64 = res.as_mut_ptr().add(i * gap_out);
|
||||
|
||||
// four strided stores with pointer bump (avoid mul each time)
|
||||
*p = x0;
|
||||
p = p.add(gap_out);
|
||||
*p = x1;
|
||||
p = p.add(gap_out);
|
||||
*p = x2;
|
||||
p = p.add(gap_out);
|
||||
*p = x3;
|
||||
|
||||
a_4xi64 = a_4xi64.add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
poulpy-backend/src/cpu_fft64_ref/mod.rs
Normal file
12
poulpy-backend/src/cpu_fft64_ref/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
mod module;
|
||||
mod reim;
|
||||
mod scratch;
|
||||
mod svp;
|
||||
mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp;
|
||||
mod zn;
|
||||
mod znx;
|
||||
|
||||
pub struct FFT64Ref {}
|
||||
62
poulpy-backend/src/cpu_fft64_ref/module.rs
Normal file
62
poulpy-backend/src/cpu_fft64_ref/module.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
reference::fft64::reim::{ReimFFTTable, ReimIFFTTable},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct FFT64RefHandle {
|
||||
table_fft: ReimFFTTable<f64>,
|
||||
table_ifft: ReimIFFTTable<f64>,
|
||||
}
|
||||
|
||||
impl Backend for FFT64Ref {
|
||||
type ScalarPrep = f64;
|
||||
type ScalarBig = i64;
|
||||
type Handle = FFT64RefHandle;
|
||||
unsafe fn destroy(handle: NonNull<Self::Handle>) {
|
||||
unsafe {
|
||||
drop(Box::from_raw(handle.as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
fn layout_big_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn layout_prep_word_count() -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64Ref {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
let handle: FFT64RefHandle = FFT64RefHandle {
|
||||
table_fft: ReimFFTTable::new(n as usize >> 1),
|
||||
table_ifft: ReimIFFTTable::new(n as usize >> 1),
|
||||
};
|
||||
// Leak Box to get a stable NonNull pointer
|
||||
let ptr: NonNull<FFT64RefHandle> = NonNull::from(Box::leak(Box::new(handle)));
|
||||
unsafe { Module::from_nonnull(ptr, n) }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FFT64ModuleHandle {
|
||||
fn get_fft_table(&self) -> &ReimFFTTable<f64>;
|
||||
fn get_ifft_table(&self) -> &ReimIFFTTable<f64>;
|
||||
}
|
||||
|
||||
impl FFT64ModuleHandle for Module<FFT64Ref> {
|
||||
fn get_fft_table(&self) -> &ReimFFTTable<f64> {
|
||||
let h: &FFT64RefHandle = unsafe { &*self.ptr() };
|
||||
&h.table_fft
|
||||
}
|
||||
fn get_ifft_table(&self) -> &ReimIFFTTable<f64> {
|
||||
let h: &FFT64RefHandle = unsafe { &*self.ptr() };
|
||||
&h.table_ifft
|
||||
}
|
||||
}
|
||||
175
poulpy-backend/src/cpu_fft64_ref/reim.rs
Normal file
175
poulpy-backend/src/cpu_fft64_ref/reim.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use poulpy_hal::reference::fft64::{
|
||||
reim::{
|
||||
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace,
|
||||
ReimZero, fft_ref, ifft_ref, reim_add_inplace_ref, reim_add_ref, reim_addmul_ref, reim_copy_ref, reim_from_znx_i64_ref,
|
||||
reim_mul_inplace_ref, reim_mul_ref, reim_negate_inplace_ref, reim_negate_ref, reim_sub_ab_inplace_ref,
|
||||
reim_sub_ba_inplace_ref, reim_sub_ref, reim_to_znx_i64_inplace_ref, reim_to_znx_i64_ref, reim_zero_ref,
|
||||
},
|
||||
reim4::{
|
||||
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
|
||||
reim4_extract_1blk_from_reim_ref, reim4_save_1blk_to_reim_ref, reim4_save_2blk_to_reim_ref,
|
||||
reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref, reim4_vec_mat2cols_product_ref,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::FFT64Ref;
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for FFT64Ref {
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
fft_ref(table.m(), table.omg(), data);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for FFT64Ref {
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
ifft_ref(table.m(), table.omg(), data);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimFromZnx for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_from_znx(res: &mut [f64], a: &[i64]) {
|
||||
reim_from_znx_i64_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimToZnx for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
reim_to_znx_i64_ref(res, divisor, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimToZnxInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64) {
|
||||
reim_to_znx_i64_inplace_ref(res, divisor);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimAdd for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
reim_add_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimAddInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_add_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_add_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSub for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
reim_sub_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubABInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_sub_ab_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimSubBAInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_sub_ba_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimNegate for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_negate(res: &mut [f64], a: &[f64]) {
|
||||
reim_negate_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimNegateInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_negate_inplace(res: &mut [f64]) {
|
||||
reim_negate_inplace_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimMul for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
reim_mul_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimMulInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_mul_inplace(res: &mut [f64], a: &[f64]) {
|
||||
reim_mul_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimAddMul for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
reim_addmul_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimCopy for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]) {
|
||||
reim_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimZero for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim_zero(res: &mut [f64]) {
|
||||
reim_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Extract1Blk for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
reim4_extract_1blk_from_reim_ref(m, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Save1Blk for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
reim4_save_1blk_to_reim_ref::<OVERWRITE>(m, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Save2Blks for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
reim4_save_2blk_to_reim_ref::<OVERWRITE>(m, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Mat1ColProd for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
reim4_vec_mat1col_product_ref(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Mat2ColsProd for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
reim4_vec_mat2cols_product_ref(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Mat2Cols2ndColProd for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]) {
|
||||
reim4_vec_mat2cols_2ndcol_product_ref(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
261
poulpy-backend/src/cpu_fft64_ref/scratch.rs
Normal file
261
poulpy-backend/src/cpu_fft64_ref/scratch.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use poulpy_hal::{
|
||||
DEFAULTALIGN, alloc_aligned,
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64Ref {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
data,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64Ref
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn scratch_owned_borrow_impl(scratch: &mut ScratchOwned<B>) -> &mut Scratch<B> {
|
||||
Scratch::from_bytes(&mut scratch.data)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64Ref {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64Ref {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64Ref
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_slice_impl<T>(scratch: &mut Scratch<B>, len: usize) -> (&mut [T], &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64Ref
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64Ref
|
||||
where
|
||||
B: SvpPPolAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::svp_ppol_alloc_bytes_impl(n, cols));
|
||||
(
|
||||
SvpPPol::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64Ref
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64Ref
|
||||
where
|
||||
B: VecZnxBigAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64Ref
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64Ref
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B> + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_dft_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64Ref
|
||||
where
|
||||
B: ScratchFromBytesImpl<B> + TakeVecZnxImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_slice_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
len: usize,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Scratch<B>) {
|
||||
let mut scratch: &mut Scratch<B> = scratch;
|
||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let (znx, new_scratch) = B::take_vec_znx_impl(scratch, n, cols, size);
|
||||
scratch = new_scratch;
|
||||
slice.push(znx);
|
||||
}
|
||||
(slice, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64Ref
|
||||
where
|
||||
B: VmpPMatAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_vmp_pmat_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
VmpPMat::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64Ref
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
fn take_mat_znx_impl(
|
||||
scratch: &mut Scratch<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
(take_slice, rem_slice)
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len, aligned_len,
|
||||
);
|
||||
}
|
||||
}
|
||||
66
poulpy-backend/src/cpu_fft64_ref/svp.rs
Normal file
66
poulpy-backend/src/cpu_fft64_ref/svp.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDftToMut, VecZnxDftToRef},
|
||||
oep::{
|
||||
SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl,
|
||||
SvpPrepareImpl,
|
||||
},
|
||||
reference::fft64::svp::{svp_apply_dft_to_dft, svp_apply_dft_to_dft_inplace, svp_prepare},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64Ref {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64Ref {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64Ref {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64Ref {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
svp_prepare(module.get_fft_table(), res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyDftToDftImpl<Self> for FFT64Ref {
|
||||
fn svp_apply_dft_to_dft_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
svp_apply_dft_to_dft(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Ref {
|
||||
fn svp_apply_dft_to_dft_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
{
|
||||
svp_apply_dft_to_dft_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
538
poulpy-backend/src/cpu_fft64_ref/vec_znx.rs
Normal file
538
poulpy-backend/src/cpu_fft64_ref/vec_znx.rs
Normal file
@@ -0,0 +1,538 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, VecZnxAutomorphismInplaceTmpBytes, VecZnxMergeRingsTmpBytes, VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotateInplaceTmpBytes, VecZnxSplitRingTmpBytes,
|
||||
},
|
||||
layouts::{Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
|
||||
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace,
|
||||
vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy,
|
||||
vec_znx_fill_normal_ref, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes,
|
||||
vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one, vec_znx_mul_xp_minus_one_inplace,
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_negate, vec_znx_negate_inplace, vec_znx_normalize,
|
||||
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
|
||||
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
|
||||
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace, vec_znx_sub_scalar,
|
||||
vec_znx_sub_scalar_inplace, vec_znx_switch_ring,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_normalize_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64Ref
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize::<R, A, Self>(basek, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_normalize_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize_inplace::<R, Self>(basek, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_add_impl<R, A, B>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_add::<R, A, B, Self>(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_add_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
vec_znx_add_scalar_inplace::<R, A, Self>(res, res_col, res_limb, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddScalarImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_add_scalar_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_add_scalar::<R, A, B, Self>(res, res_col, a, a_col, b, b_col, b_limb);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_impl<R, A, B>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub::<R, A, B, Self>(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ab_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_ba_inplace::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_scalar_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_sub_scalar::<R, A, B, Self>(res, res_col, a, a_col, b, b_col, b_limb);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
res_limb: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
vec_znx_sub_scalar_inplace::<R, A, Self>(res, res_col, res_limb, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_negate_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_negate::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_negate_inplace_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_negate_inplace::<R, Self>(res, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_lsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, Self>(basek, k, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, Self>(basek, k, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_rotate::<R, A, Self>(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceTmpBytesImpl<Self> for FFT64Ref
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rotate_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
Self: VecZnxRotateInplaceTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_rotate_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rotate_inplace::<R, Self>(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_automorphism_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_automorphism::<R, A, Self>(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_automorphism_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
Self: VecZnxAutomorphismInplaceTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vec_znx_automorphism_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_automorphism_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_automorphism_inplace::<R, Self>(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_mul_xp_minus_one::<R, A, Self>(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl<Self> for FFT64Ref
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
Self: VecZnxMulXpMinusOneImpl<Self>,
|
||||
{
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_mul_xp_minus_one_inplace::<R, Self>(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitRingTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_split_ring_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitRingImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxSplitRingTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_split_ring_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_split_ring::<R, A, Self>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMergeRingsTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_merge_rings_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMergeRingsImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxMergeRingsTmpBytes,
|
||||
{
|
||||
fn vec_znx_merge_rings_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &[A],
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_merge_rings::<R, A, Self>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSwitchRingImpl<Self> for FFT64Ref
|
||||
where
|
||||
Self: VecZnxCopyImpl<Self>,
|
||||
{
|
||||
fn vec_znx_switch_ring_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_switch_ring::<R, A, Self>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_copy::<R, A, Self>(res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
332
poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs
Normal file
332
poulpy-backend/src/cpu_fft64_ref/vec_znx_big.rs
Normal file
@@ -0,0 +1,332 @@
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxBigAutomorphismInplaceTmpBytes, VecZnxBigNormalizeTmpBytes},
|
||||
layouts::{
|
||||
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
|
||||
ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
fft64::vec_znx_big::{
|
||||
vec_znx_big_add, vec_znx_big_add_inplace, vec_znx_big_add_normal_ref, vec_znx_big_add_small,
|
||||
vec_znx_big_add_small_inplace, vec_znx_big_automorphism, vec_znx_big_automorphism_inplace,
|
||||
vec_znx_big_automorphism_inplace_tmp_bytes, vec_znx_big_negate, vec_znx_big_negate_inplace, vec_znx_big_normalize,
|
||||
vec_znx_big_normalize_tmp_bytes, vec_znx_big_sub, vec_znx_big_sub_ab_inplace, vec_znx_big_sub_ba_inplace,
|
||||
vec_znx_big_sub_small_a, vec_znx_big_sub_small_a_inplace, vec_znx_big_sub_small_b, vec_znx_big_sub_small_b_inplace,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64Ref> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let min_size: usize = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Ref {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
vec_znx_big_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64Ref {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_add(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64Ref {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_add_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64Ref {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_add_small(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64Ref {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_add_small_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ab_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_ba_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_sub_small_a(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_sub_small_a_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_sub_small_b(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Ref {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_big_sub_small_b_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_negate_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_negate(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_negate_inplace_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
{
|
||||
vec_znx_big_negate_inplace(res, res_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_big_normalize_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64Ref
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_normalize(basek, res, res_col, a, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64Ref {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(_module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
vec_znx_big_automorphism(p, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_big_automorphism_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: VecZnxBigAutomorphismInplaceTmpBytes,
|
||||
{
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_automorphism_inplace(p, res, res_col, tmp);
|
||||
}
|
||||
}
|
||||
186
poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs
Normal file
186
poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use poulpy_hal::{
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxToRef,
|
||||
},
|
||||
oep::{
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::fft64::vec_znx_dft::{
|
||||
vec_znx_dft_add, vec_znx_dft_add_inplace, vec_znx_dft_apply, vec_znx_dft_copy, vec_znx_dft_sub,
|
||||
vec_znx_dft_sub_ab_inplace, vec_znx_dft_sub_ba_inplace, vec_znx_dft_zero, vec_znx_idft_apply, vec_znx_idft_apply_consume,
|
||||
vec_znx_idft_apply_tmpa,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDft::<Vec<u8>, Self>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Ref as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_idft_apply_tmp_bytes_impl(_module: &Module<Self>) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_idft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
_scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_idft_apply(module.get_ifft_table(), res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyTmpAImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToMut<Self>,
|
||||
{
|
||||
vec_znx_idft_apply_tmpa(module.get_ifft_table(), res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIdftApplyConsumeImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<Self>, res: VecZnxDft<D, FFT64Ref>) -> VecZnxBig<D, FFT64Ref>
|
||||
where
|
||||
VecZnxDft<D, FFT64Ref>: VecZnxDftToMut<Self>,
|
||||
{
|
||||
vec_znx_idft_apply_consume(module.get_ifft_table(), res)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftApplyImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_dft_apply(module.get_fft_table(), step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_add_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_add(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_add_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_impl<R, A, B>(
|
||||
_module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub(res, res_col, a, a_col, b, b_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ab_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_sub_ba_inplace(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
vec_znx_dft_copy(step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Ref {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
{
|
||||
vec_znx_dft_zero(res);
|
||||
}
|
||||
}
|
||||
143
poulpy-backend/src/cpu_fft64_ref/vmp.rs
Normal file
143
poulpy-backend/src/cpu_fft64_ref/vmp.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VmpPrepareTmpBytes},
|
||||
layouts::{
|
||||
Backend, MatZnx, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatOwned,
|
||||
VmpPMatToMut, VmpPMatToRef, ZnxInfos,
|
||||
},
|
||||
oep::{
|
||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
reference::fft64::vmp::{
|
||||
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::{FFT64Ref, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<Self> for FFT64Ref {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<Self> for FFT64Ref {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<Self> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftImpl<Self> for FFT64Ref
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
FFT64Ref: VmpApplyDftToDftTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<Self>, res: &mut R, a: &A, pmat: &C, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], Self> = pmat.to_ref();
|
||||
|
||||
let (tmp, _) = scratch.take_slice(
|
||||
Self::vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module,
|
||||
res.size(),
|
||||
a.size(),
|
||||
pmat.rows(),
|
||||
pmat.cols_in(),
|
||||
pmat.cols_out(),
|
||||
pmat.size(),
|
||||
) / size_of::<f64>(),
|
||||
);
|
||||
vmp_apply_dft_to_dft(&mut res, &a, &pmat, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddImpl<Self> for FFT64Ref
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
FFT64Ref: VmpApplyDftToDftTmpBytesImpl<Self>,
|
||||
{
|
||||
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
a: &A,
|
||||
pmat: &C,
|
||||
limb_offset: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], Self> = pmat.to_ref();
|
||||
|
||||
let (tmp, _) = scratch.take_slice(
|
||||
Self::vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module,
|
||||
res.size(),
|
||||
a.size(),
|
||||
pmat.rows(),
|
||||
pmat.cols_in(),
|
||||
pmat.cols_out(),
|
||||
pmat.size(),
|
||||
) / size_of::<f64>(),
|
||||
);
|
||||
vmp_apply_dft_to_dft_add(&mut res, &a, &pmat, limb_offset * pmat.cols_out(), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<Self>, _rows: usize, _cols_in: usize, _cols_out: usize, _size: usize) -> usize {
|
||||
vmp_prepare_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareImpl<Self> for FFT64Ref {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VmpPMatToMut<Self>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
{}
|
||||
let mut res: VmpPMat<&mut [u8], Self> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
let (tmp, _) =
|
||||
scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()) / size_of::<f64>());
|
||||
vmp_prepare(module.get_fft_table(), &mut res, &a, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
_res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
_b_cols_out: usize,
|
||||
_b_size: usize,
|
||||
) -> usize {
|
||||
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
_res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
_b_cols_out: usize,
|
||||
_b_size: usize,
|
||||
) -> usize {
|
||||
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
||||
}
|
||||
}
|
||||
73
poulpy-backend/src/cpu_fft64_ref/zn.rs
Normal file
73
poulpy-backend/src/cpu_fft64_ref/zn.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use poulpy_hal::{
|
||||
api::TakeSlice,
|
||||
layouts::{Scratch, ZnToMut},
|
||||
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
|
||||
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
|
||||
unsafe impl ZnNormalizeTmpBytesImpl<Self> for FFT64Ref {
|
||||
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize {
|
||||
zn_normalize_tmp_bytes(n)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Ref
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn zn_normalize_inplace_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(n);
|
||||
zn_normalize_inplace::<R, FFT64Ref>(n, basek, res, res_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Ref {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
152
poulpy-backend/src/cpu_fft64_ref/znx.rs
Normal file
152
poulpy-backend/src/cpu_fft64_ref/znx.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use poulpy_hal::reference::znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxRotate, ZnxSub, ZnxSubABInplace,
|
||||
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero, znx_add_inplace_ref, znx_add_ref, znx_automorphism_ref, znx_copy_ref,
|
||||
znx_negate_inplace_ref, znx_negate_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
|
||||
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref, znx_rotate,
|
||||
znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref, znx_switch_ring_ref, znx_zero_ref,
|
||||
};
|
||||
|
||||
use crate::cpu_fft64_ref::FFT64Ref;
|
||||
|
||||
impl ZnxAdd for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_add_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAddInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_add_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSub for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_sub_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubABInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ab_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubBAInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ba_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAutomorphism for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_automorphism_ref(p, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegate for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]) {
|
||||
znx_negate_ref(res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegateInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_negate_inplace(res: &mut [i64]) {
|
||||
znx_negate_inplace_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
znx_rotate::<Self>(p, res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod module;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod reim;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod svp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx;
|
||||
|
||||
172
poulpy-backend/src/cpu_spqlios/ffi/reim.rs
Normal file
172
poulpy-backend/src/cpu_spqlios/ffi/reim.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_fft_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFT_PRECOMP = reim_fft_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_ifft_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_IFFT_PRECOMP = reim_ifft_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_mul_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFTVEC_MUL_PRECOMP = reim_mul_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_addmul_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FFTVEC_ADDMUL_PRECOMP = reim_addmul_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_znx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_ZNX32_PRECOMP = reim_from_znx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_znx64_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_ZNX64_PRECOMP = reim_from_znx64_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_from_tnx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_FROM_TNX32_PRECOMP = reim_from_tnx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_tnx32_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_TNX32_PRECOMP = reim_to_tnx32_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_tnx_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_TNX_PRECOMP = reim_to_tnx_precomp;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct reim_to_znx64_precomp {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub type REIM_TO_ZNX64_PRECOMP = reim_to_znx64_precomp;
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fft_precomp(m: u32, num_buffers: u32) -> *mut REIM_FFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft_precomp_get_buffer(tables: *const REIM_FFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fft_buffer(m: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn delete_reim_fft_buffer(buffer: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft(tables: *const REIM_FFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_ifft_precomp(m: u32, num_buffers: u32) -> *mut REIM_IFFT_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft_precomp_get_buffer(tables: *const REIM_IFFT_PRECOMP, buffer_index: u32) -> *mut f64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft(tables: *const REIM_IFFT_PRECOMP, data: *mut f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fftvec_mul_precomp(m: u32) -> *mut REIM_FFTVEC_MUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_mul(tables: *const REIM_FFTVEC_MUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_fftvec_addmul_precomp(m: u32) -> *mut REIM_FFTVEC_ADDMUL_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_addmul(tables: *const REIM_FFTVEC_ADDMUL_PRECOMP, r: *mut f64, a: *const f64, b: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_znx32_precomp(m: u32, log2bound: u32) -> *mut REIM_FROM_ZNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx32(tables: *const REIM_FROM_ZNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx64(tables: *const REIM_FROM_ZNX64_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_znx64_precomp(m: u32, maxbnd: u32) -> *mut REIM_FROM_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx64_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, a: *const i64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_from_tnx32_precomp(m: u32) -> *mut REIM_FROM_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_tnx32(tables: *const REIM_FROM_TNX32_PRECOMP, r: *mut ::std::os::raw::c_void, a: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_tnx32_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX32_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx32(tables: *const REIM_TO_TNX32_PRECOMP, r: *mut i32, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_tnx_precomp(m: u32, divisor: f64, log2overhead: u32) -> *mut REIM_TO_TNX_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx(tables: *const REIM_TO_TNX_PRECOMP, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut f64, a: *const f64);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn new_reim_to_znx64_precomp(m: u32, divisor: f64, log2bound: u32) -> *mut REIM_TO_ZNX64_PRECOMP;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_znx64(precomp: *const REIM_TO_ZNX64_PRECOMP, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_znx64_simple(m: u32, divisor: f64, log2bound: u32, r: *mut i64, a: *const ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_ifft_simple(m: u32, data: *mut ::std::os::raw::c_void);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_mul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
b: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_fftvec_addmul_simple(
|
||||
m: u32,
|
||||
r: *mut ::std::os::raw::c_void,
|
||||
a: *const ::std::os::raw::c_void,
|
||||
b: *const ::std::os::raw::c_void,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_znx32_simple(m: u32, log2bound: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_from_tnx32_simple(m: u32, r: *mut ::std::os::raw::c_void, x: *const i32);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn reim_to_tnx32_simple(m: u32, divisor: f64, log2overhead: u32, r: *mut i32, x: *const ::std::os::raw::c_void);
|
||||
}
|
||||
@@ -7,10 +7,4 @@ mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
pub use module::FFT64;
|
||||
|
||||
/// For external documentation
|
||||
pub use vec_znx::{
|
||||
vec_znx_copy_ref, vec_znx_lsh_inplace_ref, vec_znx_merge_ref, vec_znx_rsh_inplace_ref, vec_znx_split_ref,
|
||||
vec_znx_switch_degree_ref,
|
||||
};
|
||||
pub struct FFT64Spqlios;
|
||||
|
||||
@@ -3,13 +3,23 @@ use std::ptr::NonNull;
|
||||
use poulpy_hal::{
|
||||
layouts::{Backend, Module},
|
||||
oep::ModuleNewImpl,
|
||||
reference::znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxRotate, ZnxSwitchRing, ZnxZero, znx_copy_ref, znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref,
|
||||
znx_normalize_first_step_carry_only_ref, znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref,
|
||||
znx_normalize_middle_step_carry_only_ref, znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
znx_switch_ring_ref, znx_zero_ref,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::ffi::module::{MODULE, delete_module_info, new_module_info};
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64Spqlios,
|
||||
ffi::module::{MODULE, delete_module_info, new_module_info},
|
||||
znx::znx_rotate_i64,
|
||||
};
|
||||
|
||||
pub struct FFT64;
|
||||
|
||||
impl Backend for FFT64 {
|
||||
impl Backend for FFT64Spqlios {
|
||||
type ScalarPrep = f64;
|
||||
type ScalarBig = i64;
|
||||
type Handle = MODULE;
|
||||
@@ -26,8 +36,90 @@ impl Backend for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64 {
|
||||
unsafe impl ModuleNewImpl<Self> for FFT64Spqlios {
|
||||
fn new_impl(n: u64) -> Module<Self> {
|
||||
unsafe { Module::from_raw_parts(new_module_info(n, 0), n) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for FFT64Spqlios {
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for FFT64Spqlios {
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for FFT64Spqlios {
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxRotate for FFT64Spqlios {
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
unsafe {
|
||||
znx_rotate_i64(res.len() as u64, p, res.as_mut_ptr(), src.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,9 @@ use poulpy_hal::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
use crate::cpu_spqlios::FFT64Spqlios;
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64 {
|
||||
unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64Spqlios {
|
||||
fn scratch_owned_alloc_impl(size: usize) -> ScratchOwned<B> {
|
||||
let data: Vec<u8> = alloc_aligned(size);
|
||||
ScratchOwned {
|
||||
@@ -24,7 +24,7 @@ unsafe impl<B: Backend> ScratchOwnedAllocImpl<B> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> ScratchOwnedBorrowImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -33,13 +33,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64 {
|
||||
unsafe impl<B: Backend> ScratchFromBytesImpl<B> for FFT64Spqlios {
|
||||
fn scratch_from_bytes_impl(data: &mut [u8]) -> &mut Scratch<B> {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Scratch<B>) }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64 {
|
||||
unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64Spqlios {
|
||||
fn scratch_available_impl(scratch: &Scratch<B>) -> usize {
|
||||
let ptr: *const u8 = scratch.data.as_ptr();
|
||||
let self_len: usize = scratch.data.len();
|
||||
@@ -48,7 +48,7 @@ unsafe impl<B: Backend> ScratchAvailableImpl<B> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeSliceImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -64,7 +64,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -77,7 +77,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeSvpPPolImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: SvpPPolAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -90,7 +90,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -103,7 +103,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxBigImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VecZnxBigAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -124,7 +124,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxDftImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -146,7 +146,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxDftSliceImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VecZnxDftAllocBytesImpl<B> + ScratchFromBytesImpl<B> + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
@@ -168,7 +168,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVecZnxSliceImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B> + TakeVecZnxImpl<B>,
|
||||
{
|
||||
@@ -190,7 +190,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeVmpPMatImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: VmpPMatAllocBytesImpl<B> + ScratchFromBytesImpl<B>,
|
||||
{
|
||||
@@ -213,7 +213,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64
|
||||
unsafe impl<B: Backend> TakeMatZnxImpl<B> for FFT64Spqlios
|
||||
where
|
||||
B: ScratchFromBytesImpl<B>,
|
||||
{
|
||||
|
||||
@@ -3,33 +3,36 @@ use poulpy_hal::{
|
||||
Backend, Module, ScalarZnxToRef, SvpPPol, SvpPPolOwned, SvpPPolToMut, SvpPPolToRef, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{SvpApplyImpl, SvpApplyInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl, SvpPrepareImpl},
|
||||
oep::{
|
||||
SvpApplyDftToDftImpl, SvpApplyDftToDftInplaceImpl, SvpPPolAllocBytesImpl, SvpPPolAllocImpl, SvpPPolFromBytesImpl,
|
||||
SvpPrepareImpl,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{svp, vec_znx_dft::vec_znx_dft_t},
|
||||
};
|
||||
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPPolFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn svp_ppol_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPPolAllocImpl<Self> for FFT64Spqlios {
|
||||
fn svp_ppol_alloc_impl(n: usize, cols: usize) -> SvpPPolOwned<Self> {
|
||||
SvpPPolOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPPolAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn svp_ppol_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
FFT64::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
FFT64Spqlios::layout_prep_word_count() * n * cols * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
unsafe impl SvpPrepareImpl<Self> for FFT64Spqlios {
|
||||
fn svp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: SvpPPolToMut<Self>,
|
||||
@@ -45,9 +48,16 @@ unsafe impl SvpPrepareImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
fn svp_apply_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
unsafe impl SvpApplyDftToDftImpl<Self> for FFT64Spqlios {
|
||||
fn svp_apply_dft_to_dft_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
B: VecZnxDftToRef<Self>,
|
||||
@@ -70,8 +80,8 @@ unsafe impl SvpApplyImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SvpApplyInplaceImpl for FFT64 {
|
||||
fn svp_apply_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
unsafe impl SvpApplyDftToDftInplaceImpl for FFT64Spqlios {
|
||||
fn svp_apply_dft_to_dft_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: SvpPPolToRef<Self>,
|
||||
|
||||
@@ -1,39 +1,44 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::Normal;
|
||||
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
TakeSlice, TakeVecZnx, VecZnxAddDistF64, VecZnxCopy, VecZnxFillDistF64, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree,
|
||||
},
|
||||
api::{TakeSlice, VecZnxMergeRingsTmpBytes, VecZnxNormalizeTmpBytes, VecZnxSplitRingTmpBytes},
|
||||
layouts::{
|
||||
Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, TakeVecZnxImpl, VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl,
|
||||
VecZnxAddScalarInplaceImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxFillDistF64Impl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl,
|
||||
VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
|
||||
VecZnxSplitImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl,
|
||||
VecZnxSwithcDegreeImpl,
|
||||
TakeSliceImpl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxAutomorphismInplaceTmpBytesImpl, VecZnxCopyImpl,
|
||||
VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshImpl, VecZnxLshInplaceImpl, VecZnxLshTmpBytesImpl,
|
||||
VecZnxMergeRingsImpl, VecZnxMergeRingsTmpBytesImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytesImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
|
||||
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
|
||||
VecZnxSplitRingTmpBytesImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl,
|
||||
},
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add_normal_ref, vec_znx_automorphism_inplace_tmp_bytes, vec_znx_copy, vec_znx_fill_normal_ref,
|
||||
vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings,
|
||||
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes,
|
||||
vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes,
|
||||
vec_znx_switch_ring,
|
||||
},
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{module::module_info_t, vec_znx, znx},
|
||||
};
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxNormalizeImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
@@ -75,7 +80,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64
|
||||
unsafe impl VecZnxNormalizeInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self> + VecZnxNormalizeTmpBytesImpl<Self>,
|
||||
{
|
||||
@@ -108,7 +113,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -142,7 +147,7 @@ unsafe impl VecZnxAddImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -172,7 +177,7 @@ unsafe impl VecZnxAddInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_scalar_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
@@ -209,7 +214,60 @@ unsafe impl VecZnxAddScalarInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAddScalarImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_scalar_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, b_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, b_limb),
|
||||
1_u64,
|
||||
b.sl() as u64,
|
||||
);
|
||||
|
||||
for j in 0..min_size {
|
||||
if j != b_limb {
|
||||
znx_copy_ref(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_impl<R, A, C>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &C, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -243,7 +301,7 @@ unsafe impl VecZnxSubImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -272,7 +330,7 @@ unsafe impl VecZnxSubABInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -301,7 +359,60 @@ unsafe impl VecZnxSubBAInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxSubScalarImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_scalar_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
b_limb: usize,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
module.ptr() as *const module_info_t,
|
||||
res.at_mut_ptr(res_col, b_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
b.at_ptr(b_col, b_limb),
|
||||
1_u64,
|
||||
b.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
);
|
||||
|
||||
for j in 0..min_size {
|
||||
if j != b_limb {
|
||||
res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j))
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_sub_scalar_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
@@ -327,18 +438,18 @@ unsafe impl VecZnxSubScalarInplaceImpl<Self> for FFT64 {
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1_u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxNegateImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -364,7 +475,7 @@ unsafe impl VecZnxNegateImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
@@ -384,92 +495,105 @@ unsafe impl VecZnxNegateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_lsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
unsafe impl VecZnxLshTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_lsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_lsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_rsh_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rsh_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxLshInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_lsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
vec_znx_lsh_inplace_ref(basek, k, a)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_lsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
|
||||
unsafe impl VecZnxRshImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_left(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(size - steps..size).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.at_mut(i, j).iter_mut().for_each(|xi| {
|
||||
*xi <<= shift;
|
||||
});
|
||||
});
|
||||
});
|
||||
fn vec_znx_rsh_inplace_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh::<_, _, FFT64Spqlios>(basek, k, res, res_col, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_rsh_inplace_impl<A>(_module: &Module<Self>, basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
unsafe impl VecZnxRshInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Module<Self>: VecZnxNormalizeTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rsh_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
k: usize,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
vec_znx_rsh_inplace_ref(basek, k, a)
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_rsh_inplace::<_, FFT64Spqlios>(basek, k, a, a_col, carry)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace_ref<A>(basek: usize, k: usize, a: &mut A)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
let n: usize = a.n();
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / basek;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k_rem != 0 {
|
||||
let mut carry: Vec<i64> = vec![0i64; n]; // ALLOC (but small so OK)
|
||||
let shift: usize = i64::BITS as usize - k_rem;
|
||||
(0..cols).for_each(|i| {
|
||||
carry.fill(0);
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << basek;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem;
|
||||
});
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxRotateImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_rotate_impl<R, A>(_module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -482,7 +606,8 @@ unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
unsafe {
|
||||
(0..a.size()).for_each(|j| {
|
||||
let min_size = res.size().min(a.size());
|
||||
(0..min_size).for_each(|j| {
|
||||
znx::znx_rotate_i64(
|
||||
a.n() as u64,
|
||||
k,
|
||||
@@ -490,12 +615,28 @@ unsafe impl VecZnxRotateImpl<Self> for FFT64 {
|
||||
a.at_ptr(a_col, j),
|
||||
);
|
||||
});
|
||||
|
||||
(min_size..res.size()).for_each(|j| {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxRotateInplaceTmpBytesImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_rotate_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_rotate_inplace_impl<A>(_module: &Module<Self>, k: i64, a: &mut A, a_col: usize, _scratch: &mut Scratch<Self>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -508,7 +649,7 @@ unsafe impl VecZnxRotateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxAutomorphismImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -535,8 +676,14 @@ unsafe impl VecZnxAutomorphismImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxAutomorphismInplaceTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_automorphism_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_automorphism_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize, _scratch: &mut Scratch<Self>)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
@@ -564,7 +711,7 @@ unsafe impl VecZnxAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_mul_xp_minus_one_impl<R, A>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
@@ -592,9 +739,20 @@ unsafe impl VecZnxMulXpMinusOneImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64 {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(module: &Module<Self>, p: i64, res: &mut R, res_col: usize)
|
||||
where
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_mul_xp_minus_one_inplace_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_mul_xp_minus_one_inplace_impl<R>(
|
||||
module: &Module<Self>,
|
||||
p: i64,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
_scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
@@ -617,15 +775,18 @@ unsafe impl VecZnxMulXpMinusOneInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitImpl<Self> for FFT64
|
||||
unsafe impl VecZnxSplitRingTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_split_ring_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_split_ring_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSplitRingImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeVecZnxImpl<Self>
|
||||
+ TakeVecZnxImpl<Self>
|
||||
+ VecZnxSwithcDegreeImpl<Self>
|
||||
+ VecZnxRotateImpl<Self>
|
||||
+ VecZnxRotateInplaceImpl<Self>,
|
||||
Module<Self>: VecZnxSplitRingTmpBytes,
|
||||
Scratch<Self>: TakeSlice,
|
||||
{
|
||||
fn vec_znx_split_impl<R, A>(
|
||||
fn vec_znx_split_ring_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
@@ -636,287 +797,72 @@ where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_split_ref(module, res, res_col, a, a_col, scratch)
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_split_ring_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_split_ring::<_, _, FFT64Spqlios>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_split_ref<R, A, B>(
|
||||
module: &Module<B>,
|
||||
res: &mut [R],
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
B: Backend + TakeVecZnxImpl<B> + VecZnxSwithcDegreeImpl<B> + VecZnxRotateImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
module.vec_znx_switch_degree(bi, res_col, &a, a_col);
|
||||
module.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
|
||||
} else {
|
||||
module.vec_znx_switch_degree(bi, res_col, &buf, a_col);
|
||||
module.vec_znx_rotate_inplace(-1, &mut buf, a_col);
|
||||
}
|
||||
})
|
||||
unsafe impl VecZnxMergeRingsTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_merge_rings_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
vec_znx_merge_rings_tmp_bytes(module.n())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxMergeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxMergeRingsImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: VecZnxSwithcDegreeImpl<Self> + VecZnxRotateInplaceImpl<Self>,
|
||||
Module<Self>: VecZnxMergeRingsTmpBytes,
|
||||
{
|
||||
fn vec_znx_merge_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
fn vec_znx_merge_rings_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &[A],
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_merge_ref(module, res, res_col, a, a_col)
|
||||
let (tmp, _) = scratch.take_slice(module.vec_znx_merge_rings_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_merge_rings::<_, _, FFT64Spqlios>(res, res_col, a, a_col, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_merge_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &[A], a_col: usize)
|
||||
where
|
||||
B: Backend + VecZnxSwithcDegreeImpl<B> + VecZnxRotateInplaceImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().for_each(|ai| {
|
||||
module.vec_znx_switch_degree(&mut res, res_col, ai, a_col);
|
||||
module.vec_znx_rotate_inplace(-1, &mut res, res_col);
|
||||
});
|
||||
|
||||
module.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||
}
|
||||
|
||||
unsafe impl VecZnxSwithcDegreeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxSwitchRingImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: VecZnxCopyImpl<Self>,
|
||||
{
|
||||
fn vec_znx_switch_degree_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
fn vec_znx_switch_ring_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_switch_degree_ref(module, res, res_col, a, a_col)
|
||||
vec_znx_switch_ring::<_, _, FFT64Spqlios>(res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_switch_degree_ref<R, A, B>(module: &Module<B>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
B: Backend + VecZnxCopyImpl<B>,
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
|
||||
if n_in == n_out {
|
||||
module.vec_znx_copy(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
res.zero();
|
||||
}
|
||||
|
||||
let size: usize = a.size().min(res.size());
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at(a_col, i).iter().step_by(gap_in),
|
||||
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxCopyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_copy_impl<R, A>(_module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
vec_znx_copy_ref(res, res_col, a, a_col)
|
||||
vec_znx_copy::<_, _, FFT64Spqlios>(res, res_col, a, a_col)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_copy_ref<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
(0..min_size).for_each(|j| {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, j));
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64 {
|
||||
fn vec_znx_fill_uniform_impl<R>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
) where
|
||||
unsafe impl VecZnxFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_fill_uniform_impl<R>(_module: &Module<Self>, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
vec_znx_fill_uniform_ref(basek, res, res_col, source)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillDistF64Impl<Self> for FFT64 {
|
||||
fn vec_znx_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddDistF64Impl<Self> for FFT64 {
|
||||
fn vec_znx_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxFillDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl VecZnxFillNormalImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_fill_normal_impl<R>(
|
||||
module: &Module<Self>,
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -927,24 +873,13 @@ where
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
module.vec_znx_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
vec_znx_fill_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: VecZnxAddDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl VecZnxAddNormalImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_add_normal_impl<R>(
|
||||
module: &Module<Self>,
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
@@ -955,14 +890,6 @@ where
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
module.vec_znx_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
vec_znx_add_normal_ref(basek, res, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,170 +1,98 @@
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::cpu_spqlios::{FFT64, ffi::vec_znx};
|
||||
use crate::cpu_spqlios::{FFT64Spqlios, ffi::vec_znx};
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxBigAddDistF64, VecZnxBigFillDistF64, VecZnxBigNormalizeTmpBytes},
|
||||
api::{TakeSlice, VecZnxBigNormalizeTmpBytes},
|
||||
layouts::{
|
||||
Backend, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef,
|
||||
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
TakeSliceImpl, VecZnxBigAddDistF64Impl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl,
|
||||
VecZnxBigAddSmallImpl, VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl,
|
||||
VecZnxBigAutomorphismImpl, VecZnxBigAutomorphismInplaceImpl, VecZnxBigFillDistF64Impl, VecZnxBigFillNormalImpl,
|
||||
VecZnxBigFromBytesImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl, VecZnxBigNormalizeTmpBytesImpl,
|
||||
VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl, VecZnxBigSubSmallAImpl,
|
||||
VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
TakeSliceImpl, VecZnxBigAddImpl, VecZnxBigAddInplaceImpl, VecZnxBigAddNormalImpl, VecZnxBigAddSmallImpl,
|
||||
VecZnxBigAddSmallInplaceImpl, VecZnxBigAllocBytesImpl, VecZnxBigAllocImpl, VecZnxBigAutomorphismImpl,
|
||||
VecZnxBigAutomorphismInplaceImpl, VecZnxBigAutomorphismInplaceTmpBytesImpl, VecZnxBigFromBytesImpl,
|
||||
VecZnxBigFromSmallImpl, VecZnxBigNegateImpl, VecZnxBigNegateInplaceImpl, VecZnxBigNormalizeImpl,
|
||||
VecZnxBigNormalizeTmpBytesImpl, VecZnxBigSubABInplaceImpl, VecZnxBigSubBAInplaceImpl, VecZnxBigSubImpl,
|
||||
VecZnxBigSubSmallAImpl, VecZnxBigSubSmallAInplaceImpl, VecZnxBigSubSmallBImpl, VecZnxBigSubSmallBInplaceImpl,
|
||||
},
|
||||
reference::{
|
||||
vec_znx::vec_znx_add_normal_ref,
|
||||
znx::{znx_copy_ref, znx_zero_ref},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAllocImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<Self> {
|
||||
VecZnxBig::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddDistF64Impl<Self> for FFT64 {
|
||||
fn add_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
unsafe impl VecZnxBigFromSmallImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_from_small_impl<R, A>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x += dist_f64.round() as i64
|
||||
});
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let min_size: usize = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
znx_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
znx_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddNormalImpl<Self> for FFT64Spqlios {
|
||||
fn add_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_add_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillDistF64Impl<Self> for FFT64 {
|
||||
fn fill_dist_f64_impl<R: VecZnxBigToMut<Self>, D: Distribution<f64>>(
|
||||
_module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
res.at_mut(res_col, limb).iter_mut().for_each(|x| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*x = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigFillNormalImpl<Self> for FFT64 {
|
||||
fn fill_normal_impl<R: VecZnxBigToMut<Self>>(
|
||||
module: &Module<Self>,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) {
|
||||
module.vec_znx_big_fill_dist_f64(
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
let res: VecZnxBig<&mut [u8], FFT64Spqlios> = res.to_mut();
|
||||
|
||||
let mut res_znx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_normal_ref(basek, &mut res_znx, res_col, k, sigma, bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
@@ -199,7 +127,7 @@ unsafe impl VecZnxBigAddImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -230,7 +158,7 @@ unsafe impl VecZnxBigAddInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
@@ -272,7 +200,7 @@ unsafe impl VecZnxBigAddSmallImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -303,7 +231,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_impl<R, A, B>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
@@ -338,7 +266,7 @@ unsafe impl VecZnxBigSubImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -369,7 +297,7 @@ unsafe impl VecZnxBigSubABInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -400,7 +328,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
@@ -442,7 +370,7 @@ unsafe impl VecZnxBigSubSmallAImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -473,7 +401,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
@@ -515,7 +443,7 @@ unsafe impl VecZnxBigSubSmallBImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -546,7 +474,29 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigNegateImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_negate_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], Self> = a.to_ref();
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_negate_inplace_impl<A>(module: &Module<Self>, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<Self>,
|
||||
@@ -566,13 +516,13 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64
|
||||
unsafe impl VecZnxBigNormalizeImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
@@ -613,7 +563,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64Spqlios {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism_impl<R, A>(module: &Module<Self>, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
@@ -642,10 +592,21 @@ unsafe impl VecZnxBigAutomorphismImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxBigAutomorphismInplaceTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_big_automorphism_inplace_tmp_bytes_impl(_module: &Module<Self>) -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigAutomorphismInplaceImpl<Self> for FFT64Spqlios {
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(module: &Module<Self>, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
fn vec_znx_big_automorphism_inplace_impl<A>(
|
||||
module: &Module<Self>,
|
||||
k: i64,
|
||||
a: &mut A,
|
||||
a_col: usize,
|
||||
_scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
A: VecZnxBigToMut<Self>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], Self> = a.to_mut();
|
||||
|
||||
@@ -1,60 +1,73 @@
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, VecZnxIDFTTmpBytes},
|
||||
api::{TakeSlice, VecZnxIdftApplyTmpBytes},
|
||||
layouts::{
|
||||
Backend, Data, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::{
|
||||
DFTImpl, IDFTConsumeImpl, IDFTImpl, IDFTTmpAImpl, VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl,
|
||||
VecZnxDftAllocImpl, VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl,
|
||||
VecZnxDftSubImpl, VecZnxDftZeroImpl, VecZnxIDFTTmpBytesImpl,
|
||||
VecZnxDftAddImpl, VecZnxDftAddInplaceImpl, VecZnxDftAllocBytesImpl, VecZnxDftAllocImpl, VecZnxDftApplyImpl,
|
||||
VecZnxDftCopyImpl, VecZnxDftFromBytesImpl, VecZnxDftSubABInplaceImpl, VecZnxDftSubBAInplaceImpl, VecZnxDftSubImpl,
|
||||
VecZnxDftZeroImpl, VecZnxIdftApplyConsumeImpl, VecZnxIdftApplyImpl, VecZnxIdftApplyTmpAImpl, VecZnxIdftApplyTmpBytesImpl,
|
||||
},
|
||||
reference::{
|
||||
fft64::{
|
||||
reim::{ReimCopy, ReimZero, reim_copy_ref, reim_negate_inplace_ref, reim_negate_ref, reim_zero_ref},
|
||||
vec_znx_dft::vec_znx_dft_copy,
|
||||
},
|
||||
znx::znx_zero_ref,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{vec_znx_big, vec_znx_dft},
|
||||
};
|
||||
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDft::<Vec<u8>, FFT64>::from_bytes(n, cols, size, bytes)
|
||||
VecZnxDft::<Vec<u8>, Self>::from_bytes(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
FFT64::layout_prep_word_count() * n * cols * size * size_of::<<FFT64 as Backend>::ScalarPrep>()
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Spqlios as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAllocImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxDftOwned<Self> {
|
||||
VecZnxDftOwned::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxIDFTTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_idft_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe impl VecZnxIdftApplyTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl IDFTImpl<Self> for FFT64 {
|
||||
fn idft_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
unsafe impl VecZnxIdftApplyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_idft_tmp_bytes());
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_idft_apply_tmp_bytes());
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
@@ -69,47 +82,43 @@ unsafe impl IDFTImpl<Self> for FFT64 {
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
(min_size..res.size()).for_each(|j| {
|
||||
res.zero_at(res_col, j);
|
||||
});
|
||||
(min_size..res.size()).for_each(|j| znx_zero_ref(res.at_mut(res_col, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl IDFTTmpAImpl<Self> for FFT64 {
|
||||
fn idft_tmp_a_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
unsafe impl VecZnxIdftApplyTmpAImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_tmpa_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxDftToMut<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
let mut res: VecZnxBig<&mut [u8], Self> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], Self> = a.to_mut();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_mut.size());
|
||||
let min_size: usize = res.size().min(a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1_u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
(min_size..res.size()).for_each(|j| znx_zero_ref(res.at_mut(res_col, j)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl IDFTConsumeImpl<Self> for FFT64 {
|
||||
fn idft_consume_impl<D: Data>(module: &Module<Self>, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
unsafe impl VecZnxIdftApplyConsumeImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_idft_apply_consume_impl<D: Data>(module: &Module<Self>, mut a: VecZnxDft<D, Self>) -> VecZnxBig<D, Self>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<Self>,
|
||||
VecZnxDft<D, Self>: VecZnxDftToMut<Self>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], Self> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
@@ -130,89 +139,129 @@ unsafe impl IDFTConsumeImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl DFTImpl<Self> for FFT64 {
|
||||
fn dft_impl<R, A>(module: &Module<Self>, step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
unsafe impl VecZnxDftApplyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let steps: usize = a.size().div_ceil(step);
|
||||
let min_steps: usize = res.size().min(steps);
|
||||
unsafe {
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
if limb < a.size() {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1_u64,
|
||||
a_ref.at_ptr(a_col, limb),
|
||||
a.at_ptr(a_col, limb),
|
||||
1_u64,
|
||||
a_ref.sl() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
(min_steps..res.size()).for_each(|j| reim_zero_ref(res.at_mut(res_col, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAddImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_add_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
D: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_copy_ref(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_copy_ref(res.at_mut(res_col, j), a.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_add_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
@@ -220,58 +269,93 @@ unsafe impl VecZnxDftAddInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftSubImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_impl<R, A, D>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
D: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], Self> = b.to_ref();
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_negate_ref(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
(0..sum_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
reim_copy_ref(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
reim_zero_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_ab_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
@@ -279,34 +363,38 @@ unsafe impl VecZnxDftSubABInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftSubBAInplaceImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_sub_ba_inplace_impl<R, A>(module: &Module<Self>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], Self> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], Self> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_sub(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
res.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
|
||||
for j in min_size..res.size() {
|
||||
reim_negate_inplace_ref(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64 {
|
||||
unsafe impl VecZnxDftCopyImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_copy_impl<R, A>(
|
||||
_module: &Module<Self>,
|
||||
step: usize,
|
||||
@@ -319,27 +407,25 @@ unsafe impl VecZnxDftCopyImpl<Self> for FFT64 {
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let steps: usize = a_ref.size().div_ceil(step);
|
||||
let min_steps: usize = res_mut.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a_ref.size() {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
vec_znx_dft_copy(step, offset, res, res_col, a, a_col);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64 {
|
||||
impl ReimCopy for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]) {
|
||||
reim_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ReimZero for FFT64Spqlios {
|
||||
#[inline(always)]
|
||||
fn reim_zero(res: &mut [f64]) {
|
||||
reim_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Spqlios {
|
||||
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
|
||||
@@ -6,22 +6,22 @@ use poulpy_hal::{
|
||||
},
|
||||
oep::{
|
||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPMatPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::{
|
||||
FFT64,
|
||||
FFT64Spqlios,
|
||||
ffi::{vec_znx_dft::vec_znx_dft_t, vmp},
|
||||
};
|
||||
|
||||
unsafe impl VmpPMatAllocBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpPMatAllocBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_pmat_alloc_bytes_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
FFT64::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
Self::layout_prep_word_count() * n * rows * cols_in * cols_out * size * size_of::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpPMatFromBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_pmat_from_bytes_impl(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
@@ -29,19 +29,19 @@ unsafe impl VmpPMatFromBytesImpl<FFT64> for FFT64 {
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<FFT64> {
|
||||
) -> VmpPMatOwned<Self> {
|
||||
VmpPMatOwned::from_bytes(n, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<FFT64> {
|
||||
unsafe impl VmpPMatAllocImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_pmat_alloc_impl(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<Self> {
|
||||
VmpPMatOwned::alloc(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe impl VmpPrepareTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<Self>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
@@ -52,13 +52,13 @@ unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch<FFT64>)
|
||||
unsafe impl VmpPrepareImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_prepare_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VmpPMatToMut<FFT64>,
|
||||
R: VmpPMatToMut<Self>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: VmpPMat<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut res: VmpPMat<&mut [u8], Self> = res.to_mut();
|
||||
let a: MatZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -109,9 +109,9 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyDftToDftTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -131,12 +131,12 @@ unsafe impl VmpApplyDftToDftTmpBytesImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<FFT64>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<FFT64>)
|
||||
unsafe impl VmpApplyDftToDftImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_impl<R, A, C>(module: &Module<Self>, res: &mut R, a: &A, b: &C, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
@@ -186,9 +186,9 @@ unsafe impl VmpApplyDftToDftImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -208,18 +208,18 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl VmpApplyDftToDftAddImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyDftToDftAddImpl<Self> for FFT64Spqlios {
|
||||
fn vmp_apply_dft_to_dft_add_impl<R, A, C>(
|
||||
module: &Module<FFT64>,
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
a: &A,
|
||||
b: &C,
|
||||
scale: usize,
|
||||
scratch: &mut Scratch<FFT64>,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
C: VmpPMatToRef<FFT64>,
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: VecZnxDftToRef<Self>,
|
||||
C: VmpPMatToRef<Self>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
use poulpy_hal::{
|
||||
api::TakeSlice,
|
||||
layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
oep::{
|
||||
TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl,
|
||||
ZnNormalizeInplaceImpl,
|
||||
},
|
||||
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
|
||||
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform},
|
||||
source::Source,
|
||||
};
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::cpu_spqlios::{FFT64, ffi::zn64};
|
||||
use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64};
|
||||
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
@@ -39,113 +36,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64 {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)[..n]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
zn_fill_uniform(n, basek, res, res_col, source);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillDistF64Impl<Self> for FFT64 {
|
||||
fn zn_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddDistF64Impl<Self> for FFT64 {
|
||||
fn zn_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnFillDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
@@ -158,23 +59,12 @@ where
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_fill_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
zn_fill_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnAddDistF64Impl<Self>,
|
||||
{
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
@@ -187,15 +77,6 @@ where
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_add_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
zn_add_normal(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ mod fft64;
|
||||
mod ntt120;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
mod tests;
|
||||
|
||||
pub use ffi::*;
|
||||
pub use fft64::*;
|
||||
pub use ntt120::*;
|
||||
|
||||
Submodule poulpy-backend/src/cpu_spqlios/spqlios-arithmetic updated: 708e5d7e86...b6938df774
@@ -1,2 +0,0 @@
|
||||
mod vec_znx_fft64;
|
||||
mod vmp_pmat_fft64;
|
||||
@@ -1,19 +0,0 @@
|
||||
use poulpy_hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
tests::vec_znx::{test_vec_znx_add_normal, test_vec_znx_fill_uniform},
|
||||
};
|
||||
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_fill_uniform_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_fill_uniform(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_add_normal_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << 12);
|
||||
test_vec_znx_add_normal(&module);
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
use poulpy_hal::tests::vmp_pmat::test_vmp_apply;
|
||||
|
||||
use crate::cpu_spqlios::FFT64;
|
||||
|
||||
#[test]
|
||||
fn vmp_apply() {
|
||||
test_vmp_apply::<FFT64>();
|
||||
}
|
||||
117
poulpy-backend/src/cpu_spqlios/tests.rs
Normal file
117
poulpy-backend/src/cpu_spqlios/tests.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use poulpy_hal::{backend_test_suite, cross_backend_test_suite};
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_add => poulpy_hal::test_suite::vec_znx::test_vec_znx_add,
|
||||
test_vec_znx_add_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_inplace,
|
||||
test_vec_znx_add_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar,
|
||||
test_vec_znx_add_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_scalar_inplace,
|
||||
test_vec_znx_sub => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub,
|
||||
test_vec_znx_sub_ab_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ab_inplace,
|
||||
test_vec_znx_sub_ba_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_ba_inplace,
|
||||
test_vec_znx_sub_scalar => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar,
|
||||
test_vec_znx_sub_scalar_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_sub_scalar_inplace,
|
||||
test_vec_znx_rsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh,
|
||||
test_vec_znx_rsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rsh_inplace,
|
||||
test_vec_znx_lsh => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh,
|
||||
test_vec_znx_lsh_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_lsh_inplace,
|
||||
test_vec_znx_negate => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate,
|
||||
test_vec_znx_negate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_negate_inplace,
|
||||
test_vec_znx_rotate => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate,
|
||||
test_vec_znx_rotate_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_rotate_inplace,
|
||||
test_vec_znx_automorphism => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism,
|
||||
test_vec_znx_automorphism_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_automorphism_inplace,
|
||||
test_vec_znx_mul_xp_minus_one => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one,
|
||||
test_vec_znx_mul_xp_minus_one_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_mul_xp_minus_one_inplace,
|
||||
test_vec_znx_normalize => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize,
|
||||
test_vec_znx_normalize_inplace => poulpy_hal::test_suite::vec_znx::test_vec_znx_normalize_inplace,
|
||||
test_vec_znx_switch_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_switch_ring,
|
||||
test_vec_znx_split_ring => poulpy_hal::test_suite::vec_znx::test_vec_znx_split_ring,
|
||||
test_vec_znx_copy => poulpy_hal::test_suite::vec_znx::test_vec_znx_copy,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod svp,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_svp_apply_dft_to_dft => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft,
|
||||
test_svp_apply_dft_to_dft_inplace => poulpy_hal::test_suite::svp::test_svp_apply_dft_to_dft_inplace,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx_big,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_big_add => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add,
|
||||
test_vec_znx_big_add_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_inplace,
|
||||
test_vec_znx_big_add_small => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small,
|
||||
test_vec_znx_big_add_small_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_add_small_inplace,
|
||||
test_vec_znx_big_sub => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub,
|
||||
test_vec_znx_big_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ab_inplace,
|
||||
test_vec_znx_big_automorphism => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism,
|
||||
test_vec_znx_big_automorphism_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_automorphism_inplace,
|
||||
test_vec_znx_big_negate => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate,
|
||||
test_vec_znx_big_negate_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_negate_inplace,
|
||||
test_vec_znx_big_normalize => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_normalize,
|
||||
test_vec_znx_big_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_ba_inplace,
|
||||
test_vec_znx_big_sub_small_a => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a,
|
||||
test_vec_znx_big_sub_small_a_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_a_inplace,
|
||||
test_vec_znx_big_sub_small_b => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b,
|
||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::test_suite::vec_znx_big::test_vec_znx_big_sub_small_b_inplace,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vec_znx_dft,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vec_znx_dft_add => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add,
|
||||
test_vec_znx_dft_add_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_add_inplace,
|
||||
test_vec_znx_dft_sub => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub,
|
||||
test_vec_znx_dft_sub_ab_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ab_inplace,
|
||||
test_vec_znx_dft_sub_ba_inplace => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_dft_sub_ba_inplace,
|
||||
test_vec_znx_idft_apply => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply,
|
||||
test_vec_znx_idft_apply_consume => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_consume,
|
||||
test_vec_znx_idft_apply_tmpa => poulpy_hal::test_suite::vec_znx_dft::test_vec_znx_idft_apply_tmpa,
|
||||
}
|
||||
}
|
||||
|
||||
cross_backend_test_suite! {
|
||||
mod vmp,
|
||||
backend_ref = crate::cpu_fft64_ref::FFT64Ref,
|
||||
backend_test = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 5,
|
||||
basek = 12,
|
||||
tests = {
|
||||
test_vmp_apply_dft_to_dft => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft,
|
||||
test_vmp_apply_dft_to_dft_add => poulpy_hal::test_suite::vmp::test_vmp_apply_dft_to_dft_add,
|
||||
}
|
||||
}
|
||||
|
||||
backend_test_suite! {
|
||||
mod sampling,
|
||||
backend = crate::cpu_spqlios::FFT64Spqlios,
|
||||
size = 1 << 12,
|
||||
tests = {
|
||||
test_vec_znx_fill_uniform => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_uniform,
|
||||
test_vec_znx_fill_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_fill_normal,
|
||||
test_vec_znx_add_normal => poulpy_hal::test_suite::vec_znx::test_vec_znx_add_normal,
|
||||
test_vec_znx_big_sub_small_b_inplace => poulpy_hal::reference::fft64::vec_znx_big::test_vec_znx_big_add_normal,
|
||||
}
|
||||
}
|
||||
@@ -1 +1,7 @@
|
||||
pub mod cpu_fft64_avx;
|
||||
pub mod cpu_fft64_ref;
|
||||
pub mod cpu_spqlios;
|
||||
|
||||
pub use cpu_fft64_avx::FFT64Avx;
|
||||
pub use cpu_fft64_ref::FFT64Ref;
|
||||
pub use cpu_spqlios::FFT64Spqlios;
|
||||
|
||||
@@ -15,6 +15,7 @@ poulpy-hal = {path="../poulpy-hal"}
|
||||
poulpy-backend = {path="../poulpy-backend"}
|
||||
itertools = {workspace = true}
|
||||
byteorder = {workspace = true}
|
||||
once_cell = {workspace = true}
|
||||
|
||||
[[bench]]
|
||||
name = "external_product_glwe_fft64"
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
|
||||
use poulpy_backend::cpu_spqlios::FFT64;
|
||||
use poulpy_backend::cpu_spqlios::FFT64Spqlios;
|
||||
use poulpy_hal::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned},
|
||||
@@ -26,7 +26,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) {
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(1 << p.log_n);
|
||||
|
||||
let n: usize = module.n();
|
||||
let basek: usize = p.basek;
|
||||
@@ -43,7 +43,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) {
|
||||
let mut ct_glwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_ct_out, rank);
|
||||
let pt_rgsw: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
|
||||
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::alloc(
|
||||
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::alloc(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank)
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe_in.k())
|
||||
| GLWECiphertext::external_product_scratch_space(
|
||||
@@ -63,7 +63,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) {
|
||||
|
||||
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let sk_dft: GLWESecretPrepared<Vec<u8>, FFT64> = sk.prepare_alloc(&module, scratch.borrow());
|
||||
let sk_dft: GLWESecretPrepared<Vec<u8>, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
ct_ggsw.encrypt_sk(
|
||||
&module,
|
||||
@@ -82,7 +82,7 @@ fn bench_external_product_glwe_fft64(c: &mut Criterion) {
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let ggsw_prepared: GGSWCiphertextPrepared<Vec<u8>, FFT64> = ct_ggsw.prepare_alloc(&module, scratch.borrow());
|
||||
let ggsw_prepared: GGSWCiphertextPrepared<Vec<u8>, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
move || {
|
||||
ct_glwe_out.external_product(&module, &ct_glwe_in, &ggsw_prepared, scratch.borrow());
|
||||
@@ -120,7 +120,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(1 << p.log_n);
|
||||
|
||||
let n = module.n();
|
||||
let basek: usize = p.basek;
|
||||
@@ -135,7 +135,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
let mut ct_glwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_glwe, rank);
|
||||
let pt_rgsw: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
|
||||
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::alloc(
|
||||
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::alloc(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank)
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k())
|
||||
| GLWECiphertext::external_product_inplace_scratch_space(&module, basek, ct_glwe.k(), ct_ggsw.k(), digits, rank),
|
||||
@@ -147,7 +147,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
|
||||
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let sk_dft: GLWESecretPrepared<Vec<u8>, FFT64> = sk.prepare_alloc(&module, scratch.borrow());
|
||||
let sk_dft: GLWESecretPrepared<Vec<u8>, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
ct_ggsw.encrypt_sk(
|
||||
&module,
|
||||
@@ -166,7 +166,7 @@ fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let ggsw_prepared: GGSWCiphertextPrepared<Vec<u8>, FFT64> = ct_ggsw.prepare_alloc(&module, scratch.borrow());
|
||||
let ggsw_prepared: GGSWCiphertextPrepared<Vec<u8>, FFT64Spqlios> = ct_ggsw.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
move || {
|
||||
let scratch_borrow = scratch.borrow();
|
||||
|
||||
@@ -5,7 +5,7 @@ use poulpy_core::layouts::{
|
||||
use std::{hint::black_box, time::Duration};
|
||||
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use poulpy_backend::cpu_spqlios::FFT64;
|
||||
use poulpy_backend::cpu_spqlios::FFT64Spqlios;
|
||||
use poulpy_hal::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},
|
||||
layouts::{Module, ScratchOwned},
|
||||
@@ -27,7 +27,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(1 << p.log_n);
|
||||
|
||||
let n = module.n();
|
||||
let basek: usize = p.basek;
|
||||
@@ -44,7 +44,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
|
||||
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_rlwe_in, rank_in);
|
||||
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_rlwe_out, rank_out);
|
||||
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::alloc(
|
||||
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::alloc(
|
||||
GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank_in, rank_out)
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k())
|
||||
| GLWECiphertext::keyswitch_scratch_space(
|
||||
@@ -65,7 +65,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
|
||||
|
||||
let mut sk_in: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let sk_in_dft: GLWESecretPrepared<Vec<u8>, FFT64> = sk_in.prepare_alloc(&module, scratch.borrow());
|
||||
let sk_in_dft: GLWESecretPrepared<Vec<u8>, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
let mut sk_out: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
@@ -132,7 +132,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(1 << p.log_n);
|
||||
|
||||
let n = module.n();
|
||||
let basek: usize = p.basek;
|
||||
@@ -146,7 +146,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
let mut ksk: GGLWESwitchingKey<Vec<u8>> = GGLWESwitchingKey::alloc(n, basek, k_ksk, rows, digits, rank, rank);
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_ct, rank);
|
||||
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::alloc(
|
||||
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::alloc(
|
||||
GGLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank, rank)
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k())
|
||||
| GLWECiphertext::keyswitch_inplace_scratch_space(&module, basek, ct.k(), ksk.k(), digits, rank),
|
||||
@@ -158,7 +158,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
|
||||
let mut sk_in: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let sk_in_dft: GLWESecretPrepared<Vec<u8>, FFT64> = sk_in.prepare_alloc(&module, scratch.borrow());
|
||||
let sk_in_dft: GLWESecretPrepared<Vec<u8>, FFT64Spqlios> = sk_in.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
let mut sk_out: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
@@ -180,7 +180,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let ksk_prepared: GGLWESwitchingKeyPrepared<Vec<u8>, FFT64> = ksk.prepare_alloc(&module, scratch.borrow());
|
||||
let ksk_prepared: GGLWESwitchingKeyPrepared<Vec<u8>, FFT64Spqlios> = ksk.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
move || {
|
||||
ct.keyswitch_inplace(&module, &ksk_prepared, scratch.borrow());
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use poulpy_backend::cpu_spqlios::FFT64;
|
||||
use poulpy_backend::cpu_spqlios::FFT64Spqlios;
|
||||
use poulpy_core::{
|
||||
GLWEOperations, SIGMA,
|
||||
layouts::{
|
||||
@@ -31,7 +31,7 @@ fn main() {
|
||||
let rank: usize = 1;
|
||||
|
||||
// Instantiate Module (DFT Tables)
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n as u64);
|
||||
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(n as u64);
|
||||
|
||||
// Allocates ciphertext & plaintexts
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_ct, rank);
|
||||
@@ -44,7 +44,7 @@ fn main() {
|
||||
let mut source_xa: Source = Source::new([2u8; 32]);
|
||||
|
||||
// Scratch space
|
||||
let mut scratch: ScratchOwned<FFT64> = ScratchOwned::alloc(
|
||||
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::alloc(
|
||||
GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k()),
|
||||
);
|
||||
@@ -54,10 +54,10 @@ fn main() {
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
// Backend-prepared secret
|
||||
let sk_prepared: GLWESecretPrepared<Vec<u8>, FFT64> = sk.prepare_alloc(&module, scratch.borrow());
|
||||
let sk_prepared: GLWESecretPrepared<Vec<u8>, FFT64Spqlios> = sk.prepare_alloc(&module, scratch.borrow());
|
||||
|
||||
// Uniform plaintext
|
||||
module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, k_pt, &mut source_xa);
|
||||
module.vec_znx_fill_uniform(basek, &mut pt_want.data, 0, &mut source_xa);
|
||||
|
||||
// Encryption
|
||||
ct.encrypt_sk(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft,
|
||||
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
|
||||
};
|
||||
@@ -54,12 +54,12 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphism
|
||||
+ VecZnxAutomorphismInplace,
|
||||
+ VecZnxAutomorphismInplace<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -72,7 +72,7 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
self.rank_out(),
|
||||
rhs.rank_in(),
|
||||
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||
self.rank_out(),
|
||||
@@ -113,7 +113,7 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
|
||||
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i);
|
||||
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -138,17 +138,56 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphism
|
||||
+ VecZnxAutomorphismInplace,
|
||||
+ VecZnxAutomorphismInplace<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GGLWEAutomorphismKey<DataSelf> = self as *mut GGLWEAutomorphismKey<DataSelf>;
|
||||
self.automorphism(module, &*self_ptr, rhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_in(),
|
||||
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
let p: i64 = self.p();
|
||||
let p_inv = module.galois_element_inv(p);
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
|
||||
|
||||
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
|
||||
});
|
||||
|
||||
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
|
||||
res_ct.keyswitch_inplace(module, &rhs.key, scratch);
|
||||
|
||||
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace,
|
||||
VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
|
||||
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
|
||||
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
@@ -79,16 +79,16 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphismInplace
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigAllocBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -133,7 +133,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
)
|
||||
};
|
||||
|
||||
self.automorphism_internal(module, lhs, auto_key, scratch);
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
|
||||
});
|
||||
self.expand_row(module, tensor_key, scratch);
|
||||
}
|
||||
|
||||
@@ -149,49 +155,25 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphismInplace
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigAllocBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
|
||||
self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
fn automorphism_internal<DataLhs: DataRef, DataAk: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
lhs: &GGSWCiphertext<DataLhs>,
|
||||
auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphismInplace,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
|
||||
.automorphism_inplace(module, auto_key, scratch);
|
||||
});
|
||||
self.expand_row(module, tensor_key, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace,
|
||||
VecZnxBigSubSmallBInplace, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace,
|
||||
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace, VecZnxBigSubSmallBInplace,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig},
|
||||
};
|
||||
@@ -54,16 +55,16 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphismInplace,
|
||||
+ VecZnxAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
self.keyswitch(module, lhs, &rhs.key, scratch);
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -78,16 +79,16 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxAutomorphismInplace,
|
||||
+ VecZnxAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
self.keyswitch_inplace(module, &rhs.key, scratch);
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -103,8 +104,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
@@ -114,12 +115,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
{
|
||||
self.assert_keyswitch(module, lhs, &rhs.key, scratch);
|
||||
}
|
||||
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -134,17 +135,24 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.automorphism_add(module, &*self_ptr, rhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
self.assert_keyswitch_inplace(module, &rhs.key, scratch);
|
||||
}
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn automorphism_sub_ab<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
|
||||
@@ -159,8 +167,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
@@ -171,12 +179,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
{
|
||||
self.assert_keyswitch(module, lhs, &rhs.key, scratch);
|
||||
}
|
||||
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
|
||||
module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -191,18 +199,25 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallAInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.automorphism_sub_ab(module, &*self_ptr, rhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
self.assert_keyswitch_inplace(module, &rhs.key, scratch);
|
||||
}
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
|
||||
module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &self.data, i);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn automorphism_sub_ba<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
|
||||
@@ -217,8 +232,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
@@ -229,12 +244,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
{
|
||||
self.assert_keyswitch(module, lhs, &rhs.key, scratch);
|
||||
}
|
||||
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
|
||||
module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -249,17 +264,24 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.automorphism_sub_ba(module, &*self_ptr, rhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
self.assert_keyswitch_inplace(module, &rhs.key, scratch);
|
||||
}
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
|
||||
let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
|
||||
module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &self.data, i);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
|
||||
};
|
||||
@@ -60,8 +61,8 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeGLWECt,
|
||||
@@ -71,8 +72,8 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
assert_eq!(self.basek(), a.basek());
|
||||
assert_eq!(a.n(), ks.n());
|
||||
}
|
||||
let (mut tmp_glwe, scratch1) = scratch.take_glwe_ct(a.n(), a.basek(), self.k(), 1);
|
||||
tmp_glwe.keyswitch(module, a, &ks.0, scratch1);
|
||||
let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(a.n(), a.basek(), self.k(), 1);
|
||||
tmp_glwe.keyswitch(module, a, &ks.0, scratch_1);
|
||||
self.sample_extract(&tmp_glwe);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
|
||||
};
|
||||
@@ -43,8 +44,8 @@ impl<D: DataMut> GLWECiphertext<D> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeGLWECt,
|
||||
@@ -55,7 +56,7 @@ impl<D: DataMut> GLWECiphertext<D> {
|
||||
assert_eq!(self.basek(), self.basek());
|
||||
}
|
||||
|
||||
let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), lwe.basek(), lwe.k(), 1);
|
||||
let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), lwe.basek(), lwe.k(), 1);
|
||||
glwe.data.zero();
|
||||
|
||||
let n_lwe: usize = lwe.n();
|
||||
@@ -66,6 +67,6 @@ impl<D: DataMut> GLWECiphertext<D> {
|
||||
glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
|
||||
});
|
||||
|
||||
self.keyswitch(module, &glwe, &ksk.0, scratch1);
|
||||
self.keyswitch(module, &glwe, &ksk.0, scratch_1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, SvpApplyInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes,
|
||||
SvpApplyDftToDftInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch},
|
||||
};
|
||||
@@ -26,9 +26,9 @@ impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
|
||||
sk: &GLWESecretPrepared<DataSk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
Module<B>: VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddInplace<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
@@ -50,9 +50,9 @@ impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
|
||||
(1..cols).for_each(|i| {
|
||||
// ci_dft = DFT(a[i]) * DFT(s[i])
|
||||
let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n(), 1, self.size()); // TODO optimize size when pt << ct
|
||||
module.dft(1, 0, &mut ci_dft, 0, &self.data, i);
|
||||
module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1);
|
||||
let ci_big = module.vec_znx_idft_consume(ci_dft);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &self.data, i);
|
||||
module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
|
||||
let ci_big = module.vec_znx_idft_apply_consume(ci_dft);
|
||||
|
||||
// c0_big += a[i] * s[i]
|
||||
module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes,
|
||||
VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
source::Source,
|
||||
@@ -41,12 +41,12 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKeyCompressed<DataSelf> {
|
||||
Module<B>: VecZnxAutomorphism
|
||||
+ SvpPrepare<B>
|
||||
+ SvpPPolAllocBytes
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
|
||||
source::Source,
|
||||
@@ -37,9 +37,9 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
|
||||
Module<B>: VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes,
|
||||
VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
|
||||
source::Source,
|
||||
@@ -44,12 +44,12 @@ impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
|
||||
) where
|
||||
Module<B>: SvpPrepare<B>
|
||||
+ SvpPPolAllocBytes
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -90,9 +90,9 @@ impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
|
||||
|
||||
let n: usize = sk_in.n().max(sk_out.n());
|
||||
|
||||
let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
|
||||
let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank());
|
||||
(0..sk_in.rank()).for_each(|i| {
|
||||
module.vec_znx_switch_degree(
|
||||
module.vec_znx_switch_ring(
|
||||
&mut sk_in_tmp.as_vec_znx_mut(),
|
||||
i,
|
||||
&sk_in.data.as_vec_znx(),
|
||||
@@ -100,11 +100,11 @@ impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
|
||||
);
|
||||
});
|
||||
|
||||
let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
|
||||
let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank());
|
||||
{
|
||||
let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
|
||||
let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1);
|
||||
(0..sk_out.rank()).for_each(|i| {
|
||||
module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
|
||||
module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
|
||||
module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
|
||||
});
|
||||
}
|
||||
@@ -115,7 +115,7 @@ impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
|
||||
&sk_out_tmp,
|
||||
seed_xa,
|
||||
source_xe,
|
||||
scratch2,
|
||||
scratch_2,
|
||||
);
|
||||
self.sk_in_n = sk_in.n();
|
||||
self.sk_out_n = sk_out.n();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare,
|
||||
TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
|
||||
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
|
||||
TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
|
||||
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
source::Source,
|
||||
@@ -33,13 +33,13 @@ impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
|
||||
source_xe: &mut Source,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: SvpApply<B>
|
||||
+ IDFTTmpA<B>
|
||||
Module<B>: SvpApplyDftToDft<B>
|
||||
+ VecZnxIdftApplyTmpA<B>
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -48,7 +48,7 @@ impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
|
||||
+ VecZnxAddNormal
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ VecZnxAddScalarInplace
|
||||
+ SvpPrepare<B>
|
||||
+ SvpPPolAllocBytes
|
||||
@@ -70,39 +70,39 @@ impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
|
||||
let n: usize = sk.n();
|
||||
let rank: usize = self.rank();
|
||||
|
||||
let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank);
|
||||
sk_dft_prep.prepare(module, sk, scratch1);
|
||||
let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
|
||||
sk_dft_prep.prepare(module, sk, scratch_1);
|
||||
|
||||
let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
|
||||
let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
module.dft(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
|
||||
module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
|
||||
});
|
||||
|
||||
let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
|
||||
let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
|
||||
let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
|
||||
let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
|
||||
let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
|
||||
let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
|
||||
|
||||
let mut source_xa: Source = Source::new(seed_xa);
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
(i..rank).for_each(|j| {
|
||||
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
|
||||
module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
|
||||
|
||||
module.idft_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
|
||||
module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
|
||||
module.vec_znx_big_normalize(
|
||||
self.basek(),
|
||||
&mut sk_ij.data.as_vec_znx_mut(),
|
||||
0,
|
||||
&sk_ij_big,
|
||||
0,
|
||||
scratch5,
|
||||
scratch_5,
|
||||
);
|
||||
|
||||
let (seed_xa_tmp, _) = source_xa.branch();
|
||||
|
||||
self.at_mut(i, j)
|
||||
.encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch5);
|
||||
.encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5);
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
|
||||
source::Source,
|
||||
@@ -37,9 +37,9 @@ impl<DataSelf: DataMut> GGSWCiphertextCompressed<DataSelf> {
|
||||
Module<B>: VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
source::Source,
|
||||
@@ -35,9 +35,9 @@ impl<D: DataMut> GLWECiphertextCompressed<D> {
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -63,9 +63,9 @@ impl<D: DataMut> GLWECiphertextCompressed<D> {
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes,
|
||||
VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
source::Source,
|
||||
@@ -41,9 +41,9 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
Module<B>: VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -53,7 +53,7 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ SvpPrepare<B>
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ SvpPPolAllocBytes
|
||||
+ VecZnxAutomorphism,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
|
||||
source::Source,
|
||||
@@ -41,9 +41,9 @@ impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
|
||||
Module<B>: VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes,
|
||||
VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
|
||||
source::Source,
|
||||
@@ -55,9 +55,9 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
Module<B>: VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -67,7 +67,7 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ SvpPrepare<B>
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ SvpPPolAllocBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
|
||||
{
|
||||
@@ -100,9 +100,9 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
|
||||
let n: usize = sk_in.n().max(sk_out.n());
|
||||
|
||||
let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
|
||||
let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank());
|
||||
(0..sk_in.rank()).for_each(|i| {
|
||||
module.vec_znx_switch_degree(
|
||||
module.vec_znx_switch_ring(
|
||||
&mut sk_in_tmp.as_vec_znx_mut(),
|
||||
i,
|
||||
&sk_in.data.as_vec_znx(),
|
||||
@@ -110,11 +110,11 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
);
|
||||
});
|
||||
|
||||
let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
|
||||
let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank());
|
||||
{
|
||||
let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
|
||||
let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1);
|
||||
(0..sk_out.rank()).for_each(|i| {
|
||||
module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
|
||||
module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
|
||||
module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
|
||||
});
|
||||
}
|
||||
@@ -125,7 +125,7 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
&sk_out_tmp,
|
||||
source_xa,
|
||||
source_xe,
|
||||
scratch2,
|
||||
scratch_2,
|
||||
);
|
||||
self.sk_in_n = sk_in.n();
|
||||
self.sk_out_n = sk_out.n();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
|
||||
TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
|
||||
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
source::Source,
|
||||
@@ -41,14 +41,14 @@ impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
|
||||
source_xe: &mut Source,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: SvpApply<B>
|
||||
+ IDFTTmpA<B>
|
||||
Module<B>: SvpApplyDftToDft<B>
|
||||
+ VecZnxIdftApplyTmpA<B>
|
||||
+ VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -58,7 +58,7 @@ impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ SvpPrepare<B>
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ SvpPPolAllocBytes,
|
||||
Scratch<B>:
|
||||
TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
|
||||
@@ -73,35 +73,35 @@ impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
|
||||
|
||||
let rank: usize = self.rank();
|
||||
|
||||
let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank);
|
||||
sk_dft_prep.prepare(module, sk, scratch1);
|
||||
let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
|
||||
sk_dft_prep.prepare(module, sk, scratch_1);
|
||||
|
||||
let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
|
||||
let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
module.dft(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
|
||||
module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
|
||||
});
|
||||
|
||||
let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
|
||||
let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
|
||||
let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
|
||||
let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
|
||||
let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
|
||||
let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
(i..rank).for_each(|j| {
|
||||
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
|
||||
module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
|
||||
|
||||
module.idft_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
|
||||
module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
|
||||
module.vec_znx_big_normalize(
|
||||
self.basek(),
|
||||
&mut sk_ij.data.as_vec_znx_mut(),
|
||||
0,
|
||||
&sk_ij_big,
|
||||
0,
|
||||
scratch5,
|
||||
scratch_5,
|
||||
);
|
||||
|
||||
self.at_mut(i, j)
|
||||
.encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch5);
|
||||
.encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5);
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
|
||||
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero},
|
||||
source::Source,
|
||||
@@ -40,9 +40,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
Module<B>: VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -67,14 +67,14 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
let rank: usize = self.rank();
|
||||
let digits: usize = self.digits();
|
||||
|
||||
let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k);
|
||||
let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(self.n(), basek, k);
|
||||
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
tmp_pt.data.zero();
|
||||
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1);
|
||||
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_1);
|
||||
|
||||
(0..rank + 1).for_each(|col_j| {
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
@@ -85,7 +85,7 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
scratch1,
|
||||
scratch_1,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
|
||||
ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeSvpPPol,
|
||||
TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigAddNormal, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume,
|
||||
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero},
|
||||
source::Source,
|
||||
@@ -53,9 +53,9 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -92,9 +92,9 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -138,9 +138,9 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -179,8 +179,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: SvpPrepare<B>
|
||||
+ SvpApply<B>
|
||||
+ IDFTConsume<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddNormal<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
@@ -198,8 +198,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: SvpPrepare<B>
|
||||
+ SvpApply<B>
|
||||
+ IDFTConsume<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddNormal<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
@@ -226,8 +226,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: SvpPrepare<B>
|
||||
+ SvpApply<B>
|
||||
+ IDFTConsume<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddNormal<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
@@ -273,10 +273,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
(0..cols).for_each(|i| {
|
||||
let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), 1, size_pk);
|
||||
// ci_dft = DFT(u) * DFT(pk[i])
|
||||
module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
|
||||
module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i);
|
||||
|
||||
// ci_big = u * p[i]
|
||||
let mut ci_big = module.vec_znx_idft_consume(ci_dft);
|
||||
let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft);
|
||||
|
||||
// ci_big = u * pk[i] + e
|
||||
module.vec_znx_big_add_normal(basek, &mut ci_big, 0, pk.k(), source_xe, SIGMA, SIGMA_BOUND);
|
||||
@@ -311,9 +311,9 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -350,7 +350,7 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
|
||||
let col_ct: usize = if compressed { 0 } else { i };
|
||||
|
||||
// ct[i] = uniform (+ pt)
|
||||
module.vec_znx_fill_uniform(basek, ct, col_ct, k, source_xa);
|
||||
module.vec_znx_fill_uniform(basek, ct, col_ct, source_xa);
|
||||
|
||||
let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size);
|
||||
|
||||
@@ -361,16 +361,16 @@ pub(crate) fn glwe_encrypt_sk_internal<DataCt: DataMut, DataPt: DataRef, DataSk:
|
||||
if i == col {
|
||||
module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut ci, 0, scratch_3);
|
||||
module.dft(1, 0, &mut ci_dft, 0, &ci, 0);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0);
|
||||
} else {
|
||||
module.dft(1, 0, &mut ci_dft, 0, ct, col_ct);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
|
||||
}
|
||||
} else {
|
||||
module.dft(1, 0, &mut ci_dft, 0, ct, col_ct);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct);
|
||||
}
|
||||
|
||||
module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1);
|
||||
let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_consume(ci_dft);
|
||||
module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
|
||||
let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft);
|
||||
|
||||
// use c[0] as buffer, which is overwritten later by the normalization step
|
||||
module.vec_znx_big_normalize(basek, &mut ci, 0, &ci_big, 0, scratch_3);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxAddInplace, VecZnxAddNormal,
|
||||
VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, ScratchOwned},
|
||||
@@ -22,9 +22,9 @@ impl<D: DataMut> GLWEPublicKey<D> {
|
||||
Module<B>:,
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
|
||||
source::Source,
|
||||
@@ -38,13 +38,13 @@ impl<D: DataMut> GLWEToLWESwitchingKey<D> {
|
||||
) where
|
||||
DLwe: DataRef,
|
||||
DGlwe: DataRef,
|
||||
Module<B>: VecZnxAutomorphismInplace
|
||||
Module<B>: VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -54,7 +54,7 @@ impl<D: DataMut> GLWEToLWESwitchingKey<D> {
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ SvpPrepare<B>
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ SvpPPolAllocBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
|
||||
{
|
||||
@@ -63,10 +63,10 @@ impl<D: DataMut> GLWEToLWESwitchingKey<D> {
|
||||
assert!(sk_lwe.n() <= module.n());
|
||||
}
|
||||
|
||||
let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(sk_glwe.n(), 1);
|
||||
let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), 1);
|
||||
sk_lwe_as_glwe.data.zero();
|
||||
sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0));
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1);
|
||||
|
||||
self.0.encrypt_sk(
|
||||
module,
|
||||
@@ -74,7 +74,7 @@ impl<D: DataMut> GLWEToLWESwitchingKey<D> {
|
||||
&sk_lwe_as_glwe,
|
||||
source_xa,
|
||||
source_xe,
|
||||
scratch1,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ impl<DataSelf: DataMut> LWECiphertext<DataSelf> {
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
|
||||
module.zn_fill_uniform(self.n() + 1, basek, &mut self.data, 0, k, source_xa);
|
||||
module.zn_fill_uniform(self.n() + 1, basek, &mut self.data, 0, source_xa);
|
||||
|
||||
let mut tmp_znx: Zn<Vec<u8>> = Zn::alloc(1, 1, self.size());
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut},
|
||||
source::Source,
|
||||
@@ -38,13 +38,13 @@ impl<D: DataMut> LWESwitchingKey<D> {
|
||||
) where
|
||||
DIn: DataRef,
|
||||
DOut: DataRef,
|
||||
Module<B>: VecZnxAutomorphismInplace
|
||||
Module<B>: VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -54,7 +54,7 @@ impl<D: DataMut> LWESwitchingKey<D> {
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ SvpPrepare<B>
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ SvpPPolAllocBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
|
||||
{
|
||||
@@ -65,16 +65,16 @@ impl<D: DataMut> LWESwitchingKey<D> {
|
||||
assert!(self.n() <= module.n());
|
||||
}
|
||||
|
||||
let (mut sk_in_glwe, scratch1) = scratch.take_glwe_secret(self.n(), 1);
|
||||
let (mut sk_out_glwe, scratch2) = scratch1.take_glwe_secret(self.n(), 1);
|
||||
let (mut sk_in_glwe, scratch_1) = scratch.take_glwe_secret(self.n(), 1);
|
||||
let (mut sk_out_glwe, scratch_2) = scratch_1.take_glwe_secret(self.n(), 1);
|
||||
|
||||
sk_out_glwe.data.at_mut(0, 0)[..sk_lwe_out.n()].copy_from_slice(sk_lwe_out.data.at(0, 0));
|
||||
sk_out_glwe.data.at_mut(0, 0)[sk_lwe_out.n()..].fill(0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_out_glwe.data.as_vec_znx_mut(), 0, scratch_2);
|
||||
|
||||
sk_in_glwe.data.at_mut(0, 0)[..sk_lwe_in.n()].copy_from_slice(sk_lwe_in.data.at(0, 0));
|
||||
sk_in_glwe.data.at_mut(0, 0)[sk_lwe_in.n()..].fill(0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_in_glwe.data.as_vec_znx_mut(), 0, scratch_2);
|
||||
|
||||
self.0.encrypt_sk(
|
||||
module,
|
||||
@@ -82,7 +82,7 @@ impl<D: DataMut> LWESwitchingKey<D> {
|
||||
&sk_out_glwe,
|
||||
source_xa,
|
||||
source_xe,
|
||||
scratch2,
|
||||
scratch_2,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
|
||||
TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
|
||||
VecZnxSubABInplace, VecZnxSwithcDegree,
|
||||
ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
|
||||
VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut},
|
||||
source::Source,
|
||||
@@ -36,13 +36,13 @@ impl<D: DataMut> LWEToGLWESwitchingKey<D> {
|
||||
) where
|
||||
DLwe: DataRef,
|
||||
DGlwe: DataRef,
|
||||
Module<B>: VecZnxAutomorphismInplace
|
||||
Module<B>: VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxAddScalarInplace
|
||||
+ VecZnxDftAllocBytes
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ DFT<B>
|
||||
+ SvpApplyInplace<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ SvpApplyDftToDftInplace<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxSubABInplace
|
||||
@@ -52,7 +52,7 @@ impl<D: DataMut> LWEToGLWESwitchingKey<D> {
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxSub
|
||||
+ SvpPrepare<B>
|
||||
+ VecZnxSwithcDegree
|
||||
+ VecZnxSwitchRing
|
||||
+ SvpPPolAllocBytes,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
|
||||
{
|
||||
@@ -61,10 +61,10 @@ impl<D: DataMut> LWEToGLWESwitchingKey<D> {
|
||||
assert!(sk_lwe.n() <= module.n());
|
||||
}
|
||||
|
||||
let (mut sk_lwe_as_glwe, scratch1) = scratch.take_glwe_secret(sk_glwe.n(), 1);
|
||||
let (mut sk_lwe_as_glwe, scratch_1) = scratch.take_glwe_secret(sk_glwe.n(), 1);
|
||||
sk_lwe_as_glwe.data.at_mut(0, 0)[..sk_lwe.n()].copy_from_slice(sk_lwe.data.at(0, 0));
|
||||
sk_lwe_as_glwe.data.at_mut(0, 0)[sk_lwe.n()..].fill(0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0);
|
||||
module.vec_znx_automorphism_inplace(-1, &mut sk_lwe_as_glwe.data.as_vec_znx_mut(), 0, scratch_1);
|
||||
|
||||
self.0.encrypt_sk(
|
||||
module,
|
||||
@@ -72,7 +72,7 @@ impl<D: DataMut> LWEToGLWESwitchingKey<D> {
|
||||
sk_glwe,
|
||||
source_xa,
|
||||
source_xe,
|
||||
scratch1,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
|
||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
};
|
||||
@@ -51,10 +51,10 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
@@ -70,10 +70,10 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
|
||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
|
||||
};
|
||||
@@ -51,10 +51,10 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
@@ -106,10 +106,10 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
|
||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
|
||||
};
|
||||
@@ -51,10 +51,10 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
@@ -116,10 +116,10 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
|
||||
VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnxBig},
|
||||
};
|
||||
@@ -65,10 +65,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
@@ -101,8 +101,8 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
let cols: usize = rhs.rank() + 1;
|
||||
let digits: usize = rhs.digits();
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch2) = scratch1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits));
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, lhs.size().div_ceil(digits));
|
||||
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
@@ -121,21 +121,21 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.dft(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
|
||||
module.vec_znx_dft_apply(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch2);
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2);
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_consume(res_dft);
|
||||
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1);
|
||||
module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -148,16 +148,81 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.external_product(module, &*self_ptr, rhs, scratch);
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use poulpy_hal::api::ScratchAvailable;
|
||||
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(rhs.n(), self.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertext::external_product_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols: usize = rhs.rank() + 1;
|
||||
let digits: usize = rhs.digits();
|
||||
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), cols, rhs.size()); // Todo optimise
|
||||
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n(), cols, self.size().div_ceil(digits));
|
||||
|
||||
a_dft.data_mut().fill(0);
|
||||
|
||||
{
|
||||
(0..digits).for_each(|di| {
|
||||
// (lhs.size() + di) / digits = (a - (digit - di - 1)).div_ceil(digits)
|
||||
a_dft.set_size((self.size() + di) / digits);
|
||||
|
||||
// Small optimization for digits > 2
|
||||
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
|
||||
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
|
||||
// As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
|
||||
// It is possible to further ignore the last digits-1 limbs, but this introduce
|
||||
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
|
||||
// noise is kept with respect to the ideal functionality.
|
||||
res_dft.set_size(rhs.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft_apply(
|
||||
digits,
|
||||
digits - 1 - di,
|
||||
&mut a_dft,
|
||||
col_i,
|
||||
&self.data,
|
||||
col_i,
|
||||
);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch_1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ use std::collections::HashMap;
|
||||
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxBigSubSmallBInplace, VecZnxCopy, VecZnxDftAllocBytes, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxCopy,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNegateInplace, VecZnxNormalizeInplace, VecZnxRotate,
|
||||
VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
@@ -126,20 +126,20 @@ impl GLWEPacker {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxCopy
|
||||
+ VecZnxRotateInplace
|
||||
+ VecZnxRotateInplace<B>
|
||||
+ VecZnxSub
|
||||
+ VecZnxNegateInplace
|
||||
+ VecZnxRshInplace
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAutomorphismInplace
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
@@ -204,20 +204,20 @@ fn pack_core<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxCopy
|
||||
+ VecZnxRotateInplace
|
||||
+ VecZnxRotateInplace<B>
|
||||
+ VecZnxSub
|
||||
+ VecZnxNegateInplace
|
||||
+ VecZnxRshInplace
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAutomorphismInplace
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
@@ -301,20 +301,20 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxCopy
|
||||
+ VecZnxRotateInplace
|
||||
+ VecZnxRotateInplace<B>
|
||||
+ VecZnxSub
|
||||
+ VecZnxNegateInplace
|
||||
+ VecZnxRshInplace
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAutomorphismInplace
|
||||
+ VecZnxAutomorphismInplace<B>
|
||||
+ VecZnxBigSubSmallBInplace<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
@@ -349,15 +349,15 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
|
||||
|
||||
// a = a * X^-t
|
||||
a.rotate_inplace(module, -t);
|
||||
a.rotate_inplace(module, -t, scratch_1);
|
||||
|
||||
// tmp_b = a * X^-t - b
|
||||
tmp_b.sub(module, a, b);
|
||||
tmp_b.rsh(module, 1);
|
||||
tmp_b.rsh(module, 1, scratch_1);
|
||||
|
||||
// a = a * X^-t + b
|
||||
a.add_inplace(module, b);
|
||||
a.rsh(module, 1);
|
||||
a.rsh(module, 1, scratch_1);
|
||||
|
||||
tmp_b.normalize_inplace(module, scratch_1);
|
||||
|
||||
@@ -375,9 +375,9 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
// a = a + b * X^t - phi(a * X^-t - b) * X^t
|
||||
// = a + b * X^t - phi(a * X^-t - b) * - phi(X^t)
|
||||
// = a + b * X^t + phi(a - b * X^t)
|
||||
a.rotate_inplace(module, t);
|
||||
a.rotate_inplace(module, t, scratch_1);
|
||||
} else {
|
||||
a.rsh(module, 1);
|
||||
a.rsh(module, 1, scratch);
|
||||
// a = a + phi(a)
|
||||
if let Some(key) = auto_keys.get(&gal_el) {
|
||||
a.automorphism_add_inplace(module, key, scratch);
|
||||
@@ -388,7 +388,7 @@ fn combine<D: DataRef, DataAK: DataRef, B: Backend>(
|
||||
} else if let Some(b) = b {
|
||||
let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(n, basek, k, rank);
|
||||
tmp_b.rotate(module, 1 << (log_n - i - 1), b);
|
||||
tmp_b.rsh(module, 1);
|
||||
tmp_b.rsh(module, 1, scratch_1);
|
||||
|
||||
// a = (b* X^t - phi(b* X^t))
|
||||
if let Some(key) = auto_keys.get(&gal_el) {
|
||||
|
||||
@@ -2,9 +2,9 @@ use std::collections::HashMap;
|
||||
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace,
|
||||
VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxRshInplace, VmpApplyDftToDft,
|
||||
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxRshInplace,
|
||||
VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch},
|
||||
};
|
||||
@@ -73,12 +73,12 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxRshInplace
|
||||
+ VecZnxRshInplace<B>
|
||||
+ VecZnxCopy,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
@@ -99,16 +99,16 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxBigAutomorphismInplace<B>
|
||||
+ VecZnxRshInplace,
|
||||
+ VecZnxRshInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
{
|
||||
(start..end).for_each(|i| {
|
||||
self.rsh(module, 1);
|
||||
self.rsh(module, 1, scratch);
|
||||
|
||||
let p: i64 = if i == 0 {
|
||||
-1
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
|
||||
};
|
||||
@@ -56,8 +57,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
@@ -76,8 +77,8 @@ impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
|
||||
@@ -132,8 +133,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
@@ -161,6 +162,12 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
assert!(
|
||||
self.rows() <= lhs.rows(),
|
||||
"self.rows()={} > lhs.rows()={}",
|
||||
self.rows(),
|
||||
lhs.rows()
|
||||
);
|
||||
}
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
@@ -188,8 +195,8 @@ impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace,
|
||||
VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace,
|
||||
VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxDftCopy,
|
||||
VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VmpPMat, ZnxInfos},
|
||||
@@ -114,13 +114,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigAllocBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -150,8 +150,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxDftAllocBytes
|
||||
@@ -159,10 +159,15 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
self.keyswitch_internal(module, lhs, ksk, scratch);
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
|
||||
});
|
||||
self.expand_row(module, tsk, scratch);
|
||||
}
|
||||
|
||||
@@ -178,8 +183,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ VecZnxDftAllocBytes
|
||||
@@ -187,13 +192,16 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
|
||||
self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
|
||||
}
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.keyswitch_inplace(module, ksk, scratch);
|
||||
});
|
||||
self.expand_row(module, tsk, scratch);
|
||||
}
|
||||
|
||||
pub fn expand_row<DataTsk: DataRef, B: Backend>(
|
||||
@@ -206,13 +214,13 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigAllocBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftCopy<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ VecZnxBigNormalize<B>
|
||||
+ IDFTTmpA<B>,
|
||||
+ VecZnxIdftApplyTmpA<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
assert!(
|
||||
@@ -234,9 +242,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
// Pre-compute DFT of (a0, a1, a2)
|
||||
let (mut ci_dft, scratch1) = scratch.take_vec_znx_dft(n, cols, self.size());
|
||||
let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(n, cols, self.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.dft(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &self.at(row_i, 0).data, i);
|
||||
});
|
||||
|
||||
(1..cols).for_each(|col_j| {
|
||||
@@ -262,8 +270,8 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
|
||||
let digits: usize = tsk.digits();
|
||||
|
||||
let (mut tmp_dft_i, scratch2) = scratch1.take_vec_znx_dft(n, cols, tsk.size());
|
||||
let (mut tmp_a, scratch3) = scratch2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
|
||||
let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(n, cols, tsk.size());
|
||||
let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(n, 1, ci_dft.size().div_ceil(digits));
|
||||
|
||||
{
|
||||
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
|
||||
@@ -295,9 +303,9 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
|
||||
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
|
||||
if di == 0 && col_i == 1 {
|
||||
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch3);
|
||||
module.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch3);
|
||||
module.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -313,46 +321,19 @@ impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
|
||||
// =
|
||||
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
|
||||
module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0);
|
||||
let (mut tmp_idft, scratch3) = scratch2.take_vec_znx_big(n, 1, tsk.size());
|
||||
let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(n, 1, tsk.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
|
||||
module.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i);
|
||||
module.vec_znx_big_normalize(
|
||||
self.basek(),
|
||||
&mut self.at_mut(row_i, col_j).data,
|
||||
i,
|
||||
&tmp_idft,
|
||||
0,
|
||||
scratch3,
|
||||
scratch_3,
|
||||
);
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn keyswitch_internal<DataLhs: DataRef, DataKsk: DataRef, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
lhs: &GGSWCiphertext<DataLhs>,
|
||||
ksk: &GGLWESwitchingKeyPrepared<DataKsk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
{
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
self.at_mut(row_i, 0)
|
||||
.keyswitch(module, &lhs.at(row_i, 0), ksk, scratch);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
|
||||
};
|
||||
@@ -117,6 +118,63 @@ impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn assert_keyswitch_inplace<B: Backend, DataRhs>(
|
||||
&self,
|
||||
module: &Module<B>,
|
||||
rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
|
||||
scratch: &Scratch<B>,
|
||||
) where
|
||||
DataRhs: DataRef,
|
||||
Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
|
||||
Scratch<B>: ScratchAvailable,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank_out(),
|
||||
"self.rank(): {} != rhs.rank_out(): {}",
|
||||
self.rank(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(rhs.n(), self.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank_in(),
|
||||
rhs.rank_out(),
|
||||
),
|
||||
"scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank_in(),
|
||||
rhs.rank_out(),
|
||||
)={}",
|
||||
scratch.available(),
|
||||
GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
self.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank_in(),
|
||||
rhs.rank_out(),
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
@@ -130,11 +188,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
@@ -143,10 +200,10 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
{
|
||||
self.assert_keyswitch(module, lhs, rhs, scratch);
|
||||
}
|
||||
let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -162,16 +219,21 @@ impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.keyswitch(module, &*self_ptr, rhs, scratch);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
self.assert_keyswitch_inplace(module, rhs, scratch);
|
||||
}
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
|
||||
let res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, rhs, scratch_1);
|
||||
(0..self.cols()).for_each(|i| {
|
||||
module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch_1);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,8 +254,8 @@ impl<D: DataRef> GLWECiphertext<D> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B>,
|
||||
@@ -224,16 +286,17 @@ where
|
||||
DataRes: DataMut,
|
||||
DataIn: DataRef,
|
||||
DataVmp: DataRef,
|
||||
Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApplyDftToDft<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
|
||||
Module<B>:
|
||||
VecZnxDftAllocBytes + VecZnxDftApply<B> + VmpApplyDftToDft<B> + VecZnxIdftApplyConsume<B> + VecZnxBigAddSmallInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B>,
|
||||
{
|
||||
let cols: usize = a.cols();
|
||||
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
|
||||
(0..cols - 1).for_each(|col_i| {
|
||||
module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1);
|
||||
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a, col_i + 1);
|
||||
});
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
res_big
|
||||
}
|
||||
@@ -251,16 +314,16 @@ where
|
||||
DataIn: DataRef,
|
||||
DataVmp: DataRef,
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ DFT<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B>,
|
||||
{
|
||||
let cols: usize = a.cols();
|
||||
let size: usize = a.size();
|
||||
let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
|
||||
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
|
||||
|
||||
ai_dft.data_mut().fill(0);
|
||||
|
||||
@@ -277,18 +340,18 @@ where
|
||||
res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
|
||||
|
||||
(0..cols - 1).for_each(|col_i| {
|
||||
module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
|
||||
module.vec_znx_dft_apply(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
|
||||
});
|
||||
|
||||
if di == 0 {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch_1);
|
||||
} else {
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch1);
|
||||
module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch_1);
|
||||
}
|
||||
});
|
||||
|
||||
res_dft.set_size(res_dft.max_size());
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
|
||||
let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_apply_consume(res_dft);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
res_big
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
|
||||
VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
|
||||
VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
|
||||
};
|
||||
@@ -26,8 +27,8 @@ impl LWECiphertext<Vec<u8>> {
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
{
|
||||
@@ -51,8 +52,8 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VmpApplyDftToDft<B>
|
||||
+ VmpApplyDftToDftAdd<B>
|
||||
+ DFT<B>
|
||||
+ IDFTConsume<B>
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxIdftApplyConsume<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
|
||||
@@ -67,7 +68,7 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
let max_k: usize = self.k().max(a.k());
|
||||
let basek: usize = self.basek();
|
||||
|
||||
let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
|
||||
let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
|
||||
glwe.data.zero();
|
||||
|
||||
let n_lwe: usize = a.n();
|
||||
@@ -78,7 +79,7 @@ impl<DLwe: DataMut> LWECiphertext<DLwe> {
|
||||
glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
|
||||
});
|
||||
|
||||
glwe.keyswitch_inplace(module, &ksk.0, scratch1);
|
||||
glwe.keyswitch_inplace(module, &ksk.0, scratch_1);
|
||||
|
||||
self.sample_extract(&glwe);
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ impl<D: DataRef> fmt::Debug for GGLWEAutomorphismKeyCompressed<D> {
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for GGLWEAutomorphismKeyCompressed<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
self.key.fill_uniform(source);
|
||||
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
|
||||
self.key.fill_uniform(log_bound, source);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user