mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
24
poulpy-hal/src/reference/fft64/mod.rs
Normal file
24
poulpy-hal/src/reference/fft64/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
pub mod reim;
|
||||
pub mod reim4;
|
||||
pub mod svp;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
|
||||
pub(crate) fn assert_approx_eq_slice(a: &[f64], b: &[f64], tol: f64) {
|
||||
assert_eq!(a.len(), b.len(), "Slices have different lengths");
|
||||
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff: f64 = (x - y).abs();
|
||||
let scale: f64 = x.abs().max(y.abs()).max(1.0);
|
||||
assert!(
|
||||
diff <= tol * scale,
|
||||
"Difference at index {}: left={} right={} rel_diff={} > tol={}",
|
||||
i,
|
||||
x,
|
||||
y,
|
||||
diff / scale,
|
||||
tol
|
||||
);
|
||||
}
|
||||
}
|
||||
31
poulpy-hal/src/reference/fft64/reim/conversion.rs
Normal file
31
poulpy-hal/src/reference/fft64/reim/conversion.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
#[inline(always)]
|
||||
pub fn reim_from_znx_i64_ref(res: &mut [f64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_to_znx_i64_ref(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
let inv_div = 1. / divisor;
|
||||
for i in 0..res.len() {
|
||||
res[i] = (a[i] * inv_div).round() as i64
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_to_znx_i64_inplace_ref(res: &mut [f64], divisor: f64) {
|
||||
let inv_div = 1. / divisor;
|
||||
for ri in res {
|
||||
*ri = f64::from_bits(((*ri * inv_div).round() as i64) as u64)
|
||||
}
|
||||
}
|
||||
327
poulpy-hal/src/reference/fft64/reim/fft_ref.rs
Normal file
327
poulpy-hal/src/reference/fft64/reim/fft_ref.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => fft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => fft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => fft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => fft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fft_bfs_16_ref(m, re, im, omg, 0);
|
||||
} else {
|
||||
fft_rec_16_ref(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft_rec_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fft_bfs_16_ref(m, re, im, omg, pos);
|
||||
};
|
||||
|
||||
let h = m >> 1;
|
||||
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos = fft_rec_16_ref(h, re, im, omg, pos);
|
||||
pos = fft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn cplx_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let dr: R = *rb * omg_re - *ib * omg_im;
|
||||
let di: R = *rb * omg_im + *ib * omg_re;
|
||||
*rb = *ra - dr;
|
||||
*ib = *ia - di;
|
||||
*ra = *ra + dr;
|
||||
*ia = *ia + di;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn cplx_i_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let dr: R = *rb * omg_im + *ib * omg_re;
|
||||
let di: R = *rb * omg_re - *ib * omg_im;
|
||||
*rb = *ra + dr;
|
||||
*ib = *ia - di;
|
||||
*ra = *ra - dr;
|
||||
*ia = *ia + di;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
|
||||
let [ra, rb] = re;
|
||||
let [ia, ib] = im;
|
||||
let [romg, iomg] = omg;
|
||||
cplx_twiddle(ra, ia, rb, ib, romg, iomg);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
|
||||
let [re_0, re_1, re_2, re_3] = re;
|
||||
let [im_0, im_1, im_2, im_3] = im;
|
||||
|
||||
{
|
||||
let omg_0 = omg[0];
|
||||
let omg_1 = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0 = omg[2];
|
||||
let omg_1 = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
|
||||
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
|
||||
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
|
||||
|
||||
{
|
||||
let omg_0 = omg[0];
|
||||
let omg_1 = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
|
||||
cplx_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
|
||||
cplx_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2 = omg[2];
|
||||
let omg_3 = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_4 = omg[4];
|
||||
let omg_5 = omg[5];
|
||||
let omg_6 = omg[6];
|
||||
let omg_7 = omg[7];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
|
||||
cplx_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
|
||||
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft16_ref<R: Float + FloatConst + Debug>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
|
||||
let [
|
||||
re_0,
|
||||
re_1,
|
||||
re_2,
|
||||
re_3,
|
||||
re_4,
|
||||
re_5,
|
||||
re_6,
|
||||
re_7,
|
||||
re_8,
|
||||
re_9,
|
||||
re_10,
|
||||
re_11,
|
||||
re_12,
|
||||
re_13,
|
||||
re_14,
|
||||
re_15,
|
||||
] = re;
|
||||
let [
|
||||
im_0,
|
||||
im_1,
|
||||
im_2,
|
||||
im_3,
|
||||
im_4,
|
||||
im_5,
|
||||
im_6,
|
||||
im_7,
|
||||
im_8,
|
||||
im_9,
|
||||
im_10,
|
||||
im_11,
|
||||
im_12,
|
||||
im_13,
|
||||
im_14,
|
||||
im_15,
|
||||
] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
|
||||
cplx_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
|
||||
cplx_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
|
||||
|
||||
cplx_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
|
||||
cplx_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
|
||||
cplx_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
|
||||
cplx_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
|
||||
cplx_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
|
||||
cplx_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
|
||||
cplx_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
|
||||
|
||||
cplx_i_twiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[4];
|
||||
let omg_1: R = omg[5];
|
||||
let omg_2: R = omg[6];
|
||||
let omg_3: R = omg[7];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
cplx_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
|
||||
cplx_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
|
||||
|
||||
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[8];
|
||||
let omg_1: R = omg[9];
|
||||
let omg_2: R = omg[10];
|
||||
let omg_3: R = omg[11];
|
||||
let omg_4: R = omg[12];
|
||||
let omg_5: R = omg[13];
|
||||
let omg_6: R = omg[14];
|
||||
let omg_7: R = omg[15];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
|
||||
cplx_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
|
||||
cplx_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
|
||||
cplx_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
|
||||
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
|
||||
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
|
||||
cplx_i_twiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
|
||||
cplx_i_twiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h: usize = mm >> 1;
|
||||
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
mm = h
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
mm = h
|
||||
}
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
fft16_ref(
|
||||
as_arr_mut::<16, R>(&mut re[off..]),
|
||||
as_arr_mut::<16, R>(&mut im[off..]),
|
||||
*as_arr::<16, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
|
||||
let romg = omg[0];
|
||||
let iomg = omg[1];
|
||||
|
||||
let (re_lhs, re_rhs) = re.split_at_mut(h);
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn bitwiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
|
||||
let (r0, r2) = re.split_at_mut(2 * h);
|
||||
let (r0, r1) = r0.split_at_mut(h);
|
||||
let (r2, r3) = r2.split_at_mut(h);
|
||||
|
||||
let (i0, i2) = im.split_at_mut(2 * h);
|
||||
let (i0, i1) = i0.split_at_mut(h);
|
||||
let (i2, i3) = i2.split_at_mut(h);
|
||||
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_0, omg_1);
|
||||
cplx_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
|
||||
}
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_2, omg_3);
|
||||
cplx_i_twiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
|
||||
}
|
||||
}
|
||||
156
poulpy-hal/src/reference/fft64/reim/fft_vec.rs
Normal file
156
poulpy-hal/src/reference/fft64/reim/fft_vec.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
#[inline(always)]
|
||||
pub fn reim_add_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] + b[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_add_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] += a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] - b[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] -= a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] - res[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_negate_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = -a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_negate_inplace_ref(res: &mut [f64]) {
|
||||
for ri in res {
|
||||
*ri = -*ri
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_addmul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = br[i];
|
||||
let _bi: f64 = bi[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] += _rr;
|
||||
ri[i] += _ri;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_mul_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = rr[i];
|
||||
let _bi: f64 = ri[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] = _rr;
|
||||
ri[i] = _ri;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_mul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = br[i];
|
||||
let _bi: f64 = bi[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] = _rr;
|
||||
ri[i] = _ri;
|
||||
}
|
||||
}
|
||||
322
poulpy-hal/src/reference/fft64/reim/ifft_ref.rs
Normal file
322
poulpy-hal/src/reference/fft64/reim/ifft_ref.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
|
||||
|
||||
pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => ifft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => ifft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => ifft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => ifft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
ifft_bfs_16_ref(m, re, im, omg, 0);
|
||||
} else {
|
||||
ifft_rec_16_ref(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft_rec_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return ifft_bfs_16_ref(m, re, im, omg, pos);
|
||||
};
|
||||
let h: usize = m >> 1;
|
||||
pos = ifft_rec_16_ref(h, re, im, omg, pos);
|
||||
pos = ifft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
ifft16_ref(
|
||||
as_arr_mut::<16, R>(&mut re[off..]),
|
||||
as_arr_mut::<16, R>(&mut im[off..]),
|
||||
*as_arr::<16, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let r_diff: R = *ra - *rb;
|
||||
let i_diff: R = *ia - *ib;
|
||||
*ra = *ra + *rb;
|
||||
*ia = *ia + *ib;
|
||||
*rb = r_diff * omg_re - i_diff * omg_im;
|
||||
*ib = r_diff * omg_im + i_diff * omg_re;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_itwiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let r_diff: R = *ra - *rb;
|
||||
let i_diff: R = *ia - *ib;
|
||||
*ra = *ra + *rb;
|
||||
*ia = *ia + *ib;
|
||||
*rb = r_diff * omg_im + i_diff * omg_re;
|
||||
*ib = -r_diff * omg_re + i_diff * omg_im;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
|
||||
let [ra, rb] = re;
|
||||
let [ia, ib] = im;
|
||||
let [romg, iomg] = omg;
|
||||
inv_twiddle(ra, ia, rb, ib, romg, iomg);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
|
||||
let [re_0, re_1, re_2, re_3] = re;
|
||||
let [im_0, im_1, im_2, im_3] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[2];
|
||||
let omg_1: R = omg[3];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
|
||||
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
|
||||
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
|
||||
|
||||
{
|
||||
let omg_4: R = omg[0];
|
||||
let omg_5: R = omg[1];
|
||||
let omg_6: R = omg[2];
|
||||
let omg_7: R = omg[3];
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
|
||||
inv_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
|
||||
inv_itwiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[4];
|
||||
let omg_3: R = omg[5];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
|
||||
inv_itwiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
|
||||
inv_itwiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[6];
|
||||
let omg_1: R = omg[7];
|
||||
inv_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
|
||||
inv_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
|
||||
inv_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft16_ref<R: Float + FloatConst>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
|
||||
let [
|
||||
re_0,
|
||||
re_1,
|
||||
re_2,
|
||||
re_3,
|
||||
re_4,
|
||||
re_5,
|
||||
re_6,
|
||||
re_7,
|
||||
re_8,
|
||||
re_9,
|
||||
re_10,
|
||||
re_11,
|
||||
re_12,
|
||||
re_13,
|
||||
re_14,
|
||||
re_15,
|
||||
] = re;
|
||||
let [
|
||||
im_0,
|
||||
im_1,
|
||||
im_2,
|
||||
im_3,
|
||||
im_4,
|
||||
im_5,
|
||||
im_6,
|
||||
im_7,
|
||||
im_8,
|
||||
im_9,
|
||||
im_10,
|
||||
im_11,
|
||||
im_12,
|
||||
im_13,
|
||||
im_14,
|
||||
im_15,
|
||||
] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
let omg_4: R = omg[4];
|
||||
let omg_5: R = omg[5];
|
||||
let omg_6: R = omg[6];
|
||||
let omg_7: R = omg[7];
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
|
||||
inv_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
|
||||
inv_itwiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
|
||||
inv_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
|
||||
inv_itwiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
|
||||
inv_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
|
||||
inv_itwiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[8];
|
||||
let omg_1: R = omg[9];
|
||||
let omg_2: R = omg[10];
|
||||
let omg_3: R = omg[11];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
inv_itwiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
|
||||
inv_itwiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
|
||||
inv_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
|
||||
inv_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
|
||||
inv_itwiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
|
||||
inv_itwiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[12];
|
||||
let omg_3: R = omg[13];
|
||||
inv_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
|
||||
inv_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
|
||||
inv_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
|
||||
inv_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
|
||||
inv_itwiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
|
||||
inv_itwiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
|
||||
inv_itwiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
|
||||
inv_itwiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[14];
|
||||
let omg_1: R = omg[15];
|
||||
inv_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
|
||||
inv_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
|
||||
inv_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
|
||||
inv_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
|
||||
inv_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
|
||||
inv_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
|
||||
inv_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
|
||||
let romg = omg[0];
|
||||
let iomg = omg[1];
|
||||
|
||||
let (re_lhs, re_rhs) = re.split_at_mut(h);
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_bitwiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
|
||||
let (r0, r2) = re.split_at_mut(2 * h);
|
||||
let (r0, r1) = r0.split_at_mut(h);
|
||||
let (r2, r3) = r2.split_at_mut(h);
|
||||
|
||||
let (i0, i2) = im.split_at_mut(2 * h);
|
||||
let (i0, i1) = i0.split_at_mut(h);
|
||||
let (i2, i3) = i2.split_at_mut(h);
|
||||
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_0, omg_1);
|
||||
inv_itwiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
|
||||
}
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_2, omg_3);
|
||||
inv_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
|
||||
}
|
||||
}
|
||||
128
poulpy-hal/src/reference/fft64/reim/mod.rs
Normal file
128
poulpy-hal/src/reference/fft64/reim/mod.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
// ----------------------------------------------------------------------
|
||||
// DISCLAIMER
|
||||
//
|
||||
// This module contains code that has been directly ported from the
|
||||
// spqlios-arithmetic library
|
||||
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
|
||||
// under the Apache License, Version 2.0.
|
||||
//
|
||||
// The porting process from C to Rust was done with minimal changes
|
||||
// in order to preserve the semantics and performance characteristics
|
||||
// of the original implementation.
|
||||
//
|
||||
// Both Poulpy and spqlios-arithmetic are distributed under the terms
|
||||
// of the Apache License, Version 2.0. See the LICENSE file for details.
|
||||
//
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#![allow(bad_asm_style)]
|
||||
|
||||
mod conversion;
|
||||
mod fft_ref;
|
||||
mod fft_vec;
|
||||
mod ifft_ref;
|
||||
mod table_fft;
|
||||
mod table_ifft;
|
||||
mod zero;
|
||||
|
||||
pub use conversion::*;
|
||||
pub use fft_ref::*;
|
||||
pub use fft_vec::*;
|
||||
pub use ifft_ref::*;
|
||||
pub use table_fft::*;
|
||||
pub use table_ifft::*;
|
||||
pub use zero::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
unsafe { &*(x.as_ptr() as *const [R; size]) }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_mut<const size: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
unsafe { &mut *(x.as_mut_ptr() as *mut [R; size]) }
|
||||
}
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
#[inline(always)]
|
||||
pub(crate) fn frac_rev_bits<R: Float + FloatConst>(x: usize) -> R {
|
||||
let half: R = R::from(0.5).unwrap();
|
||||
|
||||
match x {
|
||||
0 => R::zero(),
|
||||
1 => half,
|
||||
_ => {
|
||||
if x.is_multiple_of(2) {
|
||||
frac_rev_bits::<R>(x >> 1) * half
|
||||
} else {
|
||||
frac_rev_bits::<R>(x >> 1) * half + half
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReimDFTExecute<D, T> {
|
||||
fn reim_dft_execute(table: &D, data: &mut [T]);
|
||||
}
|
||||
|
||||
pub trait ReimFromZnx {
|
||||
fn reim_from_znx(res: &mut [f64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ReimToZnx {
|
||||
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimToZnxInplace {
|
||||
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64);
|
||||
}
|
||||
|
||||
pub trait ReimAdd {
|
||||
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimAddInplace {
|
||||
fn reim_add_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSub {
|
||||
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubABInplace {
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubBAInplace {
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegate {
|
||||
fn reim_negate(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegateInplace {
|
||||
fn reim_negate_inplace(res: &mut [f64]);
|
||||
}
|
||||
|
||||
pub trait ReimMul {
|
||||
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimMulInplace {
|
||||
fn reim_mul_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimAddMul {
|
||||
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimCopy {
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimZero {
|
||||
fn reim_zero(res: &mut [f64]);
|
||||
}
|
||||
207
poulpy-hal/src/reference/fft64/reim/table_fft.rs
Normal file
207
poulpy-hal/src/reference/fft64/reim/table_fft.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
reference::fft64::reim::{ReimDFTExecute, fft_ref, frac_rev_bits},
|
||||
};
|
||||
|
||||
pub struct ReimFFTRef;
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTRef {
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
fft_ref(table.m, &table.omg, data);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
|
||||
m: usize,
|
||||
omg: Vec<R>,
|
||||
}
|
||||
|
||||
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::from(1. / 4.).unwrap();
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => {
|
||||
fill_fft2_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
4 => {
|
||||
fill_fft4_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
8 => {
|
||||
fill_fft8_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
16 => {
|
||||
fill_fft16_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fill_fft_bfs_16_omegas(m, quarter, &mut omg, 0);
|
||||
} else {
|
||||
fill_fft_rec_16_omegas(m, quarter, &mut omg, 0);
|
||||
}
|
||||
|
||||
Self { m, omg }
|
||||
}
|
||||
|
||||
pub fn m(&self) -> usize {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn omg(&self) -> &[R] {
|
||||
&self.omg
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 2);
|
||||
let angle: R = j / R::from(2).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle);
|
||||
omg_pos[1] = R::sin(two_pi * angle);
|
||||
pos + 2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 4);
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
pos + 4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 8);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
omg_pos[4] = R::cos(two_pi * angle_4);
|
||||
omg_pos[5] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[6] = R::sin(two_pi * angle_4);
|
||||
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
|
||||
pos + 8
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 16);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let _16th: R = R::from(1. / 16.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let angle_8: R = j / R::from(16).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
omg_pos[4] = R::cos(two_pi * angle_4);
|
||||
omg_pos[5] = R::sin(two_pi * angle_4);
|
||||
omg_pos[6] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[8] = R::cos(two_pi * angle_8);
|
||||
omg_pos[9] = R::cos(two_pi * (angle_8 + _8th));
|
||||
omg_pos[10] = R::cos(two_pi * (angle_8 + _16th));
|
||||
omg_pos[11] = R::cos(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[12] = R::sin(two_pi * angle_8);
|
||||
omg_pos[13] = R::sin(two_pi * (angle_8 + _8th));
|
||||
omg_pos[14] = R::sin(two_pi * (angle_8 + _16th));
|
||||
omg_pos[15] = R::sin(two_pi * (angle_8 + _8th + _16th));
|
||||
pos + 16
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft_bfs_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
let mut jj: R = j;
|
||||
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h = mm >> 1;
|
||||
let j: R = jj * R::from(0.5).unwrap();
|
||||
omg[pos] = R::cos(two_pi * j);
|
||||
omg[pos + 1] = R::sin(two_pi * j);
|
||||
pos += 2;
|
||||
mm = h;
|
||||
jj = j
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
let j: R = jj * R::from(1. / 4.).unwrap();
|
||||
for i in (0..m).step_by(mm) {
|
||||
let rs_0 = j + frac_rev_bits::<R>(i / mm) * R::from(1. / 4.).unwrap();
|
||||
let rs_1 = R::from(2).unwrap() * rs_0;
|
||||
omg[pos] = R::cos(two_pi * rs_1);
|
||||
omg[pos + 1] = R::sin(two_pi * rs_1);
|
||||
omg[pos + 2] = R::cos(two_pi * rs_0);
|
||||
omg[pos + 3] = R::sin(two_pi * rs_0);
|
||||
pos += 4;
|
||||
}
|
||||
mm = h;
|
||||
jj = j;
|
||||
}
|
||||
|
||||
for i in (0..m).step_by(16) {
|
||||
let j = jj + frac_rev_bits(i >> 4);
|
||||
fill_fft16_omegas(j, omg, pos);
|
||||
pos += 16
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft_rec_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fill_fft_bfs_16_omegas(m, j, omg, pos);
|
||||
}
|
||||
let h: usize = m >> 1;
|
||||
let s: R = j * R::from(0.5).unwrap();
|
||||
let _2pi = R::from(2).unwrap() * R::PI();
|
||||
omg[pos] = R::cos(_2pi * s);
|
||||
omg[pos + 1] = R::sin(_2pi * s);
|
||||
pos += 2;
|
||||
pos = fill_fft_rec_16_omegas(h, s, omg, pos);
|
||||
pos = fill_fft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ctwiddle_ref(ra: &mut f64, ia: &mut f64, rb: &mut f64, ib: &mut f64, omg_re: f64, omg_im: f64) {
|
||||
let dr: f64 = *rb * omg_re - *ib * omg_im;
|
||||
let di: f64 = *rb * omg_im + *ib * omg_re;
|
||||
*rb = *ra - dr;
|
||||
*ib = *ia - di;
|
||||
*ra += dr;
|
||||
*ia += di;
|
||||
}
|
||||
201
poulpy-hal/src/reference/fft64/reim/table_ifft.rs
Normal file
201
poulpy-hal/src/reference/fft64/reim/table_ifft.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
reference::fft64::reim::{ReimDFTExecute, frac_rev_bits, ifft_ref::ifft_ref},
|
||||
};
|
||||
|
||||
pub struct ReimIFFTRef;
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTRef {
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
ifft_ref(table.m, &table.omg, data);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
|
||||
m: usize,
|
||||
omg: Vec<R>,
|
||||
}
|
||||
|
||||
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::exp2(R::from(-2).unwrap());
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => {
|
||||
fill_ifft2_omegas::<R>(quarter, &mut omg, 0);
|
||||
}
|
||||
4 => {
|
||||
fill_ifft4_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
8 => {
|
||||
fill_ifft8_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
16 => {
|
||||
fill_ifft16_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fill_ifft_bfs_16_omegas(m, quarter, &mut omg, 0);
|
||||
} else {
|
||||
fill_ifft_rec_16_omegas(m, quarter, &mut omg, 0);
|
||||
}
|
||||
|
||||
Self { m, omg }
|
||||
}
|
||||
|
||||
pub fn execute(&self, data: &mut [R]) {
|
||||
ifft_ref(self.m, &self.omg, data);
|
||||
}
|
||||
|
||||
pub fn m(&self) -> usize {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn omg(&self) -> &[R] {
|
||||
&self.omg
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 2);
|
||||
let angle: R = j / R::exp2(R::from(2).unwrap());
|
||||
let two_pi: R = R::exp2(R::from(2).unwrap()) * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle);
|
||||
omg_pos[1] = -R::sin(two_pi * angle);
|
||||
pos + 2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 4);
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_2);
|
||||
omg_pos[1] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[2] = R::cos(two_pi * angle_1);
|
||||
omg_pos[3] = -R::sin(two_pi * angle_1);
|
||||
pos + 4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 8);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(2).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_4);
|
||||
omg_pos[1] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[2] = -R::sin(two_pi * angle_4);
|
||||
omg_pos[3] = -R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[4] = R::cos(two_pi * angle_2);
|
||||
omg_pos[5] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[6] = R::cos(two_pi * angle_1);
|
||||
omg_pos[7] = -R::sin(two_pi * angle_1);
|
||||
pos + 8
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 16);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let _16th: R = R::from(1. / 16.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let angle_8: R = j / R::from(16).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_8);
|
||||
omg_pos[1] = R::cos(two_pi * (angle_8 + _8th));
|
||||
omg_pos[2] = R::cos(two_pi * (angle_8 + _16th));
|
||||
omg_pos[3] = R::cos(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[4] = -R::sin(two_pi * angle_8);
|
||||
omg_pos[5] = -R::sin(two_pi * (angle_8 + _8th));
|
||||
omg_pos[6] = -R::sin(two_pi * (angle_8 + _16th));
|
||||
omg_pos[7] = -R::sin(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[8] = R::cos(two_pi * angle_4);
|
||||
omg_pos[9] = -R::sin(two_pi * angle_4);
|
||||
omg_pos[10] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[11] = -R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[12] = R::cos(two_pi * angle_2);
|
||||
omg_pos[13] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[14] = R::cos(two_pi * angle_1);
|
||||
omg_pos[15] = -R::sin(two_pi * angle_1);
|
||||
pos + 16
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft_bfs_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut jj: R = j * R::from(16).unwrap() / R::from(m).unwrap();
|
||||
|
||||
for i in (0..m).step_by(16) {
|
||||
let j = jj + frac_rev_bits(i >> 4);
|
||||
fill_ifft16_omegas(j, omg, pos);
|
||||
pos += 16
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for i in (0..m).step_by(mm) {
|
||||
let rs_0 = jj + frac_rev_bits::<R>(i / mm) / R::from(4).unwrap();
|
||||
let rs_1 = R::from(2).unwrap() * rs_0;
|
||||
omg[pos] = R::cos(two_pi * rs_0);
|
||||
omg[pos + 1] = -R::sin(two_pi * rs_0);
|
||||
omg[pos + 2] = R::cos(two_pi * rs_1);
|
||||
omg[pos + 3] = -R::sin(two_pi * rs_1);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
jj = jj * R::from(4).unwrap();
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
omg[pos] = R::cos(two_pi * jj);
|
||||
omg[pos + 1] = -R::sin(two_pi * jj);
|
||||
pos += 2;
|
||||
jj = jj * R::from(2).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(jj, j);
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft_rec_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fill_ifft_bfs_16_omegas(m, j, omg, pos);
|
||||
}
|
||||
let h: usize = m >> 1;
|
||||
let s: R = j / R::from(2).unwrap();
|
||||
pos = fill_ifft_rec_16_omegas(h, s, omg, pos);
|
||||
pos = fill_ifft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
|
||||
let _2pi = R::from(2).unwrap() * R::PI();
|
||||
omg[pos] = R::cos(_2pi * s);
|
||||
omg[pos + 1] = -R::sin(_2pi * s);
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
11
poulpy-hal/src/reference/fft64/reim/zero.rs
Normal file
11
poulpy-hal/src/reference/fft64/reim/zero.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
pub fn reim_zero_ref(res: &mut [f64]) {
|
||||
res.fill(0.);
|
||||
}
|
||||
|
||||
pub fn reim_copy_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
res.copy_from_slice(a);
|
||||
}
|
||||
209
poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs
Normal file
209
poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use crate::reference::fft64::reim::as_arr;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= 2 * rows * 4);
|
||||
|
||||
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
|
||||
chunk.copy_from_slice(&src[offset..offset + 4]);
|
||||
offset += m
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_1blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= offset + m + 4);
|
||||
debug_assert!(src.len() >= 8);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
dst_off[0] += src[0];
|
||||
dst_off[1] += src[1];
|
||||
dst_off[2] += src[2];
|
||||
dst_off[3] += src[3];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
dst_off[0] += src[4];
|
||||
dst_off[1] += src[5];
|
||||
dst_off[2] += src[6];
|
||||
dst_off[3] += src[7];
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= offset + 3 * m + 4);
|
||||
debug_assert!(src.len() >= 16);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
dst_off[0] += src[0];
|
||||
dst_off[1] += src[1];
|
||||
dst_off[2] += src[2];
|
||||
dst_off[3] += src[3];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
dst_off[0] += src[4];
|
||||
dst_off[1] += src[5];
|
||||
dst_off[2] += src[6];
|
||||
dst_off[3] += src[7];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[8..12]);
|
||||
} else {
|
||||
dst_off[0] += src[8];
|
||||
dst_off[1] += src[9];
|
||||
dst_off[2] += src[10];
|
||||
dst_off[3] += src[11];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[12..16]);
|
||||
} else {
|
||||
dst_off[0] += src[12];
|
||||
dst_off[1] += src[13];
|
||||
dst_off[2] += src[14];
|
||||
dst_off[3] += src[15];
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat1col_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 8 doubles: [re1(4), im1(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 8 doubles: [ar(4) | ai(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
|
||||
}
|
||||
|
||||
println!("u_ref: {:?}", &u[..nrows * 8]);
|
||||
println!("v_ref: {:?}", &v[..nrows * 8]);
|
||||
|
||||
let mut acc: [f64; 8] = [0f64; 8];
|
||||
let mut j = 0;
|
||||
for _ in 0..nrows {
|
||||
reim4_add_mul(&mut acc, as_arr(&u[j..]), as_arr(&v[j..]));
|
||||
j += 8;
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc);
|
||||
|
||||
println!("dst_ref: {:?}", &dst[..8]);
|
||||
println!();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat2cols_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 16 doubles: [re1(4), im1(4), re2(4), im2(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 16 doubles: [ar(4) | ai(4) | br(4) | bi(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows * 16 doubles"
|
||||
);
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
let mut acc_0: [f64; 8] = [0f64; 8];
|
||||
let mut acc_1: [f64; 8] = [0f64; 8];
|
||||
for i in 0..nrows {
|
||||
let _1j: usize = i << 3;
|
||||
let _2j: usize = i << 4;
|
||||
let u_j: &[f64; 8] = as_arr(&u[_1j..]);
|
||||
reim4_add_mul(&mut acc_0, u_j, as_arr(&v[_2j..]));
|
||||
reim4_add_mul(&mut acc_1, u_j, as_arr(&v[_2j + 8..]));
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc_0);
|
||||
dst[8..16].copy_from_slice(&acc_1);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat2cols_2ndcol_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 8 doubles: [re1(4), im1(4), re2(4), im2(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 16 doubles: [x | x | br(4) | bi(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
dst.len() >= 8,
|
||||
"dst must be at least 8 doubles but is {}",
|
||||
dst.len()
|
||||
);
|
||||
assert!(
|
||||
u.len() >= nrows * 8,
|
||||
"u must be at least nrows={} * 8 doubles but is {}",
|
||||
nrows,
|
||||
u.len()
|
||||
);
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows={} * 16 doubles but is {}",
|
||||
nrows,
|
||||
v.len()
|
||||
);
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
let mut acc: [f64; 8] = [0f64; 8];
|
||||
for i in 0..nrows {
|
||||
let _1j: usize = i << 3;
|
||||
let _2j: usize = i << 4;
|
||||
reim4_add_mul(&mut acc, as_arr(&u[_1j..]), as_arr(&v[_2j + 8..]));
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
|
||||
for k in 0..4 {
|
||||
let ar: f64 = a[k];
|
||||
let br: f64 = b[k];
|
||||
let ai: f64 = a[k + 4];
|
||||
let bi: f64 = b[k + 4];
|
||||
dst[k] += ar * br - ai * bi;
|
||||
dst[k + 4] += ar * bi + ai * br;
|
||||
}
|
||||
}
|
||||
27
poulpy-hal/src/reference/fft64/reim4/mod.rs
Normal file
27
poulpy-hal/src/reference/fft64/reim4/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
mod arithmetic_ref;
|
||||
|
||||
pub use arithmetic_ref::*;
|
||||
|
||||
pub trait Reim4Extract1Blk {
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save1Blk {
|
||||
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save2Blks {
|
||||
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat1ColProd {
|
||||
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat2ColsProd {
|
||||
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat2Cols2ndColProd {
|
||||
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
119
poulpy-hal/src/reference/fft64/svp.rs
Normal file
119
poulpy-hal/src/reference/fft64/svp.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, ScalarZnx, ScalarZnxToRef, SvpPPol, SvpPPolToMut, SvpPPolToRef, VecZnx, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
reference::fft64::reim::{ReimAddMul, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimMul, ReimMulInplace, ReimZero},
|
||||
};
|
||||
|
||||
pub fn svp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx,
|
||||
R: SvpPPolToMut<BE>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: SvpPPol<&mut [u8], BE> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
BE::reim_from_znx(res.at_mut(res_col, 0), a.at(a_col, 0));
|
||||
BE::reim_dft_execute(table, res.at_mut(res_col, 0));
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft<R, A, B, BE>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimZero + ReimFromZnx + ReimMulInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
let out: &mut [f64] = res.at_mut(res_col, j);
|
||||
BE::reim_from_znx(out, b.at(b_col, j));
|
||||
BE::reim_dft_execute(table, out);
|
||||
BE::reim_mul_inplace(out, ppol);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimMul + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
BE::reim_mul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAddMul + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
BE::reim_addmul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimMulInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..res.size() {
|
||||
BE::reim_mul_inplace(res.at_mut(res_col, j), ppol);
|
||||
}
|
||||
}
|
||||
521
poulpy-hal/src/reference/fft64/vec_znx_big.rs
Normal file
521
poulpy-hal/src/reference/fft64/vec_znx_big.rs
Normal file
@@ -0,0 +1,521 @@
|
||||
use std::f64::consts::SQRT_2;
|
||||
|
||||
use crate::{
|
||||
api::VecZnxBigAddNormal,
|
||||
layouts::{
|
||||
Backend, Module, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
|
||||
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_big_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_small<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_small_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism<R, A, BE>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_automorphism::<_, _, BE>(p, &mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism_inplace<R, BE>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_automorphism_inplace::<_, BE>(p, &mut res_vznx, res_col, tmp);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_negate<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxNegate + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_negate::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_negate_inplace<R, BE>(res: &mut R, res_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_negate_inplace::<_, BE>(&mut res_vznx, res_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
BE: Backend<ScalarBig = i64>
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxZero,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], B> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxBigAddNormal<B>,
|
||||
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
let sqrt2: f64 = SQRT_2;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(basek, col_i) * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
sigma * sqrt2
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- B - A
|
||||
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_small_a<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, a, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_small_b<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
|
||||
}
|
||||
|
||||
/// R <- R - A
|
||||
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - R
|
||||
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
369
poulpy-hal/src/reference/fft64/vec_znx_dft.rs
Normal file
369
poulpy-hal/src/reference/fft64/vec_znx_dft.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use bytemuck::cast_slice_mut;
|
||||
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, Data, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos,
|
||||
ZnxView, ZnxViewMut,
|
||||
},
|
||||
reference::{
|
||||
fft64::reim::{
|
||||
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
|
||||
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
|
||||
},
|
||||
znx::ZnxZero,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn vec_znx_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAdd + ReimCopy + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAddInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let steps: usize = a.size().div_ceil(step);
|
||||
let min_steps: usize = res.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a.size() {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res.size()).for_each(|j| {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
})
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_apply<R, A, BE>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(step > 0);
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let a_size: usize = a.size();
|
||||
let res_size: usize = res.size();
|
||||
|
||||
let steps: usize = a_size.div_ceil(step);
|
||||
let min_steps: usize = res_size.min(steps);
|
||||
|
||||
for j in 0..min_steps {
|
||||
let limb = offset + j * step;
|
||||
if limb < a_size {
|
||||
BE::reim_from_znx(res.at_mut(res_col, j), a.at(a_col, limb));
|
||||
BE::reim_dft_execute(table, res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
(min_steps..res.size()).for_each(|j| {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64>
|
||||
+ ReimDFTExecute<ReimIFFTTable<f64>, f64>
|
||||
+ ReimCopy
|
||||
+ ReimToZnxInplace
|
||||
+ ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for j in 0..min_size {
|
||||
let res_slice_f64: &mut [f64] = cast_slice_mut(res.at_mut(res_col, j));
|
||||
BE::reim_copy(res_slice_f64, a.at(a_col, j));
|
||||
BE::reim_dft_execute(table, res_slice_f64);
|
||||
BE::reim_to_znx_inplace(res_slice_f64, divisor);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply_tmpa<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnx + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxDftToMut<BE>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for j in 0..min_size {
|
||||
BE::reim_dft_execute(table, a.at_mut(a_col, j));
|
||||
BE::reim_to_znx(res.at_mut(res_col, j), divisor, a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply_consume<D: Data, BE>(table: &ReimIFFTTable<f64>, mut res: VecZnxDft<D, BE>) -> VecZnxBig<D, BE>
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnxInplace,
|
||||
VecZnxDft<D, BE>: VecZnxDftToMut<BE>,
|
||||
{
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
}
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for i in 0..res.cols() {
|
||||
for j in 0..res.size() {
|
||||
BE::reim_dft_execute(table, res.at_mut(i, j));
|
||||
BE::reim_to_znx_inplace(res.at_mut(i, j), divisor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res.into_big()
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSub + ReimNegate + ReimZero + ReimCopy,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_negate(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
BE::reim_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_zero<R, BE>(res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
BE: Backend<ScalarPrep = f64> + ReimZero,
|
||||
{
|
||||
BE::reim_zero(res.to_mut().raw_mut());
|
||||
}
|
||||
365
poulpy-hal/src/reference/fft64/vmp.rs
Normal file
365
poulpy-hal/src/reference/fft64/vmp.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use crate::{
|
||||
cast_mut,
|
||||
layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
|
||||
oep::VecZnxDftAllocBytesImpl,
|
||||
reference::fft64::{
|
||||
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
||||
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
|
||||
vec_znx_dft::vec_znx_dft_apply,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::layouts::{Backend, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatToRef, ZnxInfos};
|
||||
|
||||
pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
R: VmpPMatToMut<BE>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: crate::layouts::VmpPMat<&mut [u8], BE> = pmat.to_mut();
|
||||
let a: MatZnx<&[u8]> = mat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
"res.cols_in: {} != a.cols_in: {}",
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
"res.cols_out: {} != a.cols_out: {}",
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let nrows: usize = a.cols_in() * a.rows();
|
||||
let ncols: usize = a.cols_out() * a.size();
|
||||
vmp_prepare_core::<BE>(table, res.raw_mut(), a.raw(), nrows, ncols, tmp);
|
||||
}
|
||||
|
||||
pub(crate) fn vmp_prepare_core<REIM>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
pmat: &mut [f64],
|
||||
mat: &[i64],
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
tmp: &mut [f64],
|
||||
) where
|
||||
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
{
|
||||
let m: usize = table.m();
|
||||
let n: usize = m << 1;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n >= 8);
|
||||
assert_eq!(mat.len(), n * nrows * ncols);
|
||||
assert_eq!(pmat.len(), n * nrows * ncols);
|
||||
assert_eq!(tmp.len(), vmp_prepare_tmp_bytes(n) / size_of::<i64>())
|
||||
}
|
||||
|
||||
let offset: usize = nrows * ncols * 8;
|
||||
|
||||
for row_i in 0..nrows {
|
||||
for col_i in 0..ncols {
|
||||
let pos: usize = n * (row_i * ncols + col_i);
|
||||
|
||||
REIM::reim_from_znx(tmp, &mat[pos..pos + n]);
|
||||
REIM::reim_dft_execute(table, tmp);
|
||||
|
||||
let dst: &mut [f64] = if col_i == (ncols - 1) && !ncols.is_multiple_of(2) {
|
||||
&mut pmat[col_i * nrows * 8 + row_i * 8..]
|
||||
} else {
|
||||
&mut pmat[(col_i / 2) * (nrows * 16) + row_i * 16 + (col_i % 2) * 8..]
|
||||
};
|
||||
|
||||
for blk_i in 0..m >> 2 {
|
||||
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_tmp_bytes(n: usize, a_size: usize, prows: usize, pcols_in: usize) -> usize {
|
||||
let row_max: usize = (a_size).min(prows);
|
||||
(16 + (n + 8) * row_max * pcols_in) * size_of::<f64>()
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft<R, A, M, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ VecZnxDftAllocBytesImpl<BE>
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk
|
||||
+ ReimFromZnx,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
let n: usize = a.n();
|
||||
let cols: usize = pmat.cols_in();
|
||||
let size: usize = a.size().min(pmat.rows());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(tmp_bytes.len() >= vmp_apply_dft_tmp_bytes(n, size, pmat.rows(), cols));
|
||||
assert!(a.cols() <= cols);
|
||||
}
|
||||
|
||||
let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size));
|
||||
|
||||
let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size);
|
||||
|
||||
let offset: usize = cols - a.cols();
|
||||
for j in 0..cols {
|
||||
vec_znx_dft_apply(table, 1, 0, &mut a_dft, j, &a, offset + j);
|
||||
}
|
||||
|
||||
vmp_apply_dft_to_dft(res, &a_dft, &pmat, tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usize) -> usize {
|
||||
let row_max: usize = (a_size).min(prows);
|
||||
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
use crate::layouts::{ZnxView, ZnxViewMut};
|
||||
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), pmat.n());
|
||||
assert_eq!(a.n(), pmat.n());
|
||||
assert_eq!(res.cols(), pmat.cols_out());
|
||||
assert_eq!(a.cols(), pmat.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let nrows: usize = pmat.cols_in() * pmat.rows();
|
||||
let ncols: usize = pmat.cols_out() * pmat.size();
|
||||
|
||||
let pmat_raw: &[f64] = pmat.raw();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<true, BE>(n, res_raw, a_raw, pmat_raw, 0, nrows, ncols, tmp_bytes)
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_offset: usize, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
use crate::layouts::{ZnxView, ZnxViewMut};
|
||||
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), pmat.n());
|
||||
assert_eq!(a.n(), pmat.n());
|
||||
assert_eq!(res.cols(), pmat.cols_out());
|
||||
assert_eq!(a.cols(), pmat.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let nrows: usize = pmat.cols_in() * pmat.rows();
|
||||
let ncols: usize = pmat.cols_out() * pmat.size();
|
||||
|
||||
let pmat_raw: &[f64] = pmat.raw();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<false, BE>(
|
||||
n,
|
||||
res_raw,
|
||||
a_raw,
|
||||
pmat_raw,
|
||||
limb_offset,
|
||||
nrows,
|
||||
ncols,
|
||||
tmp_bytes,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
|
||||
n: usize,
|
||||
res: &mut [f64],
|
||||
a: &[f64],
|
||||
pmat: &[f64],
|
||||
limb_offset: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
tmp_bytes: &mut [f64],
|
||||
) where
|
||||
REIM: ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n >= 8);
|
||||
assert!(n.is_power_of_two());
|
||||
assert_eq!(pmat.len(), n * nrows * ncols);
|
||||
assert!(res.len() & (n - 1) == 0);
|
||||
assert!(a.len() & (n - 1) == 0);
|
||||
}
|
||||
|
||||
let a_size: usize = a.len() / n;
|
||||
let res_size: usize = res.len() / n;
|
||||
|
||||
let m: usize = n >> 1;
|
||||
|
||||
let (mat2cols_output, extracted_blk) = tmp_bytes.split_at_mut(16);
|
||||
|
||||
let row_max: usize = nrows.min(a_size);
|
||||
let col_max: usize = ncols.min(res_size);
|
||||
|
||||
if limb_offset >= col_max {
|
||||
if OVERWRITE {
|
||||
REIM::reim_zero(res);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for blk_i in 0..(m >> 2) {
|
||||
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
|
||||
|
||||
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
|
||||
|
||||
if limb_offset.is_multiple_of(2) {
|
||||
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
} else {
|
||||
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
|
||||
REIM::reim4_mat2cols_2ndcol_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
|
||||
|
||||
for (col_res, col_pmat) in (1..)
|
||||
.step_by(2)
|
||||
.zip((limb_offset + 1..col_max - 1).step_by(2))
|
||||
{
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
}
|
||||
|
||||
if !col_max.is_multiple_of(2) {
|
||||
let last_col: usize = col_max - 1;
|
||||
let col_offset: usize = last_col * (8 * nrows);
|
||||
|
||||
if last_col >= limb_offset {
|
||||
if ncols == col_max {
|
||||
REIM::reim4_mat1col_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
} else {
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
}
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(
|
||||
m,
|
||||
blk_i,
|
||||
&mut res[(last_col - limb_offset) * n..],
|
||||
mat2cols_output,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
REIM::reim_zero(&mut res[col_max * n..]);
|
||||
}
|
||||
Reference in New Issue
Block a user