mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Ref. + AVX code & generic tests + benches (#85)
This commit is contained in:
committed by
GitHub
parent
99b9e3e10e
commit
56dbd29c59
24
poulpy-hal/src/reference/fft64/mod.rs
Normal file
24
poulpy-hal/src/reference/fft64/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
pub mod reim;
|
||||
pub mod reim4;
|
||||
pub mod svp;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
|
||||
pub(crate) fn assert_approx_eq_slice(a: &[f64], b: &[f64], tol: f64) {
|
||||
assert_eq!(a.len(), b.len(), "Slices have different lengths");
|
||||
|
||||
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
|
||||
let diff: f64 = (x - y).abs();
|
||||
let scale: f64 = x.abs().max(y.abs()).max(1.0);
|
||||
assert!(
|
||||
diff <= tol * scale,
|
||||
"Difference at index {}: left={} right={} rel_diff={} > tol={}",
|
||||
i,
|
||||
x,
|
||||
y,
|
||||
diff / scale,
|
||||
tol
|
||||
);
|
||||
}
|
||||
}
|
||||
31
poulpy-hal/src/reference/fft64/reim/conversion.rs
Normal file
31
poulpy-hal/src/reference/fft64/reim/conversion.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
#[inline(always)]
|
||||
pub fn reim_from_znx_i64_ref(res: &mut [f64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_to_znx_i64_ref(res: &mut [i64], divisor: f64, a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
let inv_div = 1. / divisor;
|
||||
for i in 0..res.len() {
|
||||
res[i] = (a[i] * inv_div).round() as i64
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_to_znx_i64_inplace_ref(res: &mut [f64], divisor: f64) {
|
||||
let inv_div = 1. / divisor;
|
||||
for ri in res {
|
||||
*ri = f64::from_bits(((*ri * inv_div).round() as i64) as u64)
|
||||
}
|
||||
}
|
||||
327
poulpy-hal/src/reference/fft64/reim/fft_ref.rs
Normal file
327
poulpy-hal/src/reference/fft64/reim/fft_ref.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => fft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => fft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => fft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => fft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fft_bfs_16_ref(m, re, im, omg, 0);
|
||||
} else {
|
||||
fft_rec_16_ref(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft_rec_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fft_bfs_16_ref(m, re, im, omg, pos);
|
||||
};
|
||||
|
||||
let h = m >> 1;
|
||||
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos = fft_rec_16_ref(h, re, im, omg, pos);
|
||||
pos = fft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn cplx_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let dr: R = *rb * omg_re - *ib * omg_im;
|
||||
let di: R = *rb * omg_im + *ib * omg_re;
|
||||
*rb = *ra - dr;
|
||||
*ib = *ia - di;
|
||||
*ra = *ra + dr;
|
||||
*ia = *ia + di;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn cplx_i_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let dr: R = *rb * omg_im + *ib * omg_re;
|
||||
let di: R = *rb * omg_re - *ib * omg_im;
|
||||
*rb = *ra + dr;
|
||||
*ib = *ia - di;
|
||||
*ra = *ra - dr;
|
||||
*ia = *ia + di;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
|
||||
let [ra, rb] = re;
|
||||
let [ia, ib] = im;
|
||||
let [romg, iomg] = omg;
|
||||
cplx_twiddle(ra, ia, rb, ib, romg, iomg);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
|
||||
let [re_0, re_1, re_2, re_3] = re;
|
||||
let [im_0, im_1, im_2, im_3] = im;
|
||||
|
||||
{
|
||||
let omg_0 = omg[0];
|
||||
let omg_1 = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0 = omg[2];
|
||||
let omg_1 = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
|
||||
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
|
||||
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
|
||||
|
||||
{
|
||||
let omg_0 = omg[0];
|
||||
let omg_1 = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
|
||||
cplx_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
|
||||
cplx_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2 = omg[2];
|
||||
let omg_3 = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_4 = omg[4];
|
||||
let omg_5 = omg[5];
|
||||
let omg_6 = omg[6];
|
||||
let omg_7 = omg[7];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
|
||||
cplx_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
|
||||
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft16_ref<R: Float + FloatConst + Debug>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
|
||||
let [
|
||||
re_0,
|
||||
re_1,
|
||||
re_2,
|
||||
re_3,
|
||||
re_4,
|
||||
re_5,
|
||||
re_6,
|
||||
re_7,
|
||||
re_8,
|
||||
re_9,
|
||||
re_10,
|
||||
re_11,
|
||||
re_12,
|
||||
re_13,
|
||||
re_14,
|
||||
re_15,
|
||||
] = re;
|
||||
let [
|
||||
im_0,
|
||||
im_1,
|
||||
im_2,
|
||||
im_3,
|
||||
im_4,
|
||||
im_5,
|
||||
im_6,
|
||||
im_7,
|
||||
im_8,
|
||||
im_9,
|
||||
im_10,
|
||||
im_11,
|
||||
im_12,
|
||||
im_13,
|
||||
im_14,
|
||||
im_15,
|
||||
] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
cplx_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
|
||||
cplx_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
|
||||
cplx_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
|
||||
|
||||
cplx_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
|
||||
cplx_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
|
||||
cplx_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
|
||||
cplx_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
cplx_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
|
||||
cplx_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
|
||||
cplx_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
|
||||
cplx_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
|
||||
|
||||
cplx_i_twiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[4];
|
||||
let omg_1: R = omg[5];
|
||||
let omg_2: R = omg[6];
|
||||
let omg_3: R = omg[7];
|
||||
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
cplx_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
|
||||
cplx_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
|
||||
|
||||
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
|
||||
cplx_i_twiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
|
||||
cplx_i_twiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[8];
|
||||
let omg_1: R = omg[9];
|
||||
let omg_2: R = omg[10];
|
||||
let omg_3: R = omg[11];
|
||||
let omg_4: R = omg[12];
|
||||
let omg_5: R = omg[13];
|
||||
let omg_6: R = omg[14];
|
||||
let omg_7: R = omg[15];
|
||||
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
|
||||
cplx_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
|
||||
cplx_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
|
||||
cplx_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
|
||||
|
||||
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
|
||||
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
|
||||
cplx_i_twiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
|
||||
cplx_i_twiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h: usize = mm >> 1;
|
||||
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
mm = h
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
bitwiddle_fft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
mm = h
|
||||
}
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
fft16_ref(
|
||||
as_arr_mut::<16, R>(&mut re[off..]),
|
||||
as_arr_mut::<16, R>(&mut im[off..]),
|
||||
*as_arr::<16, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
|
||||
let romg = omg[0];
|
||||
let iomg = omg[1];
|
||||
|
||||
let (re_lhs, re_rhs) = re.split_at_mut(h);
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn bitwiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
|
||||
let (r0, r2) = re.split_at_mut(2 * h);
|
||||
let (r0, r1) = r0.split_at_mut(h);
|
||||
let (r2, r3) = r2.split_at_mut(h);
|
||||
|
||||
let (i0, i2) = im.split_at_mut(2 * h);
|
||||
let (i0, i1) = i0.split_at_mut(h);
|
||||
let (i2, i3) = i2.split_at_mut(h);
|
||||
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_0, omg_1);
|
||||
cplx_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
|
||||
}
|
||||
|
||||
for i in 0..h {
|
||||
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_2, omg_3);
|
||||
cplx_i_twiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
|
||||
}
|
||||
}
|
||||
156
poulpy-hal/src/reference/fft64/reim/fft_vec.rs
Normal file
156
poulpy-hal/src/reference/fft64/reim/fft_vec.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
#[inline(always)]
|
||||
pub fn reim_add_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] + b[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_add_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] += a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] - b[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] -= a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = a[i] - res[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_negate_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = -a[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_negate_inplace_ref(res: &mut [f64]) {
|
||||
for ri in res {
|
||||
*ri = -*ri
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_addmul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = br[i];
|
||||
let _bi: f64 = bi[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] += _rr;
|
||||
ri[i] += _ri;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_mul_inplace_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = rr[i];
|
||||
let _bi: f64 = ri[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] = _rr;
|
||||
ri[i] = _ri;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim_mul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), res.len());
|
||||
assert_eq!(b.len(), res.len());
|
||||
}
|
||||
|
||||
let m: usize = res.len() >> 1;
|
||||
|
||||
let (rr, ri) = res.split_at_mut(m);
|
||||
let (ar, ai) = a.split_at(m);
|
||||
let (br, bi) = b.split_at(m);
|
||||
|
||||
for i in 0..m {
|
||||
let _ar: f64 = ar[i];
|
||||
let _ai: f64 = ai[i];
|
||||
let _br: f64 = br[i];
|
||||
let _bi: f64 = bi[i];
|
||||
let _rr: f64 = _ar * _br - _ai * _bi;
|
||||
let _ri: f64 = _ar * _bi + _ai * _br;
|
||||
rr[i] = _rr;
|
||||
ri[i] = _ri;
|
||||
}
|
||||
}
|
||||
322
poulpy-hal/src/reference/fft64/reim/ifft_ref.rs
Normal file
322
poulpy-hal/src/reference/fft64/reim/ifft_ref.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
|
||||
|
||||
pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
|
||||
assert!(data.len() == 2 * m);
|
||||
let (re, im) = data.split_at_mut(m);
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => ifft2_ref(
|
||||
as_arr_mut::<2, R>(re),
|
||||
as_arr_mut::<2, R>(im),
|
||||
*as_arr::<2, R>(omg),
|
||||
),
|
||||
4 => ifft4_ref(
|
||||
as_arr_mut::<4, R>(re),
|
||||
as_arr_mut::<4, R>(im),
|
||||
*as_arr::<4, R>(omg),
|
||||
),
|
||||
8 => ifft8_ref(
|
||||
as_arr_mut::<8, R>(re),
|
||||
as_arr_mut::<8, R>(im),
|
||||
*as_arr::<8, R>(omg),
|
||||
),
|
||||
16 => ifft16_ref(
|
||||
as_arr_mut::<16, R>(re),
|
||||
as_arr_mut::<16, R>(im),
|
||||
*as_arr::<16, R>(omg),
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
ifft_bfs_16_ref(m, re, im, omg, 0);
|
||||
} else {
|
||||
ifft_rec_16_ref(m, re, im, omg, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft_rec_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return ifft_bfs_16_ref(m, re, im, omg, pos);
|
||||
};
|
||||
let h: usize = m >> 1;
|
||||
pos = ifft_rec_16_ref(h, re, im, omg, pos);
|
||||
pos = ifft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
|
||||
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
|
||||
for off in (0..m).step_by(16) {
|
||||
ifft16_ref(
|
||||
as_arr_mut::<16, R>(&mut re[off..]),
|
||||
as_arr_mut::<16, R>(&mut im[off..]),
|
||||
*as_arr::<16, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 16;
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for off in (0..m).step_by(mm) {
|
||||
inv_bitwiddle_ifft_ref(
|
||||
h,
|
||||
&mut re[off..],
|
||||
&mut im[off..],
|
||||
as_arr::<4, R>(&omg[pos..]),
|
||||
);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
|
||||
pos += 2;
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let r_diff: R = *ra - *rb;
|
||||
let i_diff: R = *ia - *ib;
|
||||
*ra = *ra + *rb;
|
||||
*ia = *ia + *ib;
|
||||
*rb = r_diff * omg_re - i_diff * omg_im;
|
||||
*ib = r_diff * omg_im + i_diff * omg_re;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_itwiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
|
||||
let r_diff: R = *ra - *rb;
|
||||
let i_diff: R = *ia - *ib;
|
||||
*ra = *ra + *rb;
|
||||
*ia = *ia + *ib;
|
||||
*rb = r_diff * omg_im + i_diff * omg_re;
|
||||
*ib = -r_diff * omg_re + i_diff * omg_im;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
|
||||
let [ra, rb] = re;
|
||||
let [ia, ib] = im;
|
||||
let [romg, iomg] = omg;
|
||||
inv_twiddle(ra, ia, rb, ib, romg, iomg);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
|
||||
let [re_0, re_1, re_2, re_3] = re;
|
||||
let [im_0, im_1, im_2, im_3] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[2];
|
||||
let omg_1: R = omg[3];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
|
||||
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
|
||||
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
|
||||
|
||||
{
|
||||
let omg_4: R = omg[0];
|
||||
let omg_5: R = omg[1];
|
||||
let omg_6: R = omg[2];
|
||||
let omg_7: R = omg[3];
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
|
||||
inv_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
|
||||
inv_itwiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[4];
|
||||
let omg_3: R = omg[5];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
|
||||
inv_itwiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
|
||||
inv_itwiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[6];
|
||||
let omg_1: R = omg[7];
|
||||
inv_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
|
||||
inv_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
|
||||
inv_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ifft16_ref<R: Float + FloatConst>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
|
||||
let [
|
||||
re_0,
|
||||
re_1,
|
||||
re_2,
|
||||
re_3,
|
||||
re_4,
|
||||
re_5,
|
||||
re_6,
|
||||
re_7,
|
||||
re_8,
|
||||
re_9,
|
||||
re_10,
|
||||
re_11,
|
||||
re_12,
|
||||
re_13,
|
||||
re_14,
|
||||
re_15,
|
||||
] = re;
|
||||
let [
|
||||
im_0,
|
||||
im_1,
|
||||
im_2,
|
||||
im_3,
|
||||
im_4,
|
||||
im_5,
|
||||
im_6,
|
||||
im_7,
|
||||
im_8,
|
||||
im_9,
|
||||
im_10,
|
||||
im_11,
|
||||
im_12,
|
||||
im_13,
|
||||
im_14,
|
||||
im_15,
|
||||
] = im;
|
||||
|
||||
{
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
let omg_4: R = omg[4];
|
||||
let omg_5: R = omg[5];
|
||||
let omg_6: R = omg[6];
|
||||
let omg_7: R = omg[7];
|
||||
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
|
||||
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
|
||||
inv_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
|
||||
inv_itwiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
|
||||
inv_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
|
||||
inv_itwiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
|
||||
inv_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
|
||||
inv_itwiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[8];
|
||||
let omg_1: R = omg[9];
|
||||
let omg_2: R = omg[10];
|
||||
let omg_3: R = omg[11];
|
||||
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
|
||||
inv_itwiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
|
||||
inv_itwiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
|
||||
inv_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
|
||||
inv_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
|
||||
inv_itwiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
|
||||
inv_itwiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_2: R = omg[12];
|
||||
let omg_3: R = omg[13];
|
||||
inv_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
|
||||
inv_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
|
||||
inv_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
|
||||
inv_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
|
||||
inv_itwiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
|
||||
inv_itwiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
|
||||
inv_itwiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
|
||||
inv_itwiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
|
||||
}
|
||||
|
||||
{
|
||||
let omg_0: R = omg[14];
|
||||
let omg_1: R = omg[15];
|
||||
inv_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
|
||||
inv_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
|
||||
inv_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
|
||||
inv_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
|
||||
inv_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
|
||||
inv_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
|
||||
inv_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
|
||||
inv_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
|
||||
let romg = omg[0];
|
||||
let iomg = omg[1];
|
||||
|
||||
let (re_lhs, re_rhs) = re.split_at_mut(h);
|
||||
let (im_lhs, im_rhs) = im.split_at_mut(h);
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(
|
||||
&mut re_lhs[i],
|
||||
&mut im_lhs[i],
|
||||
&mut re_rhs[i],
|
||||
&mut im_rhs[i],
|
||||
romg,
|
||||
iomg,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn inv_bitwiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
|
||||
let (r0, r2) = re.split_at_mut(2 * h);
|
||||
let (r0, r1) = r0.split_at_mut(h);
|
||||
let (r2, r3) = r2.split_at_mut(h);
|
||||
|
||||
let (i0, i2) = im.split_at_mut(2 * h);
|
||||
let (i0, i1) = i0.split_at_mut(h);
|
||||
let (i2, i3) = i2.split_at_mut(h);
|
||||
|
||||
let omg_0: R = omg[0];
|
||||
let omg_1: R = omg[1];
|
||||
let omg_2: R = omg[2];
|
||||
let omg_3: R = omg[3];
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_0, omg_1);
|
||||
inv_itwiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
|
||||
}
|
||||
|
||||
for i in 0..h {
|
||||
inv_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_2, omg_3);
|
||||
inv_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
|
||||
}
|
||||
}
|
||||
128
poulpy-hal/src/reference/fft64/reim/mod.rs
Normal file
128
poulpy-hal/src/reference/fft64/reim/mod.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
// ----------------------------------------------------------------------
|
||||
// DISCLAIMER
|
||||
//
|
||||
// This module contains code that has been directly ported from the
|
||||
// spqlios-arithmetic library
|
||||
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
|
||||
// under the Apache License, Version 2.0.
|
||||
//
|
||||
// The porting process from C to Rust was done with minimal changes
|
||||
// in order to preserve the semantics and performance characteristics
|
||||
// of the original implementation.
|
||||
//
|
||||
// Both Poulpy and spqlios-arithmetic are distributed under the terms
|
||||
// of the Apache License, Version 2.0. See the LICENSE file for details.
|
||||
//
|
||||
// ----------------------------------------------------------------------
|
||||
|
||||
#![allow(bad_asm_style)]
|
||||
|
||||
mod conversion;
|
||||
mod fft_ref;
|
||||
mod fft_vec;
|
||||
mod ifft_ref;
|
||||
mod table_fft;
|
||||
mod table_ifft;
|
||||
mod zero;
|
||||
|
||||
pub use conversion::*;
|
||||
pub use fft_ref::*;
|
||||
pub use fft_vec::*;
|
||||
pub use ifft_ref::*;
|
||||
pub use table_fft::*;
|
||||
pub use table_ifft::*;
|
||||
pub use zero::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
unsafe { &*(x.as_ptr() as *const [R; size]) }
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn as_arr_mut<const size: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; size] {
|
||||
debug_assert!(x.len() >= size);
|
||||
unsafe { &mut *(x.as_mut_ptr() as *mut [R; size]) }
|
||||
}
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
#[inline(always)]
|
||||
pub(crate) fn frac_rev_bits<R: Float + FloatConst>(x: usize) -> R {
|
||||
let half: R = R::from(0.5).unwrap();
|
||||
|
||||
match x {
|
||||
0 => R::zero(),
|
||||
1 => half,
|
||||
_ => {
|
||||
if x.is_multiple_of(2) {
|
||||
frac_rev_bits::<R>(x >> 1) * half
|
||||
} else {
|
||||
frac_rev_bits::<R>(x >> 1) * half + half
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReimDFTExecute<D, T> {
|
||||
fn reim_dft_execute(table: &D, data: &mut [T]);
|
||||
}
|
||||
|
||||
pub trait ReimFromZnx {
|
||||
fn reim_from_znx(res: &mut [f64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ReimToZnx {
|
||||
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimToZnxInplace {
|
||||
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64);
|
||||
}
|
||||
|
||||
pub trait ReimAdd {
|
||||
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimAddInplace {
|
||||
fn reim_add_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSub {
|
||||
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubABInplace {
|
||||
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimSubBAInplace {
|
||||
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegate {
|
||||
fn reim_negate(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimNegateInplace {
|
||||
fn reim_negate_inplace(res: &mut [f64]);
|
||||
}
|
||||
|
||||
pub trait ReimMul {
|
||||
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimMulInplace {
|
||||
fn reim_mul_inplace(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimAddMul {
|
||||
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimCopy {
|
||||
fn reim_copy(res: &mut [f64], a: &[f64]);
|
||||
}
|
||||
|
||||
pub trait ReimZero {
|
||||
fn reim_zero(res: &mut [f64]);
|
||||
}
|
||||
207
poulpy-hal/src/reference/fft64/reim/table_fft.rs
Normal file
207
poulpy-hal/src/reference/fft64/reim/table_fft.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
reference::fft64::reim::{ReimDFTExecute, fft_ref, frac_rev_bits},
|
||||
};
|
||||
|
||||
pub struct ReimFFTRef;
|
||||
|
||||
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTRef {
|
||||
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
|
||||
fft_ref(table.m, &table.omg, data);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
|
||||
m: usize,
|
||||
omg: Vec<R>,
|
||||
}
|
||||
|
||||
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::from(1. / 4.).unwrap();
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => {
|
||||
fill_fft2_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
4 => {
|
||||
fill_fft4_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
8 => {
|
||||
fill_fft8_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
16 => {
|
||||
fill_fft16_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fill_fft_bfs_16_omegas(m, quarter, &mut omg, 0);
|
||||
} else {
|
||||
fill_fft_rec_16_omegas(m, quarter, &mut omg, 0);
|
||||
}
|
||||
|
||||
Self { m, omg }
|
||||
}
|
||||
|
||||
pub fn m(&self) -> usize {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn omg(&self) -> &[R] {
|
||||
&self.omg
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 2);
|
||||
let angle: R = j / R::from(2).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle);
|
||||
omg_pos[1] = R::sin(two_pi * angle);
|
||||
pos + 2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 4);
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
pos + 4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 8);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
omg_pos[4] = R::cos(two_pi * angle_4);
|
||||
omg_pos[5] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[6] = R::sin(two_pi * angle_4);
|
||||
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
|
||||
pos + 8
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 16);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let _16th: R = R::from(1. / 16.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let angle_8: R = j / R::from(16).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_1);
|
||||
omg_pos[1] = R::sin(two_pi * angle_1);
|
||||
omg_pos[2] = R::cos(two_pi * angle_2);
|
||||
omg_pos[3] = R::sin(two_pi * angle_2);
|
||||
omg_pos[4] = R::cos(two_pi * angle_4);
|
||||
omg_pos[5] = R::sin(two_pi * angle_4);
|
||||
omg_pos[6] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[8] = R::cos(two_pi * angle_8);
|
||||
omg_pos[9] = R::cos(two_pi * (angle_8 + _8th));
|
||||
omg_pos[10] = R::cos(two_pi * (angle_8 + _16th));
|
||||
omg_pos[11] = R::cos(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[12] = R::sin(two_pi * angle_8);
|
||||
omg_pos[13] = R::sin(two_pi * (angle_8 + _8th));
|
||||
omg_pos[14] = R::sin(two_pi * (angle_8 + _16th));
|
||||
omg_pos[15] = R::sin(two_pi * (angle_8 + _8th + _16th));
|
||||
pos + 16
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft_bfs_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut mm: usize = m;
|
||||
let mut jj: R = j;
|
||||
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
let h = mm >> 1;
|
||||
let j: R = jj * R::from(0.5).unwrap();
|
||||
omg[pos] = R::cos(two_pi * j);
|
||||
omg[pos + 1] = R::sin(two_pi * j);
|
||||
pos += 2;
|
||||
mm = h;
|
||||
jj = j
|
||||
}
|
||||
|
||||
while mm > 16 {
|
||||
let h: usize = mm >> 2;
|
||||
let j: R = jj * R::from(1. / 4.).unwrap();
|
||||
for i in (0..m).step_by(mm) {
|
||||
let rs_0 = j + frac_rev_bits::<R>(i / mm) * R::from(1. / 4.).unwrap();
|
||||
let rs_1 = R::from(2).unwrap() * rs_0;
|
||||
omg[pos] = R::cos(two_pi * rs_1);
|
||||
omg[pos + 1] = R::sin(two_pi * rs_1);
|
||||
omg[pos + 2] = R::cos(two_pi * rs_0);
|
||||
omg[pos + 3] = R::sin(two_pi * rs_0);
|
||||
pos += 4;
|
||||
}
|
||||
mm = h;
|
||||
jj = j;
|
||||
}
|
||||
|
||||
for i in (0..m).step_by(16) {
|
||||
let j = jj + frac_rev_bits(i >> 4);
|
||||
fill_fft16_omegas(j, omg, pos);
|
||||
pos += 16
|
||||
}
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_fft_rec_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fill_fft_bfs_16_omegas(m, j, omg, pos);
|
||||
}
|
||||
let h: usize = m >> 1;
|
||||
let s: R = j * R::from(0.5).unwrap();
|
||||
let _2pi = R::from(2).unwrap() * R::PI();
|
||||
omg[pos] = R::cos(_2pi * s);
|
||||
omg[pos + 1] = R::sin(_2pi * s);
|
||||
pos += 2;
|
||||
pos = fill_fft_rec_16_omegas(h, s, omg, pos);
|
||||
pos = fill_fft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ctwiddle_ref(ra: &mut f64, ia: &mut f64, rb: &mut f64, ib: &mut f64, omg_re: f64, omg_im: f64) {
|
||||
let dr: f64 = *rb * omg_re - *ib * omg_im;
|
||||
let di: f64 = *rb * omg_im + *ib * omg_re;
|
||||
*rb = *ra - dr;
|
||||
*ib = *ia - di;
|
||||
*ra += dr;
|
||||
*ia += di;
|
||||
}
|
||||
201
poulpy-hal/src/reference/fft64/reim/table_ifft.rs
Normal file
201
poulpy-hal/src/reference/fft64/reim/table_ifft.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand_distr::num_traits::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
reference::fft64::reim::{ReimDFTExecute, frac_rev_bits, ifft_ref::ifft_ref},
|
||||
};
|
||||
|
||||
pub struct ReimIFFTRef;
|
||||
|
||||
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTRef {
|
||||
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
|
||||
ifft_ref(table.m, &table.omg, data);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
|
||||
m: usize,
|
||||
omg: Vec<R>,
|
||||
}
|
||||
|
||||
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
|
||||
pub fn new(m: usize) -> Self {
|
||||
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
|
||||
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
|
||||
|
||||
let quarter: R = R::exp2(R::from(-2).unwrap());
|
||||
|
||||
if m <= 16 {
|
||||
match m {
|
||||
1 => {}
|
||||
2 => {
|
||||
fill_ifft2_omegas::<R>(quarter, &mut omg, 0);
|
||||
}
|
||||
4 => {
|
||||
fill_ifft4_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
8 => {
|
||||
fill_ifft8_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
16 => {
|
||||
fill_ifft16_omegas(quarter, &mut omg, 0);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else if m <= 2048 {
|
||||
fill_ifft_bfs_16_omegas(m, quarter, &mut omg, 0);
|
||||
} else {
|
||||
fill_ifft_rec_16_omegas(m, quarter, &mut omg, 0);
|
||||
}
|
||||
|
||||
Self { m, omg }
|
||||
}
|
||||
|
||||
pub fn execute(&self, data: &mut [R]) {
|
||||
ifft_ref(self.m, &self.omg, data);
|
||||
}
|
||||
|
||||
pub fn m(&self) -> usize {
|
||||
self.m
|
||||
}
|
||||
|
||||
pub fn omg(&self) -> &[R] {
|
||||
&self.omg
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 2);
|
||||
let angle: R = j / R::exp2(R::from(2).unwrap());
|
||||
let two_pi: R = R::exp2(R::from(2).unwrap()) * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle);
|
||||
omg_pos[1] = -R::sin(two_pi * angle);
|
||||
pos + 2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 4);
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_2);
|
||||
omg_pos[1] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[2] = R::cos(two_pi * angle_1);
|
||||
omg_pos[3] = -R::sin(two_pi * angle_1);
|
||||
pos + 4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 8);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(2).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_4);
|
||||
omg_pos[1] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[2] = -R::sin(two_pi * angle_4);
|
||||
omg_pos[3] = -R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[4] = R::cos(two_pi * angle_2);
|
||||
omg_pos[5] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[6] = R::cos(two_pi * angle_1);
|
||||
omg_pos[7] = -R::sin(two_pi * angle_1);
|
||||
pos + 8
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
|
||||
let omg_pos: &mut [R] = &mut omg[pos..];
|
||||
assert!(omg_pos.len() >= 16);
|
||||
let _8th: R = R::from(1. / 8.).unwrap();
|
||||
let _16th: R = R::from(1. / 16.).unwrap();
|
||||
let angle_1: R = j / R::from(2).unwrap();
|
||||
let angle_2: R = j / R::from(4).unwrap();
|
||||
let angle_4: R = j / R::from(8).unwrap();
|
||||
let angle_8: R = j / R::from(16).unwrap();
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
omg_pos[0] = R::cos(two_pi * angle_8);
|
||||
omg_pos[1] = R::cos(two_pi * (angle_8 + _8th));
|
||||
omg_pos[2] = R::cos(two_pi * (angle_8 + _16th));
|
||||
omg_pos[3] = R::cos(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[4] = -R::sin(two_pi * angle_8);
|
||||
omg_pos[5] = -R::sin(two_pi * (angle_8 + _8th));
|
||||
omg_pos[6] = -R::sin(two_pi * (angle_8 + _16th));
|
||||
omg_pos[7] = -R::sin(two_pi * (angle_8 + _8th + _16th));
|
||||
omg_pos[8] = R::cos(two_pi * angle_4);
|
||||
omg_pos[9] = -R::sin(two_pi * angle_4);
|
||||
omg_pos[10] = R::cos(two_pi * (angle_4 + _8th));
|
||||
omg_pos[11] = -R::sin(two_pi * (angle_4 + _8th));
|
||||
omg_pos[12] = R::cos(two_pi * angle_2);
|
||||
omg_pos[13] = -R::sin(two_pi * angle_2);
|
||||
omg_pos[14] = R::cos(two_pi * angle_1);
|
||||
omg_pos[15] = -R::sin(two_pi * angle_1);
|
||||
pos + 16
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft_bfs_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
|
||||
let mut jj: R = j * R::from(16).unwrap() / R::from(m).unwrap();
|
||||
|
||||
for i in (0..m).step_by(16) {
|
||||
let j = jj + frac_rev_bits(i >> 4);
|
||||
fill_ifft16_omegas(j, omg, pos);
|
||||
pos += 16
|
||||
}
|
||||
|
||||
let mut h: usize = 16;
|
||||
let m_half: usize = m >> 1;
|
||||
|
||||
let two_pi: R = R::from(2).unwrap() * R::PI();
|
||||
|
||||
while h < m_half {
|
||||
let mm: usize = h << 2;
|
||||
for i in (0..m).step_by(mm) {
|
||||
let rs_0 = jj + frac_rev_bits::<R>(i / mm) / R::from(4).unwrap();
|
||||
let rs_1 = R::from(2).unwrap() * rs_0;
|
||||
omg[pos] = R::cos(two_pi * rs_0);
|
||||
omg[pos + 1] = -R::sin(two_pi * rs_0);
|
||||
omg[pos + 2] = R::cos(two_pi * rs_1);
|
||||
omg[pos + 3] = -R::sin(two_pi * rs_1);
|
||||
pos += 4;
|
||||
}
|
||||
h = mm;
|
||||
jj = jj * R::from(4).unwrap();
|
||||
}
|
||||
|
||||
if !log_m.is_multiple_of(2) {
|
||||
omg[pos] = R::cos(two_pi * jj);
|
||||
omg[pos + 1] = -R::sin(two_pi * jj);
|
||||
pos += 2;
|
||||
jj = jj * R::from(2).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(jj, j);
|
||||
|
||||
pos
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn fill_ifft_rec_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
|
||||
if m <= 2048 {
|
||||
return fill_ifft_bfs_16_omegas(m, j, omg, pos);
|
||||
}
|
||||
let h: usize = m >> 1;
|
||||
let s: R = j / R::from(2).unwrap();
|
||||
pos = fill_ifft_rec_16_omegas(h, s, omg, pos);
|
||||
pos = fill_ifft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
|
||||
let _2pi = R::from(2).unwrap() * R::PI();
|
||||
omg[pos] = R::cos(_2pi * s);
|
||||
omg[pos + 1] = -R::sin(_2pi * s);
|
||||
pos += 2;
|
||||
pos
|
||||
}
|
||||
11
poulpy-hal/src/reference/fft64/reim/zero.rs
Normal file
11
poulpy-hal/src/reference/fft64/reim/zero.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
pub fn reim_zero_ref(res: &mut [f64]) {
|
||||
res.fill(0.);
|
||||
}
|
||||
|
||||
pub fn reim_copy_ref(res: &mut [f64], a: &[f64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
res.copy_from_slice(a);
|
||||
}
|
||||
209
poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs
Normal file
209
poulpy-hal/src/reference/fft64/reim4/arithmetic_ref.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use crate::reference::fft64::reim::as_arr;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= 2 * rows * 4);
|
||||
|
||||
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
|
||||
chunk.copy_from_slice(&src[offset..offset + 4]);
|
||||
offset += m
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_1blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= offset + m + 4);
|
||||
debug_assert!(src.len() >= 8);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
dst_off[0] += src[0];
|
||||
dst_off[1] += src[1];
|
||||
dst_off[2] += src[2];
|
||||
dst_off[3] += src[3];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
dst_off[0] += src[4];
|
||||
dst_off[1] += src[5];
|
||||
dst_off[2] += src[6];
|
||||
dst_off[3] += src[7];
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
let mut offset: usize = blk << 2;
|
||||
|
||||
debug_assert!(blk < (m >> 2));
|
||||
debug_assert!(dst.len() >= offset + 3 * m + 4);
|
||||
debug_assert!(src.len() >= 16);
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[0..4]);
|
||||
} else {
|
||||
dst_off[0] += src[0];
|
||||
dst_off[1] += src[1];
|
||||
dst_off[2] += src[2];
|
||||
dst_off[3] += src[3];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[4..8]);
|
||||
} else {
|
||||
dst_off[0] += src[4];
|
||||
dst_off[1] += src[5];
|
||||
dst_off[2] += src[6];
|
||||
dst_off[3] += src[7];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[8..12]);
|
||||
} else {
|
||||
dst_off[0] += src[8];
|
||||
dst_off[1] += src[9];
|
||||
dst_off[2] += src[10];
|
||||
dst_off[3] += src[11];
|
||||
}
|
||||
|
||||
offset += m;
|
||||
let dst_off = &mut dst[offset..offset + 4];
|
||||
if OVERWRITE {
|
||||
dst_off.copy_from_slice(&src[12..16]);
|
||||
} else {
|
||||
dst_off[0] += src[12];
|
||||
dst_off[1] += src[13];
|
||||
dst_off[2] += src[14];
|
||||
dst_off[3] += src[15];
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat1col_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 8 doubles: [re1(4), im1(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 8 doubles: [ar(4) | ai(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
|
||||
}
|
||||
|
||||
println!("u_ref: {:?}", &u[..nrows * 8]);
|
||||
println!("v_ref: {:?}", &v[..nrows * 8]);
|
||||
|
||||
let mut acc: [f64; 8] = [0f64; 8];
|
||||
let mut j = 0;
|
||||
for _ in 0..nrows {
|
||||
reim4_add_mul(&mut acc, as_arr(&u[j..]), as_arr(&v[j..]));
|
||||
j += 8;
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc);
|
||||
|
||||
println!("dst_ref: {:?}", &dst[..8]);
|
||||
println!();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat2cols_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 16 doubles: [re1(4), im1(4), re2(4), im2(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 16 doubles: [ar(4) | ai(4) | br(4) | bi(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
|
||||
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows * 16 doubles"
|
||||
);
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
let mut acc_0: [f64; 8] = [0f64; 8];
|
||||
let mut acc_1: [f64; 8] = [0f64; 8];
|
||||
for i in 0..nrows {
|
||||
let _1j: usize = i << 3;
|
||||
let _2j: usize = i << 4;
|
||||
let u_j: &[f64; 8] = as_arr(&u[_1j..]);
|
||||
reim4_add_mul(&mut acc_0, u_j, as_arr(&v[_2j..]));
|
||||
reim4_add_mul(&mut acc_1, u_j, as_arr(&v[_2j + 8..]));
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc_0);
|
||||
dst[8..16].copy_from_slice(&acc_1);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_vec_mat2cols_2ndcol_product_ref(
|
||||
nrows: usize,
|
||||
dst: &mut [f64], // 8 doubles: [re1(4), im1(4), re2(4), im2(4)]
|
||||
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
|
||||
v: &[f64], // nrows * 16 doubles: [x | x | br(4) | bi(4)] per row
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
dst.len() >= 8,
|
||||
"dst must be at least 8 doubles but is {}",
|
||||
dst.len()
|
||||
);
|
||||
assert!(
|
||||
u.len() >= nrows * 8,
|
||||
"u must be at least nrows={} * 8 doubles but is {}",
|
||||
nrows,
|
||||
u.len()
|
||||
);
|
||||
assert!(
|
||||
v.len() >= nrows * 16,
|
||||
"v must be at least nrows={} * 16 doubles but is {}",
|
||||
nrows,
|
||||
v.len()
|
||||
);
|
||||
}
|
||||
|
||||
// zero accumulators
|
||||
let mut acc: [f64; 8] = [0f64; 8];
|
||||
for i in 0..nrows {
|
||||
let _1j: usize = i << 3;
|
||||
let _2j: usize = i << 4;
|
||||
reim4_add_mul(&mut acc, as_arr(&u[_1j..]), as_arr(&v[_2j + 8..]));
|
||||
}
|
||||
dst[0..8].copy_from_slice(&acc);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
|
||||
for k in 0..4 {
|
||||
let ar: f64 = a[k];
|
||||
let br: f64 = b[k];
|
||||
let ai: f64 = a[k + 4];
|
||||
let bi: f64 = b[k + 4];
|
||||
dst[k] += ar * br - ai * bi;
|
||||
dst[k + 4] += ar * bi + ai * br;
|
||||
}
|
||||
}
|
||||
27
poulpy-hal/src/reference/fft64/reim4/mod.rs
Normal file
27
poulpy-hal/src/reference/fft64/reim4/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
mod arithmetic_ref;
|
||||
|
||||
pub use arithmetic_ref::*;
|
||||
|
||||
pub trait Reim4Extract1Blk {
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save1Blk {
|
||||
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Save2Blks {
|
||||
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat1ColProd {
|
||||
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat2ColsProd {
|
||||
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
|
||||
pub trait Reim4Mat2Cols2ndColProd {
|
||||
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
|
||||
}
|
||||
119
poulpy-hal/src/reference/fft64/svp.rs
Normal file
119
poulpy-hal/src/reference/fft64/svp.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, ScalarZnx, ScalarZnxToRef, SvpPPol, SvpPPolToMut, SvpPPolToRef, VecZnx, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
},
|
||||
reference::fft64::reim::{ReimAddMul, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimMul, ReimMulInplace, ReimZero},
|
||||
};
|
||||
|
||||
pub fn svp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx,
|
||||
R: SvpPPolToMut<BE>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: SvpPPol<&mut [u8], BE> = res.to_mut();
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
BE::reim_from_znx(res.at_mut(res_col, 0), a.at(a_col, 0));
|
||||
BE::reim_dft_execute(table, res.at_mut(res_col, 0));
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft<R, A, B, BE>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimZero + ReimFromZnx + ReimMulInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
let out: &mut [f64] = res.at_mut(res_col, j);
|
||||
BE::reim_from_znx(out, b.at(b_col, j));
|
||||
BE::reim_dft_execute(table, out);
|
||||
BE::reim_mul_inplace(out, ppol);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimMul + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
BE::reim_mul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAddMul + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let b_size: usize = b.size();
|
||||
let min_size: usize = res_size.min(b_size);
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..min_size {
|
||||
BE::reim_addmul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn svp_apply_dft_to_dft_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimMulInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: SvpPPolToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: SvpPPol<&[u8], BE> = a.to_ref();
|
||||
|
||||
let ppol: &[f64] = a.at(a_col, 0);
|
||||
for j in 0..res.size() {
|
||||
BE::reim_mul_inplace(res.at_mut(res_col, j), ppol);
|
||||
}
|
||||
}
|
||||
521
poulpy-hal/src/reference/fft64/vec_znx_big.rs
Normal file
521
poulpy-hal/src/reference/fft64/vec_znx_big.rs
Normal file
@@ -0,0 +1,521 @@
|
||||
use std::f64::consts::SQRT_2;
|
||||
|
||||
use crate::{
|
||||
api::VecZnxBigAddNormal,
|
||||
layouts::{
|
||||
Backend, Module, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut,
|
||||
},
|
||||
oep::VecZnxBigAllocBytesImpl,
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
|
||||
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
|
||||
},
|
||||
znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
|
||||
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_big_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_small<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_small_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism<R, A, BE>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_automorphism::<_, _, BE>(p, &mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_automorphism_inplace<R, BE>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_automorphism_inplace::<_, BE>(p, &mut res_vznx, res_col, tmp);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_negate<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxNegate + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_negate::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_negate_inplace<R, BE>(res: &mut R, res_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_negate_inplace::<_, BE>(&mut res_vznx, res_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
BE: Backend<ScalarBig = i64>
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxZero,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], _> = a.to_ref();
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
|
||||
}
|
||||
|
||||
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxBigToMut<B>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], B> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxBigAddNormal<B>,
|
||||
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
let sqrt2: f64 = SQRT_2;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(basek, col_i) * k_f64;
|
||||
assert!(
|
||||
(std - sigma * sqrt2).abs() < 0.1,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
sigma * sqrt2
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- B - A
|
||||
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_small_a<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<BE>,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let b: VecZnxBig<&[u8], BE> = b.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let b_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: b.data,
|
||||
n: b.n,
|
||||
cols: b.cols,
|
||||
size: b.size,
|
||||
max_size: b.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, a, a_col, &b_vznx, b_col);
|
||||
}
|
||||
|
||||
/// R <- A - B
|
||||
pub fn vec_znx_big_sub_small_b<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxBigToRef<BE>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxBig<&[u8], BE> = a.to_ref();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
let a_vznx: VecZnx<&[u8]> = VecZnx {
|
||||
data: a.data,
|
||||
n: a.n,
|
||||
cols: a.cols,
|
||||
size: a.size,
|
||||
max_size: a.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
|
||||
}
|
||||
|
||||
/// R <- R - A
|
||||
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
|
||||
/// R <- A - R
|
||||
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
|
||||
data: res.data,
|
||||
n: res.n,
|
||||
cols: res.cols,
|
||||
size: res.size,
|
||||
max_size: res.max_size,
|
||||
};
|
||||
|
||||
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
|
||||
}
|
||||
369
poulpy-hal/src/reference/fft64/vec_znx_dft.rs
Normal file
369
poulpy-hal/src/reference/fft64/vec_znx_dft.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use bytemuck::cast_slice_mut;
|
||||
|
||||
use crate::{
|
||||
layouts::{
|
||||
Backend, Data, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos,
|
||||
ZnxView, ZnxViewMut,
|
||||
},
|
||||
reference::{
|
||||
fft64::reim::{
|
||||
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
|
||||
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
|
||||
},
|
||||
znx::ZnxZero,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn vec_znx_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAdd + ReimCopy + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimAddInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let steps: usize = a.size().div_ceil(step);
|
||||
let min_steps: usize = res.size().min(steps);
|
||||
|
||||
(0..min_steps).for_each(|j| {
|
||||
let limb: usize = offset + j * step;
|
||||
if limb < a.size() {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, limb));
|
||||
}
|
||||
});
|
||||
(min_steps..res.size()).for_each(|j| {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
})
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_apply<R, A, BE>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
step: usize,
|
||||
offset: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
) where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + ReimZero,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(step > 0);
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let a_size: usize = a.size();
|
||||
let res_size: usize = res.size();
|
||||
|
||||
let steps: usize = a_size.div_ceil(step);
|
||||
let min_steps: usize = res_size.min(steps);
|
||||
|
||||
for j in 0..min_steps {
|
||||
let limb = offset + j * step;
|
||||
if limb < a_size {
|
||||
BE::reim_from_znx(res.at_mut(res_col, j), a.at(a_col, limb));
|
||||
BE::reim_dft_execute(table, res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
(min_steps..res.size()).for_each(|j| {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64>
|
||||
+ ReimDFTExecute<ReimIFFTTable<f64>, f64>
|
||||
+ ReimCopy
|
||||
+ ReimToZnxInplace
|
||||
+ ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for j in 0..min_size {
|
||||
let res_slice_f64: &mut [f64] = cast_slice_mut(res.at_mut(res_col, j));
|
||||
BE::reim_copy(res_slice_f64, a.at(a_col, j));
|
||||
BE::reim_dft_execute(table, res_slice_f64);
|
||||
BE::reim_to_znx_inplace(res_slice_f64, divisor);
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply_tmpa<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnx + ZnxZero,
|
||||
R: VecZnxBigToMut<BE>,
|
||||
A: VecZnxDftToMut<BE>,
|
||||
{
|
||||
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
|
||||
let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size = res.size();
|
||||
let min_size: usize = res_size.min(a.size());
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for j in 0..min_size {
|
||||
BE::reim_dft_execute(table, a.at_mut(a_col, j));
|
||||
BE::reim_to_znx(res.at_mut(res_col, j), divisor, a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
BE::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_idft_apply_consume<D: Data, BE>(table: &ReimIFFTTable<f64>, mut res: VecZnxDft<D, BE>) -> VecZnxBig<D, BE>
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnxInplace,
|
||||
VecZnxDft<D, BE>: VecZnxDftToMut<BE>,
|
||||
{
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(table.m() << 1, res.n());
|
||||
}
|
||||
|
||||
let divisor: f64 = table.m() as f64;
|
||||
|
||||
for i in 0..res.cols() {
|
||||
for j in 0..res.size() {
|
||||
BE::reim_dft_execute(table, res.at_mut(i, j));
|
||||
BE::reim_to_znx_inplace(res.at_mut(i, j), divisor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res.into_big()
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSub + ReimNegate + ReimZero + ReimCopy,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], BE> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_negate(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
BE::reim_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
{
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
BE::reim_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_dft_zero<R, BE>(res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
BE: Backend<ScalarPrep = f64> + ReimZero,
|
||||
{
|
||||
BE::reim_zero(res.to_mut().raw_mut());
|
||||
}
|
||||
365
poulpy-hal/src/reference/fft64/vmp.rs
Normal file
365
poulpy-hal/src/reference/fft64/vmp.rs
Normal file
@@ -0,0 +1,365 @@
|
||||
use crate::{
|
||||
cast_mut,
|
||||
layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
|
||||
oep::VecZnxDftAllocBytesImpl,
|
||||
reference::fft64::{
|
||||
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
||||
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
|
||||
vec_znx_dft::vec_znx_dft_apply,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::layouts::{Backend, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatToRef, ZnxInfos};
|
||||
|
||||
pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
R: VmpPMatToMut<BE>,
|
||||
A: MatZnxToRef,
|
||||
{
|
||||
let mut res: crate::layouts::VmpPMat<&mut [u8], BE> = pmat.to_mut();
|
||||
let a: MatZnx<&[u8]> = mat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
"res.cols_in: {} != a.cols_in: {}",
|
||||
res.cols_in(),
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rows(),
|
||||
a.rows(),
|
||||
"res.rows: {} != a.rows: {}",
|
||||
res.rows(),
|
||||
a.rows()
|
||||
);
|
||||
assert_eq!(
|
||||
res.cols_out(),
|
||||
a.cols_out(),
|
||||
"res.cols_out: {} != a.cols_out: {}",
|
||||
res.cols_out(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size: {} != a.size: {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let nrows: usize = a.cols_in() * a.rows();
|
||||
let ncols: usize = a.cols_out() * a.size();
|
||||
vmp_prepare_core::<BE>(table, res.raw_mut(), a.raw(), nrows, ncols, tmp);
|
||||
}
|
||||
|
||||
pub(crate) fn vmp_prepare_core<REIM>(
|
||||
table: &ReimFFTTable<f64>,
|
||||
pmat: &mut [f64],
|
||||
mat: &[i64],
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
tmp: &mut [f64],
|
||||
) where
|
||||
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
|
||||
{
|
||||
let m: usize = table.m();
|
||||
let n: usize = m << 1;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n >= 8);
|
||||
assert_eq!(mat.len(), n * nrows * ncols);
|
||||
assert_eq!(pmat.len(), n * nrows * ncols);
|
||||
assert_eq!(tmp.len(), vmp_prepare_tmp_bytes(n) / size_of::<i64>())
|
||||
}
|
||||
|
||||
let offset: usize = nrows * ncols * 8;
|
||||
|
||||
for row_i in 0..nrows {
|
||||
for col_i in 0..ncols {
|
||||
let pos: usize = n * (row_i * ncols + col_i);
|
||||
|
||||
REIM::reim_from_znx(tmp, &mat[pos..pos + n]);
|
||||
REIM::reim_dft_execute(table, tmp);
|
||||
|
||||
let dst: &mut [f64] = if col_i == (ncols - 1) && !ncols.is_multiple_of(2) {
|
||||
&mut pmat[col_i * nrows * 8 + row_i * 8..]
|
||||
} else {
|
||||
&mut pmat[(col_i / 2) * (nrows * 16) + row_i * 16 + (col_i % 2) * 8..]
|
||||
};
|
||||
|
||||
for blk_i in 0..m >> 2 {
|
||||
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_tmp_bytes(n: usize, a_size: usize, prows: usize, pcols_in: usize) -> usize {
|
||||
let row_max: usize = (a_size).min(prows);
|
||||
(16 + (n + 8) * row_max * pcols_in) * size_of::<f64>()
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft<R, A, M, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ VecZnxDftAllocBytesImpl<BE>
|
||||
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk
|
||||
+ ReimFromZnx,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
let n: usize = a.n();
|
||||
let cols: usize = pmat.cols_in();
|
||||
let size: usize = a.size().min(pmat.rows());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(tmp_bytes.len() >= vmp_apply_dft_tmp_bytes(n, size, pmat.rows(), cols));
|
||||
assert!(a.cols() <= cols);
|
||||
}
|
||||
|
||||
let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size));
|
||||
|
||||
let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size);
|
||||
|
||||
let offset: usize = cols - a.cols();
|
||||
for j in 0..cols {
|
||||
vec_znx_dft_apply(table, 1, 0, &mut a_dft, j, &a, offset + j);
|
||||
}
|
||||
|
||||
vmp_apply_dft_to_dft(res, &a_dft, &pmat, tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usize) -> usize {
|
||||
let row_max: usize = (a_size).min(prows);
|
||||
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
use crate::layouts::{ZnxView, ZnxViewMut};
|
||||
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), pmat.n());
|
||||
assert_eq!(a.n(), pmat.n());
|
||||
assert_eq!(res.cols(), pmat.cols_out());
|
||||
assert_eq!(a.cols(), pmat.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let nrows: usize = pmat.cols_in() * pmat.rows();
|
||||
let ncols: usize = pmat.cols_out() * pmat.size();
|
||||
|
||||
let pmat_raw: &[f64] = pmat.raw();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<true, BE>(n, res_raw, a_raw, pmat_raw, 0, nrows, ncols, tmp_bytes)
|
||||
}
|
||||
|
||||
pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_offset: usize, tmp_bytes: &mut [f64])
|
||||
where
|
||||
BE: Backend<ScalarPrep = f64>
|
||||
+ ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxDftToRef<BE>,
|
||||
M: VmpPMatToRef<BE>,
|
||||
{
|
||||
use crate::layouts::{ZnxView, ZnxViewMut};
|
||||
|
||||
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], BE> = a.to_ref();
|
||||
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), pmat.n());
|
||||
assert_eq!(a.n(), pmat.n());
|
||||
assert_eq!(res.cols(), pmat.cols_out());
|
||||
assert_eq!(a.cols(), pmat.cols_in());
|
||||
}
|
||||
|
||||
let n: usize = res.n();
|
||||
let nrows: usize = pmat.cols_in() * pmat.rows();
|
||||
let ncols: usize = pmat.cols_out() * pmat.size();
|
||||
|
||||
let pmat_raw: &[f64] = pmat.raw();
|
||||
let a_raw: &[f64] = a.raw();
|
||||
let res_raw: &mut [f64] = res.raw_mut();
|
||||
|
||||
vmp_apply_dft_to_dft_core::<false, BE>(
|
||||
n,
|
||||
res_raw,
|
||||
a_raw,
|
||||
pmat_raw,
|
||||
limb_offset,
|
||||
nrows,
|
||||
ncols,
|
||||
tmp_bytes,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
|
||||
n: usize,
|
||||
res: &mut [f64],
|
||||
a: &[f64],
|
||||
pmat: &[f64],
|
||||
limb_offset: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
tmp_bytes: &mut [f64],
|
||||
) where
|
||||
REIM: ReimZero
|
||||
+ Reim4Extract1Blk
|
||||
+ Reim4Mat1ColProd
|
||||
+ Reim4Mat2Cols2ndColProd
|
||||
+ Reim4Mat2ColsProd
|
||||
+ Reim4Save2Blks
|
||||
+ Reim4Save1Blk,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n >= 8);
|
||||
assert!(n.is_power_of_two());
|
||||
assert_eq!(pmat.len(), n * nrows * ncols);
|
||||
assert!(res.len() & (n - 1) == 0);
|
||||
assert!(a.len() & (n - 1) == 0);
|
||||
}
|
||||
|
||||
let a_size: usize = a.len() / n;
|
||||
let res_size: usize = res.len() / n;
|
||||
|
||||
let m: usize = n >> 1;
|
||||
|
||||
let (mat2cols_output, extracted_blk) = tmp_bytes.split_at_mut(16);
|
||||
|
||||
let row_max: usize = nrows.min(a_size);
|
||||
let col_max: usize = ncols.min(res_size);
|
||||
|
||||
if limb_offset >= col_max {
|
||||
if OVERWRITE {
|
||||
REIM::reim_zero(res);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for blk_i in 0..(m >> 2) {
|
||||
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
|
||||
|
||||
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
|
||||
|
||||
if limb_offset.is_multiple_of(2) {
|
||||
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
} else {
|
||||
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
|
||||
REIM::reim4_mat2cols_2ndcol_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
|
||||
|
||||
for (col_res, col_pmat) in (1..)
|
||||
.step_by(2)
|
||||
.zip((limb_offset + 1..col_max - 1).step_by(2))
|
||||
{
|
||||
let col_offset: usize = col_pmat * (8 * nrows);
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
|
||||
}
|
||||
}
|
||||
|
||||
if !col_max.is_multiple_of(2) {
|
||||
let last_col: usize = col_max - 1;
|
||||
let col_offset: usize = last_col * (8 * nrows);
|
||||
|
||||
if last_col >= limb_offset {
|
||||
if ncols == col_max {
|
||||
REIM::reim4_mat1col_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
} else {
|
||||
REIM::reim4_mat2cols_prod(
|
||||
row_max,
|
||||
mat2cols_output,
|
||||
extracted_blk,
|
||||
&mat_blk_start[col_offset..],
|
||||
);
|
||||
}
|
||||
REIM::reim4_save_1blk::<OVERWRITE>(
|
||||
m,
|
||||
blk_i,
|
||||
&mut res[(last_col - limb_offset) * n..],
|
||||
mat2cols_output,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
REIM::reim_zero(&mut res[col_max * n..]);
|
||||
}
|
||||
4
poulpy-hal/src/reference/mod.rs
Normal file
4
poulpy-hal/src/reference/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod fft64;
|
||||
pub mod vec_znx;
|
||||
pub mod zn;
|
||||
pub mod znx;
|
||||
177
poulpy-hal/src/reference/vec_znx/add.rs
Normal file
177
poulpy-hal/src/reference/vec_znx/add.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxAdd, VecZnxAddInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_add<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxAddInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAdd + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_add::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAdd + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_add(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAddInplace + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_add_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAddInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_add_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
57
poulpy-hal/src/reference/vec_znx/add_scalar.rs
Normal file
57
poulpy-hal/src/reference/vec_znx/add_scalar.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::{
|
||||
layouts::{ScalarZnx, ScalarZnxToRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_add_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
b_limb < min_size,
|
||||
"b_limb: {} > min_size: {}",
|
||||
b_limb,
|
||||
min_size
|
||||
);
|
||||
}
|
||||
|
||||
for j in 0..min_size {
|
||||
if j == b_limb {
|
||||
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, 0), b.at(b_col, j));
|
||||
} else {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
ZNXARI: ZnxAddInplace,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(res_limb < res.size());
|
||||
}
|
||||
|
||||
ZNXARI::znx_add_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
|
||||
}
|
||||
150
poulpy-hal/src/reference/vec_znx/automorphism.rs
Normal file
150
poulpy-hal/src/reference/vec_znx/automorphism.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxAutomorphismInplace,
|
||||
VecZnxAutomorphismInplaceTmpBytes,
|
||||
},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxAutomorphism, ZnxCopy, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_automorphism_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_automorphism<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxAutomorphism + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use crate::layouts::ZnxInfos;
|
||||
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_automorphism(p, res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_automorphism_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxAutomorphism + ZnxCopy,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), tmp.len());
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_automorphism(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_automorphism(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_automorphism_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxAutomorphismInplace<B> + ModuleNew<B> + VecZnxAutomorphismInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_automorphism_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_automorphism_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
32
poulpy-hal/src/reference/vec_znx/copy.rs
Normal file
32
poulpy-hal/src/reference/vec_znx/copy.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxCopy, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_copy<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxCopy + ZnxZero,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let res_size = res.size();
|
||||
let a_size = a.size();
|
||||
|
||||
let min_size = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
49
poulpy-hal/src/reference/vec_znx/merge_rings.rs
Normal file
49
poulpy-hal/src/reference/vec_znx/merge_rings.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
|
||||
reference::{
|
||||
vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring},
|
||||
znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero},
|
||||
},
|
||||
};
|
||||
|
||||
pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_merge_rings<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_out, n_in) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(tmp.len(), res.n());
|
||||
|
||||
debug_assert!(
|
||||
n_out > n_in,
|
||||
"invalid a: output ring degree should be greater"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_in,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
assert!(n_out.is_multiple_of(n_in));
|
||||
assert_eq!(a.len(), n_out / n_in);
|
||||
}
|
||||
|
||||
a.iter().for_each(|ai| {
|
||||
vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col);
|
||||
vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp);
|
||||
});
|
||||
|
||||
vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp);
|
||||
}
|
||||
31
poulpy-hal/src/reference/vec_znx/mod.rs
Normal file
31
poulpy-hal/src/reference/vec_znx/mod.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
mod add;
|
||||
mod add_scalar;
|
||||
mod automorphism;
|
||||
mod copy;
|
||||
mod merge_rings;
|
||||
mod mul_xp_minus_one;
|
||||
mod negate;
|
||||
mod normalize;
|
||||
mod rotate;
|
||||
mod sampling;
|
||||
mod shift;
|
||||
mod split_ring;
|
||||
mod sub;
|
||||
mod sub_scalar;
|
||||
mod switch_ring;
|
||||
|
||||
pub use add::*;
|
||||
pub use add_scalar::*;
|
||||
pub use automorphism::*;
|
||||
pub use copy::*;
|
||||
pub use merge_rings::*;
|
||||
pub use mul_xp_minus_one::*;
|
||||
pub use negate::*;
|
||||
pub use normalize::*;
|
||||
pub use rotate::*;
|
||||
pub use sampling::*;
|
||||
pub use shift::*;
|
||||
pub use split_ring::*;
|
||||
pub use sub::*;
|
||||
pub use sub_scalar::*;
|
||||
pub use switch_ring::*;
|
||||
136
poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs
Normal file
136
poulpy-hal/src/reference/vec_znx/mul_xp_minus_one.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace,
|
||||
VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
|
||||
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
|
||||
{
|
||||
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
|
||||
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
|
||||
}
|
||||
|
||||
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), tmp.len());
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_mul_xp_minus_one(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_mul_xp_minus_one_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxMulXpMinusOneInplace<B> + ModuleNew<B> + VecZnxMulXpMinusOneInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_mul_xp_minus_one_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
131
poulpy-hal/src/reference/vec_znx/negate.rs
Normal file
131
poulpy-hal/src/reference/vec_znx/negate.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxNegate, VecZnxNegateInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxNegate, ZnxNegateInplace, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_negate<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxNegate + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_negate(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_negate_inplace<R, ZNXARI>(res: &mut R, res_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxNegateInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNegate + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNegate + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_negate(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_negate_inplace(&mut a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
193
poulpy-hal/src/reference/vec_znx/normalize.rs
Normal file
193
poulpy-hal/src/reference/vec_znx/normalize.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{
|
||||
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeFinalStep
|
||||
+ ZnxNormalizeFirstStep,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size = a.size();
|
||||
|
||||
if a_size > res_size {
|
||||
for j in (res_size..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in (1..res_size).rev() {
|
||||
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
|
||||
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
|
||||
} else {
|
||||
for j in (0..a_size).rev() {
|
||||
if j == a_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
|
||||
}
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(carry.len() >= res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
|
||||
for j in (0..res_size).rev() {
|
||||
if j == res_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_normalize::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
148
poulpy-hal/src/reference/vec_znx/rotate.rs
Normal file
148
poulpy-hal/src/reference/vec_znx/rotate.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxCopy, ZnxRotate, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_rotate_inplace_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_rotate<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxRotate + ZnxZero,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let min_size: usize = res_size.min(a_size);
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_rotate(p, res.at_mut(res_col, j), a.at(a_col, j))
|
||||
}
|
||||
|
||||
for j in min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rotate_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxRotate + ZnxCopy,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), tmp.len());
|
||||
}
|
||||
for j in 0..res.size() {
|
||||
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRotate + ModuleNew<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rotate::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRotate + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rotate(-7, &mut res, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rotate_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRotateInplace<B> + ModuleNew<B> + VecZnxRotateInplaceTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes());
|
||||
|
||||
// Fill a with random i64
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rotate_inplace(-7, &mut res, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
64
poulpy-hal/src/reference/vec_znx/sampling.rs
Normal file
64
poulpy-hal/src/reference/vec_znx/sampling.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut},
|
||||
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_fill_normal_ref<R>(
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
source: &mut Source,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_fill_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
res.at_mut(res_col, limb),
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
672
poulpy-hal/src/reference/vec_znx/shift.rs
Normal file
672
poulpy-hal/src/reference/vec_znx/shift.rs
Normal file
@@ -0,0 +1,672 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
|
||||
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::vec_znx_copy,
|
||||
znx::{
|
||||
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
|
||||
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
|
||||
ZnxZero,
|
||||
},
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let n: usize = res.n();
|
||||
let cols: usize = res.cols();
|
||||
let size: usize = res.size();
|
||||
let steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if steps >= size {
|
||||
for j in 0..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Inplace shift of limbs by a k/basek
|
||||
if steps > 0 {
|
||||
let start: usize = n * res_col;
|
||||
let end: usize = start + n;
|
||||
let slice_size: usize = n * cols;
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
|
||||
(0..size - steps).for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
|
||||
ZNXARI::znx_copy(
|
||||
&mut lhs[start + j * slice_size..end + j * slice_size],
|
||||
&rhs[start..end],
|
||||
);
|
||||
});
|
||||
|
||||
for j in size - steps..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
// Inplace normalization with left shift of k % basek
|
||||
if !k.is_multiple_of(basek) {
|
||||
for j in (0..size - steps).rev() {
|
||||
if j == size - steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero + ZnxNormalizeFirstStep + ZnxNormalizeMiddleStep + ZnxNormalizeFirstStep + ZnxCopy + ZnxNormalizeFinalStep,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size = a.size();
|
||||
let steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if steps >= res_size.min(a_size) {
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let min_size: usize = a_size.min(res_size) - steps;
|
||||
|
||||
// Simply a left shifted normalization of limbs
|
||||
// by k/basek and intra-limb by basek - k%basek
|
||||
if !k.is_multiple_of(basek) {
|
||||
for j in (0..min_size).rev() {
|
||||
if j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
basek,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
} else if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step(
|
||||
basek,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
basek,
|
||||
k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j + steps),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If k % basek = 0, then this is simply a copy.
|
||||
for j in (0..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
|
||||
}
|
||||
}
|
||||
|
||||
// Zeroes bottom
|
||||
for j in min_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let n: usize = res.n();
|
||||
let cols: usize = res.cols();
|
||||
let size: usize = res.size();
|
||||
|
||||
let mut steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if steps >= size {
|
||||
for j in 0..size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let start: usize = n * res_col;
|
||||
let end: usize = start + n;
|
||||
let slice_size: usize = n * cols;
|
||||
|
||||
if !k.is_multiple_of(basek) {
|
||||
// We rsh by an additional basek and then lsh by basek-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
|
||||
// All limbs of a that would fall outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
if j == size - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
|
||||
}
|
||||
});
|
||||
|
||||
// Continues with shifted normalization
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
let rhs_slice: &mut [i64] = &mut rhs[start..end];
|
||||
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
|
||||
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
|
||||
});
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Shift by multiples of basek
|
||||
let res_raw: &mut [i64] = res.raw_mut();
|
||||
(steps..size).rev().for_each(|j| {
|
||||
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
|
||||
ZNXARI::znx_copy(
|
||||
&mut rhs[start..end],
|
||||
&lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
|
||||
);
|
||||
});
|
||||
|
||||
// Zeroes the top
|
||||
(0..steps).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxZero
|
||||
+ ZnxCopy
|
||||
+ ZnxNormalizeFirstStepCarryOnly
|
||||
+ ZnxNormalizeMiddleStepCarryOnly
|
||||
+ ZnxNormalizeFirstStep
|
||||
+ ZnxNormalizeMiddleStep
|
||||
+ ZnxNormalizeMiddleStepInplace
|
||||
+ ZnxNormalizeFirstStepInplace
|
||||
+ ZnxNormalizeFinalStepInplace,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let mut steps: usize = k / basek;
|
||||
let k_rem: usize = k % basek;
|
||||
|
||||
if k == 0 {
|
||||
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
if steps >= res_size {
|
||||
for j in 0..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !k.is_multiple_of(basek) {
|
||||
// We rsh by an additional basek and then lsh by basek-k
|
||||
// Allows to re-use efficient normalization code, avoids
|
||||
// avoids overflows & produce output that is normalized
|
||||
steps += 1;
|
||||
|
||||
// All limbs of a that are moved outside of the limbs of res are discarded,
|
||||
// but the carry still need to be computed.
|
||||
for j in (res_size..a_size + steps).rev() {
|
||||
if j == a_size + steps - 1 {
|
||||
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
|
||||
}
|
||||
}
|
||||
|
||||
// Avoids over flow of limbs of res
|
||||
let min_size: usize = res_size.min(a_size + steps);
|
||||
|
||||
// Zeroes lower limbs of res if a_size + steps < res_size
|
||||
(min_size..res_size).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
|
||||
// Continues with shifted normalization
|
||||
for j in (steps..min_size).rev() {
|
||||
// Case if no limb of a was previously discarded
|
||||
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
|
||||
ZNXARI::znx_normalize_first_step(
|
||||
basek,
|
||||
basek - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step(
|
||||
basek,
|
||||
basek - k_rem,
|
||||
res.at_mut(res_col, j),
|
||||
a.at(a_col, j - steps),
|
||||
carry,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagates carry on the rest of the limbs of res
|
||||
for j in (0..steps).rev() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
if j == 0 {
|
||||
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
} else {
|
||||
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let min_size: usize = res_size.min(a_size + steps);
|
||||
|
||||
// Zeroes the top
|
||||
(0..steps).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
|
||||
// Shift a into res, up to the maximum
|
||||
for j in (steps..min_size).rev() {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
|
||||
}
|
||||
|
||||
// Zeroes bottom if a_size + steps < res_size
|
||||
(min_size..res_size).for_each(|j| {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_lsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxLshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_lsh<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_lsh::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_rsh<B: Backend>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_rsh::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
res.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
layouts::{FillUniform, VecZnx, ZnxView},
|
||||
reference::{
|
||||
vec_znx::{
|
||||
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
|
||||
vec_znx_sub_ab_inplace,
|
||||
},
|
||||
znx::ZnxRef,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_lsh() {
|
||||
let n: usize = 8;
|
||||
let cols: usize = 2;
|
||||
let size: usize = 7;
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
for k in 0..256 {
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
|
||||
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
|
||||
}
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
|
||||
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
|
||||
}
|
||||
|
||||
assert_eq!(res_ref, res_test);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vec_znx_rsh() {
|
||||
let n: usize = 8;
|
||||
let cols: usize = 2;
|
||||
|
||||
let res_size: usize = 7;
|
||||
|
||||
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; n];
|
||||
|
||||
let basek: usize = 50;
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let zero: Vec<i64> = vec![0i64; n];
|
||||
|
||||
for a_size in [res_size - 1, res_size, res_size + 1] {
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
|
||||
|
||||
for k in 0..res_size * basek {
|
||||
a.fill_uniform(50, &mut source);
|
||||
|
||||
for i in 0..cols {
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
|
||||
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
|
||||
}
|
||||
|
||||
res_test.fill_uniform(50, &mut source);
|
||||
|
||||
for j in 0..cols {
|
||||
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
|
||||
}
|
||||
|
||||
for j in 0..cols {
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
|
||||
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
|
||||
}
|
||||
|
||||
// Case where res has enough to fully store a right shifted without any loss
|
||||
// In this case we can check exact equality.
|
||||
if a_size + k.div_ceil(basek) <= res_size {
|
||||
assert_eq!(res_ref, res_test);
|
||||
|
||||
for i in 0..cols {
|
||||
for j in 0..a_size {
|
||||
assert_eq!(res_ref.at(i, j), a.at(i, j), "r0 {} {}", i, j);
|
||||
assert_eq!(res_test.at(i, j), a.at(i, j), "r1 {} {}", i, j);
|
||||
}
|
||||
|
||||
for j in a_size..res_size {
|
||||
assert_eq!(res_ref.at(i, j), zero, "r0 {} {}", i, j);
|
||||
assert_eq!(res_test.at(i, j), zero, "r1 {} {}", i, j);
|
||||
}
|
||||
}
|
||||
// Some loss occures, either because a initially has more precision than res
|
||||
// or because the storage of the right shift of a requires more precision than
|
||||
// res.
|
||||
} else {
|
||||
for j in 0..cols {
|
||||
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
|
||||
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
|
||||
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
|
||||
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
|
||||
|
||||
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
|
||||
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
62
poulpy-hal/src/reference/vec_znx/split_ring.rs
Normal file
62
poulpy-hal/src/reference/vec_znx/split_ring.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxRotate, ZnxSwitchRing, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_split_ring_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn vec_znx_split_ring<R, A, ZNXARI>(res: &mut [R], res_col: usize, a: &A, a_col: usize, tmp: &mut [i64])
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSwitchRing + ZnxRotate + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let a_size = a.size();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(tmp.len(), a.n());
|
||||
|
||||
assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
assert!(n_in.is_multiple_of(n_out));
|
||||
assert_eq!(res.len(), n_in / n_out);
|
||||
}
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
let mut bi: VecZnx<&mut [u8]> = bi.to_mut();
|
||||
|
||||
let min_size = bi.size().min(a_size);
|
||||
|
||||
if i == 0 {
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
} else {
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_rotate(-(i as i64), tmp, a.at(a_col, j));
|
||||
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), tmp);
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..bi.size() {
|
||||
ZNXARI::znx_zero(bi.at_mut(res_col, j));
|
||||
}
|
||||
})
|
||||
}
|
||||
250
poulpy-hal/src/reference/vec_znx/sub.rs
Normal file
250
poulpy-hal/src/reference/vec_znx/sub.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
use std::hint::black_box;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
|
||||
use crate::{
|
||||
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
|
||||
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
|
||||
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn vec_znx_sub<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
let b_size: usize = b.size();
|
||||
|
||||
if a_size <= b_size {
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
let cpy_size: usize = b_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_negate(res.at_mut(res_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
} else {
|
||||
let sum_size: usize = b_size.min(res_size);
|
||||
let cpy_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..cpy_size {
|
||||
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in cpy_size..res_size {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSubABInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
let a_size: usize = a.size();
|
||||
|
||||
let sum_size: usize = a_size.min(res_size);
|
||||
|
||||
for j in 0..sum_size {
|
||||
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in sum_size..res_size {
|
||||
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSub + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub(&mut c, i, &a, i, &b, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
|
||||
where
|
||||
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
|
||||
{
|
||||
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
|
||||
|
||||
let mut group = c.benchmark_group(group_name);
|
||||
|
||||
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
|
||||
where
|
||||
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
|
||||
{
|
||||
let n: usize = 1 << params[0];
|
||||
let cols: usize = params[1];
|
||||
let size: usize = params[2];
|
||||
|
||||
let module: Module<B> = Module::<B>::new(n as u64);
|
||||
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
|
||||
|
||||
// Fill a with random i64
|
||||
a.fill_uniform(50, &mut source);
|
||||
b.fill_uniform(50, &mut source);
|
||||
|
||||
move || {
|
||||
for i in 0..cols {
|
||||
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
|
||||
}
|
||||
black_box(());
|
||||
}
|
||||
}
|
||||
|
||||
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
|
||||
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
|
||||
let mut runner = runner::<B>(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
58
poulpy-hal/src/reference/vec_znx/sub_scalar.rs
Normal file
58
poulpy-hal/src/reference/vec_znx/sub_scalar.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
|
||||
use crate::{
|
||||
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
|
||||
};
|
||||
|
||||
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
ZNXARI: ZnxSub + ZnxZero,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let min_size: usize = b.size().min(res.size());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
b_limb < min_size,
|
||||
"b_limb: {} > min_size: {}",
|
||||
b_limb,
|
||||
min_size
|
||||
);
|
||||
}
|
||||
|
||||
for j in 0..min_size {
|
||||
if j == b_limb {
|
||||
ZNXARI::znx_sub(res.at_mut(res_col, j), b.at(b_col, j), a.at(a_col, 0));
|
||||
} else {
|
||||
res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
ZNXARI: ZnxSubABInplace,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(res_limb < res.size());
|
||||
}
|
||||
|
||||
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
|
||||
}
|
||||
37
poulpy-hal/src/reference/vec_znx/switch_ring.rs
Normal file
37
poulpy-hal/src/reference/vec_znx/switch_ring.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::{
|
||||
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::{
|
||||
vec_znx::vec_znx_copy,
|
||||
znx::{ZnxCopy, ZnxSwitchRing, ZnxZero},
|
||||
},
|
||||
};
|
||||
|
||||
/// Maps between negacyclic rings by changing the polynomial degree.
|
||||
/// Up: Z[X]/(X^N+1) -> Z[X]/(X^{2^d N}+1) via X ↦ X^{2^d}
|
||||
/// Down: Z[X]/(X^N+1) -> Z[X]/(X^{N/2^d}+1) by folding indices.
|
||||
pub fn vec_znx_switch_ring<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxZero,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
|
||||
if n_in == n_out {
|
||||
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
|
||||
return;
|
||||
}
|
||||
|
||||
let min_size: usize = a.size().min(res.size());
|
||||
|
||||
for j in 0..min_size {
|
||||
ZNXARI::znx_switch_ring(res.at_mut(res_col, j), a.at(a_col, j));
|
||||
}
|
||||
|
||||
for j in min_size..res.size() {
|
||||
ZNXARI::znx_zero(res.at_mut(res_col, j));
|
||||
}
|
||||
}
|
||||
5
poulpy-hal/src/reference/zn/mod.rs
Normal file
5
poulpy-hal/src/reference/zn/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod normalization;
|
||||
mod sampling;
|
||||
|
||||
pub use normalization::*;
|
||||
pub use sampling::*;
|
||||
72
poulpy-hal/src/reference/zn/normalization.rs
Normal file
72
poulpy-hal/src/reference/zn/normalization.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::{
|
||||
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes},
|
||||
layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
|
||||
n * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn zn_normalize_inplace<R, ARI>(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
|
||||
where
|
||||
R: ZnToMut,
|
||||
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(carry.len(), res.n());
|
||||
}
|
||||
|
||||
let res_size: usize = res.size();
|
||||
|
||||
for j in (0..res_size).rev() {
|
||||
let out = &mut res.at_mut(res_col, j)[..n];
|
||||
|
||||
if j == res_size - 1 {
|
||||
ARI::znx_normalize_first_step_inplace(basek, 0, out, carry);
|
||||
} else if j == 0 {
|
||||
ARI::znx_normalize_final_step_inplace(basek, 0, out, carry);
|
||||
} else {
|
||||
ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_zn_normalize_inplace<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: ZnNormalizeInplace<B> + ZnNormalizeTmpBytes,
|
||||
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
|
||||
{
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let basek: usize = 12;
|
||||
|
||||
let n = 33;
|
||||
|
||||
let mut carry: Vec<i64> = vec![0i64; zn_normalize_tmp_bytes(n)];
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n()));
|
||||
|
||||
for res_size in [1, 2, 6, 11] {
|
||||
let mut res_0: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
|
||||
let mut res_1: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
|
||||
|
||||
res_0
|
||||
.raw_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = source.next_i32() as i64);
|
||||
res_1.raw_mut().copy_from_slice(res_0.raw());
|
||||
|
||||
// Reference
|
||||
for i in 0..cols {
|
||||
zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry);
|
||||
module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow());
|
||||
}
|
||||
|
||||
assert_eq!(res_0.raw(), res_1.raw());
|
||||
}
|
||||
}
|
||||
75
poulpy-hal/src/reference/zn/sampling.rs
Normal file
75
poulpy-hal/src/reference/zn/sampling.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use crate::{
|
||||
layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut},
|
||||
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn zn_fill_uniform<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
for j in 0..res.size() {
|
||||
znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn zn_fill_normal<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_fill_normal_f64_ref(
|
||||
&mut res.at_mut(res_col, limb)[..n],
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn zn_add_normal<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut res: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
|
||||
znx_add_normal_f64_ref(
|
||||
&mut res.at_mut(res_col, limb)[..n],
|
||||
sigma * scale,
|
||||
bound * scale,
|
||||
source,
|
||||
)
|
||||
}
|
||||
25
poulpy-hal/src/reference/znx/add.rs
Normal file
25
poulpy-hal/src/reference/znx/add.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
#[inline(always)]
|
||||
pub fn znx_add_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
assert_eq!(res.len(), b.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn znx_add_inplace_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] += a[i];
|
||||
}
|
||||
}
|
||||
153
poulpy-hal/src/reference/znx/arithmetic_ref.rs
Normal file
153
poulpy-hal/src/reference/znx/arithmetic_ref.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use crate::reference::znx::{
|
||||
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
|
||||
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
|
||||
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace,
|
||||
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero,
|
||||
add::{znx_add_inplace_ref, znx_add_ref},
|
||||
automorphism::znx_automorphism_ref,
|
||||
copy::znx_copy_ref,
|
||||
neg::{znx_negate_inplace_ref, znx_negate_ref},
|
||||
normalization::{
|
||||
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_carry_only_ref,
|
||||
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref,
|
||||
znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
|
||||
},
|
||||
sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref},
|
||||
switch_ring::znx_switch_ring_ref,
|
||||
zero::znx_zero_ref,
|
||||
};
|
||||
|
||||
pub struct ZnxRef {}
|
||||
|
||||
impl ZnxAdd for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_add_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAddInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_add_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSub for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
znx_sub_ref(res, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubABInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ab_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSubBAInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
|
||||
znx_sub_ba_inplace_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxAutomorphism for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
znx_automorphism_ref(p, res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxCopy for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]) {
|
||||
znx_copy_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegate for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]) {
|
||||
znx_negate_ref(res, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNegateInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_negate_inplace(res: &mut [i64]) {
|
||||
znx_negate_inplace_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxZero for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_zero(res: &mut [i64]) {
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxSwitchRing for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
|
||||
znx_switch_ring_ref(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStep for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFinalStepInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStep for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepCarryOnly for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeFirstStepInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStep for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
|
||||
impl ZnxNormalizeMiddleStepInplace for ZnxRef {
|
||||
#[inline(always)]
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
|
||||
}
|
||||
}
|
||||
21
poulpy-hal/src/reference/znx/automorphism.rs
Normal file
21
poulpy-hal/src/reference/znx/automorphism.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
let mut k: usize = 0usize;
|
||||
let mask: usize = 2 * n - 1;
|
||||
let p_2n = (p & mask as i64) as usize;
|
||||
|
||||
res[0] = a[0];
|
||||
for ai in a.iter().take(n).skip(1) {
|
||||
k = (k + p_2n) & mask;
|
||||
if k < n {
|
||||
res[k] = *ai
|
||||
} else {
|
||||
res[k - n] = -*ai
|
||||
}
|
||||
}
|
||||
}
|
||||
8
poulpy-hal/src/reference/znx/copy.rs
Normal file
8
poulpy-hal/src/reference/znx/copy.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
#[inline(always)]
|
||||
pub fn znx_copy_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len())
|
||||
}
|
||||
res.copy_from_slice(a);
|
||||
}
|
||||
104
poulpy-hal/src/reference/znx/mod.rs
Normal file
104
poulpy-hal/src/reference/znx/mod.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
mod add;
|
||||
mod arithmetic_ref;
|
||||
mod automorphism;
|
||||
mod copy;
|
||||
mod neg;
|
||||
mod normalization;
|
||||
mod rotate;
|
||||
mod sampling;
|
||||
mod sub;
|
||||
mod switch_ring;
|
||||
mod zero;
|
||||
|
||||
pub use add::*;
|
||||
pub use arithmetic_ref::*;
|
||||
pub use automorphism::*;
|
||||
pub use copy::*;
|
||||
pub use neg::*;
|
||||
pub use normalization::*;
|
||||
pub use rotate::*;
|
||||
pub use sub::*;
|
||||
pub use switch_ring::*;
|
||||
pub use zero::*;
|
||||
|
||||
pub use sampling::*;
|
||||
|
||||
pub trait ZnxAdd {
|
||||
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxAddInplace {
|
||||
fn znx_add_inplace(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSub {
|
||||
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSubABInplace {
|
||||
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSubBAInplace {
|
||||
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxAutomorphism {
|
||||
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxCopy {
|
||||
fn znx_copy(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNegate {
|
||||
fn znx_negate(res: &mut [i64], src: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNegateInplace {
|
||||
fn znx_negate_inplace(res: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxRotate {
|
||||
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxZero {
|
||||
fn znx_zero(res: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxSwitchRing {
|
||||
fn znx_switch_ring(res: &mut [i64], a: &[i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFirstStepCarryOnly {
|
||||
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFirstStepInplace {
|
||||
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFirstStep {
|
||||
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeMiddleStepCarryOnly {
|
||||
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeMiddleStepInplace {
|
||||
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeMiddleStep {
|
||||
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFinalStepInplace {
|
||||
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
|
||||
}
|
||||
|
||||
pub trait ZnxNormalizeFinalStep {
|
||||
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
|
||||
}
|
||||
18
poulpy-hal/src/reference/znx/neg.rs
Normal file
18
poulpy-hal/src/reference/znx/neg.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
#[inline(always)]
|
||||
pub fn znx_negate_ref(res: &mut [i64], src: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len())
|
||||
}
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = -src[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_negate_inplace_ref(res: &mut [i64]) {
|
||||
for value in res {
|
||||
*value = -*value
|
||||
}
|
||||
}
|
||||
199
poulpy-hal/src/reference/znx/normalization.rs
Normal file
199
poulpy-hal/src/reference/znx/normalization.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use itertools::izip;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_digit(basek: usize, x: i64) -> i64 {
|
||||
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
|
||||
(x.wrapping_sub(digit)) >> basek
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry(basek, *x, get_digit(basek, *x));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
*c = get_carry(basek, *x, digit);
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
*c = get_carry(basek_lsh, *x, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek, *a);
|
||||
*c = get_carry(basek, *a, digit);
|
||||
*x = digit;
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *a);
|
||||
*c = get_carry(basek_lsh, *a, digit);
|
||||
*x = digit << lsh;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
let carry: i64 = get_carry(basek, *x, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
let carry: i64 = get_carry(basek_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek, *x);
|
||||
let carry: i64 = get_carry(basek, *x, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *x);
|
||||
let carry: i64 = get_carry(basek_lsh, *x, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(x.len(), a.len());
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek, *a);
|
||||
let carry: i64 = get_carry(basek, *a, digit);
|
||||
let digit_plus_c: i64 = digit + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
let digit: i64 = get_digit(basek_lsh, *a);
|
||||
let carry: i64 = get_carry(basek_lsh, *a, digit);
|
||||
let digit_plus_c: i64 = (digit << lsh) + *c;
|
||||
*x = get_digit(basek, digit_plus_c);
|
||||
*c = carry + get_carry(basek, digit_plus_c, *x);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
|
||||
if lsh == 0 {
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit(basek, get_digit(basek, *x) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
|
||||
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(x.len() <= carry.len());
|
||||
assert!(lsh < basek);
|
||||
}
|
||||
if lsh == 0 {
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit(basek, get_digit(basek, *a) + *c);
|
||||
});
|
||||
} else {
|
||||
let basek_lsh: usize = basek - lsh;
|
||||
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
|
||||
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
|
||||
});
|
||||
}
|
||||
}
|
||||
26
poulpy-hal/src/reference/znx/rotate.rs
Normal file
26
poulpy-hal/src/reference/znx/rotate.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use crate::reference::znx::{ZnxCopy, ZnxNegate};
|
||||
|
||||
pub fn znx_rotate<ZNXARI: ZnxNegate + ZnxCopy>(p: i64, res: &mut [i64], src: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), src.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
|
||||
let mp_2n: usize = (p & (2 * n as i64 - 1)) as usize; // -p % 2n
|
||||
let mp_1n: usize = mp_2n & (n - 1); // -p % n
|
||||
let mp_1n_neg: usize = n - mp_1n; // p % n
|
||||
let neg_first: bool = mp_2n < n;
|
||||
|
||||
let (dst1, dst2) = res.split_at_mut(mp_1n);
|
||||
let (src1, src2) = src.split_at(mp_1n_neg);
|
||||
|
||||
if neg_first {
|
||||
ZNXARI::znx_negate(dst1, src2);
|
||||
ZNXARI::znx_copy(dst2, src1);
|
||||
} else {
|
||||
ZNXARI::znx_copy(dst1, src2);
|
||||
ZNXARI::znx_negate(dst2, src1);
|
||||
}
|
||||
}
|
||||
53
poulpy-hal/src/reference/znx/sampling.rs
Normal file
53
poulpy-hal/src/reference/znx/sampling.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use rand_distr::{Distribution, Normal};
|
||||
|
||||
use crate::source::Source;
|
||||
|
||||
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
|
||||
let pow2k: u64 = 1 << basek;
|
||||
let mask: u64 = pow2k - 1;
|
||||
let pow2k_half: i64 = (pow2k >> 1) as i64;
|
||||
res.iter_mut()
|
||||
.for_each(|xi| *xi = (source.next_u64n(pow2k, mask) as i64) - pow2k_half)
|
||||
}
|
||||
|
||||
pub fn znx_fill_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*xi = dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_add_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*xi += dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_fill_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
|
||||
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = normal.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = normal.sample(source)
|
||||
}
|
||||
*xi = dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
|
||||
pub fn znx_add_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
|
||||
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
|
||||
res.iter_mut().for_each(|xi| {
|
||||
let mut dist_f64: f64 = normal.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = normal.sample(source)
|
||||
}
|
||||
*xi += dist_f64.round() as i64
|
||||
})
|
||||
}
|
||||
36
poulpy-hal/src/reference/znx/sub.rs
Normal file
36
poulpy-hal/src/reference/znx/sub.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
assert_eq!(res.len(), b.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] = a[i] - b[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] -= a[i];
|
||||
}
|
||||
}
|
||||
|
||||
pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.len(), a.len());
|
||||
}
|
||||
|
||||
let n: usize = res.len();
|
||||
for i in 0..n {
|
||||
res[i] = a[i] - res[i];
|
||||
}
|
||||
}
|
||||
29
poulpy-hal/src/reference/znx/switch_ring.rs
Normal file
29
poulpy-hal/src/reference/znx/switch_ring.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::reference::znx::{copy::znx_copy_ref, zero::znx_zero_ref};
|
||||
|
||||
pub fn znx_switch_ring_ref(res: &mut [i64], a: &[i64]) {
|
||||
let (n_in, n_out) = (a.len(), res.len());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(n_in.is_power_of_two());
|
||||
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
|
||||
}
|
||||
|
||||
if n_in == n_out {
|
||||
znx_copy_ref(res, a);
|
||||
return;
|
||||
}
|
||||
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
znx_zero_ref(res);
|
||||
}
|
||||
|
||||
res.iter_mut()
|
||||
.step_by(gap_out)
|
||||
.zip(a.iter().step_by(gap_in))
|
||||
.for_each(|(x_out, x_in)| *x_out = *x_in);
|
||||
}
|
||||
3
poulpy-hal/src/reference/znx/zero.rs
Normal file
3
poulpy-hal/src/reference/znx/zero.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub fn znx_zero_ref(res: &mut [i64]) {
|
||||
res.fill(0);
|
||||
}
|
||||
Reference in New Issue
Block a user