Ref. + AVX code & generic tests + benches (#85)

This commit is contained in:
Jean-Philippe Bossuat
2025-09-15 16:16:11 +02:00
committed by GitHub
parent 99b9e3e10e
commit 56dbd29c59
286 changed files with 27797 additions and 7270 deletions

View File

@@ -0,0 +1,24 @@
pub mod reim;
pub mod reim4;
pub mod svp;
pub mod vec_znx_big;
pub mod vec_znx_dft;
pub mod vmp;
pub(crate) fn assert_approx_eq_slice(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(a.len(), b.len(), "Slices have different lengths");
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
let diff: f64 = (x - y).abs();
let scale: f64 = x.abs().max(y.abs()).max(1.0);
assert!(
diff <= tol * scale,
"Difference at index {}: left={} right={} rel_diff={} > tol={}",
i,
x,
y,
diff / scale,
tol
);
}
}

View File

@@ -0,0 +1,31 @@
#[inline(always)]
pub fn reim_from_znx_i64_ref(res: &mut [f64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
for i in 0..res.len() {
res[i] = a[i] as f64
}
}
#[inline(always)]
pub fn reim_to_znx_i64_ref(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let inv_div = 1. / divisor;
for i in 0..res.len() {
res[i] = (a[i] * inv_div).round() as i64
}
}
#[inline(always)]
pub fn reim_to_znx_i64_inplace_ref(res: &mut [f64], divisor: f64) {
let inv_div = 1. / divisor;
for ri in res {
*ri = f64::from_bits(((*ri * inv_div).round() as i64) as u64)
}
}

View File

@@ -0,0 +1,327 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
#[inline(always)]
pub fn fft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m <= 16 {
match m {
1 => {}
2 => fft2_ref(
as_arr_mut::<2, R>(re),
as_arr_mut::<2, R>(im),
*as_arr::<2, R>(omg),
),
4 => fft4_ref(
as_arr_mut::<4, R>(re),
as_arr_mut::<4, R>(im),
*as_arr::<4, R>(omg),
),
8 => fft8_ref(
as_arr_mut::<8, R>(re),
as_arr_mut::<8, R>(im),
*as_arr::<8, R>(omg),
),
16 => fft16_ref(
as_arr_mut::<16, R>(re),
as_arr_mut::<16, R>(im),
*as_arr::<16, R>(omg),
),
_ => {}
}
} else if m <= 2048 {
fft_bfs_16_ref(m, re, im, omg, 0);
} else {
fft_rec_16_ref(m, re, im, omg, 0);
}
}
#[inline(always)]
fn fft_rec_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
if m <= 2048 {
return fft_bfs_16_ref(m, re, im, omg, pos);
};
let h = m >> 1;
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
pos = fft_rec_16_ref(h, re, im, omg, pos);
pos = fft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
pos
}
#[inline(always)]
fn cplx_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let dr: R = *rb * omg_re - *ib * omg_im;
let di: R = *rb * omg_im + *ib * omg_re;
*rb = *ra - dr;
*ib = *ia - di;
*ra = *ra + dr;
*ia = *ia + di;
}
#[inline(always)]
fn cplx_i_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let dr: R = *rb * omg_im + *ib * omg_re;
let di: R = *rb * omg_re - *ib * omg_im;
*rb = *ra + dr;
*ib = *ia - di;
*ra = *ra - dr;
*ia = *ia + di;
}
#[inline(always)]
fn fft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
let [ra, rb] = re;
let [ia, ib] = im;
let [romg, iomg] = omg;
cplx_twiddle(ra, ia, rb, ib, romg, iomg);
}
#[inline(always)]
fn fft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
let [re_0, re_1, re_2, re_3] = re;
let [im_0, im_1, im_2, im_3] = im;
{
let omg_0 = omg[0];
let omg_1 = omg[1];
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
}
{
let omg_0 = omg[2];
let omg_1 = omg[3];
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
}
}
#[inline(always)]
fn fft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
{
let omg_0 = omg[0];
let omg_1 = omg[1];
cplx_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
cplx_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
cplx_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
}
{
let omg_2 = omg[2];
let omg_3 = omg[3];
cplx_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
cplx_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
}
{
let omg_4 = omg[4];
let omg_5 = omg[5];
let omg_6 = omg[6];
let omg_7 = omg[7];
cplx_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
cplx_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
}
}
#[inline(always)]
fn fft16_ref<R: Float + FloatConst + Debug>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
let [
re_0,
re_1,
re_2,
re_3,
re_4,
re_5,
re_6,
re_7,
re_8,
re_9,
re_10,
re_11,
re_12,
re_13,
re_14,
re_15,
] = re;
let [
im_0,
im_1,
im_2,
im_3,
im_4,
im_5,
im_6,
im_7,
im_8,
im_9,
im_10,
im_11,
im_12,
im_13,
im_14,
im_15,
] = im;
{
let omg_0: R = omg[0];
let omg_1: R = omg[1];
cplx_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
cplx_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
cplx_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
cplx_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
cplx_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
cplx_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
cplx_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
}
{
let omg_2: R = omg[2];
let omg_3: R = omg[3];
cplx_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
cplx_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
cplx_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
cplx_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
cplx_i_twiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
cplx_i_twiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
cplx_i_twiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
cplx_i_twiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
}
{
let omg_0: R = omg[4];
let omg_1: R = omg[5];
let omg_2: R = omg[6];
let omg_3: R = omg[7];
cplx_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
cplx_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
cplx_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
cplx_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
cplx_i_twiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
cplx_i_twiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
cplx_i_twiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
cplx_i_twiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
}
{
let omg_0: R = omg[8];
let omg_1: R = omg[9];
let omg_2: R = omg[10];
let omg_3: R = omg[11];
let omg_4: R = omg[12];
let omg_5: R = omg[13];
let omg_6: R = omg[14];
let omg_7: R = omg[15];
cplx_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
cplx_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
cplx_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
cplx_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
cplx_i_twiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
cplx_i_twiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
cplx_i_twiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
cplx_i_twiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
}
}
#[inline(always)]
fn fft_bfs_16_ref<R: Float + FloatConst + Debug>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut mm: usize = m;
if !log_m.is_multiple_of(2) {
let h: usize = mm >> 1;
twiddle_fft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
mm = h
}
while mm > 16 {
let h: usize = mm >> 2;
for off in (0..m).step_by(mm) {
bitwiddle_fft_ref(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, R>(&omg[pos..]),
);
pos += 4;
}
mm = h
}
for off in (0..m).step_by(16) {
fft16_ref(
as_arr_mut::<16, R>(&mut re[off..]),
as_arr_mut::<16, R>(&mut im[off..]),
*as_arr::<16, R>(&omg[pos..]),
);
pos += 16;
}
pos
}
#[inline(always)]
fn twiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
let romg = omg[0];
let iomg = omg[1];
let (re_lhs, re_rhs) = re.split_at_mut(h);
let (im_lhs, im_rhs) = im.split_at_mut(h);
for i in 0..h {
cplx_twiddle(
&mut re_lhs[i],
&mut im_lhs[i],
&mut re_rhs[i],
&mut im_rhs[i],
romg,
iomg,
);
}
}
#[inline(always)]
fn bitwiddle_fft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
let (r0, r2) = re.split_at_mut(2 * h);
let (r0, r1) = r0.split_at_mut(h);
let (r2, r3) = r2.split_at_mut(h);
let (i0, i2) = im.split_at_mut(2 * h);
let (i0, i1) = i0.split_at_mut(h);
let (i2, i3) = i2.split_at_mut(h);
let omg_0: R = omg[0];
let omg_1: R = omg[1];
let omg_2: R = omg[2];
let omg_3: R = omg[3];
for i in 0..h {
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_0, omg_1);
cplx_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
}
for i in 0..h {
cplx_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_2, omg_3);
cplx_i_twiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
}
}

View File

@@ -0,0 +1,156 @@
#[inline(always)]
pub fn reim_add_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
for i in 0..res.len() {
res[i] = a[i] + b[i]
}
}
#[inline(always)]
pub fn reim_add_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] += a[i]
}
}
#[inline(always)]
pub fn reim_sub_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
for i in 0..res.len() {
res[i] = a[i] - b[i]
}
}
#[inline(always)]
pub fn reim_sub_ab_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] -= a[i]
}
}
#[inline(always)]
pub fn reim_sub_ba_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] = a[i] - res[i]
}
}
#[inline(always)]
pub fn reim_negate_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
for i in 0..res.len() {
res[i] = -a[i]
}
}
#[inline(always)]
pub fn reim_negate_inplace_ref(res: &mut [f64]) {
for ri in res {
*ri = -*ri
}
}
#[inline(always)]
pub fn reim_addmul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
for i in 0..m {
let _ar: f64 = ar[i];
let _ai: f64 = ai[i];
let _br: f64 = br[i];
let _bi: f64 = bi[i];
let _rr: f64 = _ar * _br - _ai * _bi;
let _ri: f64 = _ar * _bi + _ai * _br;
rr[i] += _rr;
ri[i] += _ri;
}
}
#[inline(always)]
pub fn reim_mul_inplace_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
for i in 0..m {
let _ar: f64 = ar[i];
let _ai: f64 = ai[i];
let _br: f64 = rr[i];
let _bi: f64 = ri[i];
let _rr: f64 = _ar * _br - _ai * _bi;
let _ri: f64 = _ar * _bi + _ai * _br;
rr[i] = _rr;
ri[i] = _ri;
}
}
#[inline(always)]
pub fn reim_mul_ref(res: &mut [f64], a: &[f64], b: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), res.len());
assert_eq!(b.len(), res.len());
}
let m: usize = res.len() >> 1;
let (rr, ri) = res.split_at_mut(m);
let (ar, ai) = a.split_at(m);
let (br, bi) = b.split_at(m);
for i in 0..m {
let _ar: f64 = ar[i];
let _ai: f64 = ai[i];
let _br: f64 = br[i];
let _bi: f64 = bi[i];
let _rr: f64 = _ar * _br - _ai * _bi;
let _ri: f64 = _ar * _bi + _ai * _br;
rr[i] = _rr;
ri[i] = _ri;
}
}

View File

@@ -0,0 +1,322 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::reference::fft64::reim::{as_arr, as_arr_mut};
pub fn ifft_ref<R: Float + FloatConst + Debug>(m: usize, omg: &[R], data: &mut [R]) {
assert!(data.len() == 2 * m);
let (re, im) = data.split_at_mut(m);
if m <= 16 {
match m {
1 => {}
2 => ifft2_ref(
as_arr_mut::<2, R>(re),
as_arr_mut::<2, R>(im),
*as_arr::<2, R>(omg),
),
4 => ifft4_ref(
as_arr_mut::<4, R>(re),
as_arr_mut::<4, R>(im),
*as_arr::<4, R>(omg),
),
8 => ifft8_ref(
as_arr_mut::<8, R>(re),
as_arr_mut::<8, R>(im),
*as_arr::<8, R>(omg),
),
16 => ifft16_ref(
as_arr_mut::<16, R>(re),
as_arr_mut::<16, R>(im),
*as_arr::<16, R>(omg),
),
_ => {}
}
} else if m <= 2048 {
ifft_bfs_16_ref(m, re, im, omg, 0);
} else {
ifft_rec_16_ref(m, re, im, omg, 0);
}
}
#[inline(always)]
fn ifft_rec_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
if m <= 2048 {
return ifft_bfs_16_ref(m, re, im, omg, pos);
};
let h: usize = m >> 1;
pos = ifft_rec_16_ref(h, re, im, omg, pos);
pos = ifft_rec_16_ref(h, &mut re[h..], &mut im[h..], omg, pos);
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
pos
}
#[inline(always)]
fn ifft_bfs_16_ref<R: Float + FloatConst>(m: usize, re: &mut [R], im: &mut [R], omg: &[R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
for off in (0..m).step_by(16) {
ifft16_ref(
as_arr_mut::<16, R>(&mut re[off..]),
as_arr_mut::<16, R>(&mut im[off..]),
*as_arr::<16, R>(&omg[pos..]),
);
pos += 16;
}
let mut h: usize = 16;
let m_half: usize = m >> 1;
while h < m_half {
let mm: usize = h << 2;
for off in (0..m).step_by(mm) {
inv_bitwiddle_ifft_ref(
h,
&mut re[off..],
&mut im[off..],
as_arr::<4, R>(&omg[pos..]),
);
pos += 4;
}
h = mm;
}
if !log_m.is_multiple_of(2) {
inv_twiddle_ifft_ref(h, re, im, as_arr::<2, R>(&omg[pos..]));
pos += 2;
}
pos
}
#[inline(always)]
fn inv_twiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let r_diff: R = *ra - *rb;
let i_diff: R = *ia - *ib;
*ra = *ra + *rb;
*ia = *ia + *ib;
*rb = r_diff * omg_re - i_diff * omg_im;
*ib = r_diff * omg_im + i_diff * omg_re;
}
#[inline(always)]
fn inv_itwiddle<R: Float + FloatConst>(ra: &mut R, ia: &mut R, rb: &mut R, ib: &mut R, omg_re: R, omg_im: R) {
let r_diff: R = *ra - *rb;
let i_diff: R = *ia - *ib;
*ra = *ra + *rb;
*ia = *ia + *ib;
*rb = r_diff * omg_im + i_diff * omg_re;
*ib = -r_diff * omg_re + i_diff * omg_im;
}
#[inline(always)]
fn ifft2_ref<R: Float + FloatConst>(re: &mut [R; 2], im: &mut [R; 2], omg: [R; 2]) {
let [ra, rb] = re;
let [ia, ib] = im;
let [romg, iomg] = omg;
inv_twiddle(ra, ia, rb, ib, romg, iomg);
}
#[inline(always)]
fn ifft4_ref<R: Float + FloatConst>(re: &mut [R; 4], im: &mut [R; 4], omg: [R; 4]) {
let [re_0, re_1, re_2, re_3] = re;
let [im_0, im_1, im_2, im_3] = im;
{
let omg_0: R = omg[0];
let omg_1: R = omg[1];
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_1);
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_1);
}
{
let omg_0: R = omg[2];
let omg_1: R = omg[3];
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
}
}
#[inline(always)]
fn ifft8_ref<R: Float + FloatConst>(re: &mut [R; 8], im: &mut [R; 8], omg: [R; 8]) {
let [re_0, re_1, re_2, re_3, re_4, re_5, re_6, re_7] = re;
let [im_0, im_1, im_2, im_3, im_4, im_5, im_6, im_7] = im;
{
let omg_4: R = omg[0];
let omg_5: R = omg[1];
let omg_6: R = omg[2];
let omg_7: R = omg[3];
inv_twiddle(re_0, im_0, re_1, im_1, omg_4, omg_6);
inv_itwiddle(re_2, im_2, re_3, im_3, omg_4, omg_6);
inv_twiddle(re_4, im_4, re_5, im_5, omg_5, omg_7);
inv_itwiddle(re_6, im_6, re_7, im_7, omg_5, omg_7);
}
{
let omg_2: R = omg[4];
let omg_3: R = omg[5];
inv_twiddle(re_0, im_0, re_2, im_2, omg_2, omg_3);
inv_twiddle(re_1, im_1, re_3, im_3, omg_2, omg_3);
inv_itwiddle(re_4, im_4, re_6, im_6, omg_2, omg_3);
inv_itwiddle(re_5, im_5, re_7, im_7, omg_2, omg_3);
}
{
let omg_0: R = omg[6];
let omg_1: R = omg[7];
inv_twiddle(re_0, im_0, re_4, im_4, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_5, im_5, omg_0, omg_1);
inv_twiddle(re_2, im_2, re_6, im_6, omg_0, omg_1);
inv_twiddle(re_3, im_3, re_7, im_7, omg_0, omg_1);
}
}
#[inline(always)]
fn ifft16_ref<R: Float + FloatConst>(re: &mut [R; 16], im: &mut [R; 16], omg: [R; 16]) {
let [
re_0,
re_1,
re_2,
re_3,
re_4,
re_5,
re_6,
re_7,
re_8,
re_9,
re_10,
re_11,
re_12,
re_13,
re_14,
re_15,
] = re;
let [
im_0,
im_1,
im_2,
im_3,
im_4,
im_5,
im_6,
im_7,
im_8,
im_9,
im_10,
im_11,
im_12,
im_13,
im_14,
im_15,
] = im;
{
let omg_0: R = omg[0];
let omg_1: R = omg[1];
let omg_2: R = omg[2];
let omg_3: R = omg[3];
let omg_4: R = omg[4];
let omg_5: R = omg[5];
let omg_6: R = omg[6];
let omg_7: R = omg[7];
inv_twiddle(re_0, im_0, re_1, im_1, omg_0, omg_4);
inv_itwiddle(re_2, im_2, re_3, im_3, omg_0, omg_4);
inv_twiddle(re_4, im_4, re_5, im_5, omg_1, omg_5);
inv_itwiddle(re_6, im_6, re_7, im_7, omg_1, omg_5);
inv_twiddle(re_8, im_8, re_9, im_9, omg_2, omg_6);
inv_itwiddle(re_10, im_10, re_11, im_11, omg_2, omg_6);
inv_twiddle(re_12, im_12, re_13, im_13, omg_3, omg_7);
inv_itwiddle(re_14, im_14, re_15, im_15, omg_3, omg_7);
}
{
let omg_0: R = omg[8];
let omg_1: R = omg[9];
let omg_2: R = omg[10];
let omg_3: R = omg[11];
inv_twiddle(re_0, im_0, re_2, im_2, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_3, im_3, omg_0, omg_1);
inv_itwiddle(re_4, im_4, re_6, im_6, omg_0, omg_1);
inv_itwiddle(re_5, im_5, re_7, im_7, omg_0, omg_1);
inv_twiddle(re_8, im_8, re_10, im_10, omg_2, omg_3);
inv_twiddle(re_9, im_9, re_11, im_11, omg_2, omg_3);
inv_itwiddle(re_12, im_12, re_14, im_14, omg_2, omg_3);
inv_itwiddle(re_13, im_13, re_15, im_15, omg_2, omg_3);
}
{
let omg_2: R = omg[12];
let omg_3: R = omg[13];
inv_twiddle(re_0, im_0, re_4, im_4, omg_2, omg_3);
inv_twiddle(re_1, im_1, re_5, im_5, omg_2, omg_3);
inv_twiddle(re_2, im_2, re_6, im_6, omg_2, omg_3);
inv_twiddle(re_3, im_3, re_7, im_7, omg_2, omg_3);
inv_itwiddle(re_8, im_8, re_12, im_12, omg_2, omg_3);
inv_itwiddle(re_9, im_9, re_13, im_13, omg_2, omg_3);
inv_itwiddle(re_10, im_10, re_14, im_14, omg_2, omg_3);
inv_itwiddle(re_11, im_11, re_15, im_15, omg_2, omg_3);
}
{
let omg_0: R = omg[14];
let omg_1: R = omg[15];
inv_twiddle(re_0, im_0, re_8, im_8, omg_0, omg_1);
inv_twiddle(re_1, im_1, re_9, im_9, omg_0, omg_1);
inv_twiddle(re_2, im_2, re_10, im_10, omg_0, omg_1);
inv_twiddle(re_3, im_3, re_11, im_11, omg_0, omg_1);
inv_twiddle(re_4, im_4, re_12, im_12, omg_0, omg_1);
inv_twiddle(re_5, im_5, re_13, im_13, omg_0, omg_1);
inv_twiddle(re_6, im_6, re_14, im_14, omg_0, omg_1);
inv_twiddle(re_7, im_7, re_15, im_15, omg_0, omg_1);
}
}
#[inline(always)]
fn inv_twiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 2]) {
let romg = omg[0];
let iomg = omg[1];
let (re_lhs, re_rhs) = re.split_at_mut(h);
let (im_lhs, im_rhs) = im.split_at_mut(h);
for i in 0..h {
inv_twiddle(
&mut re_lhs[i],
&mut im_lhs[i],
&mut re_rhs[i],
&mut im_rhs[i],
romg,
iomg,
);
}
}
#[inline(always)]
fn inv_bitwiddle_ifft_ref<R: Float + FloatConst>(h: usize, re: &mut [R], im: &mut [R], omg: &[R; 4]) {
let (r0, r2) = re.split_at_mut(2 * h);
let (r0, r1) = r0.split_at_mut(h);
let (r2, r3) = r2.split_at_mut(h);
let (i0, i2) = im.split_at_mut(2 * h);
let (i0, i1) = i0.split_at_mut(h);
let (i2, i3) = i2.split_at_mut(h);
let omg_0: R = omg[0];
let omg_1: R = omg[1];
let omg_2: R = omg[2];
let omg_3: R = omg[3];
for i in 0..h {
inv_twiddle(&mut r0[i], &mut i0[i], &mut r1[i], &mut i1[i], omg_0, omg_1);
inv_itwiddle(&mut r2[i], &mut i2[i], &mut r3[i], &mut i3[i], omg_0, omg_1);
}
for i in 0..h {
inv_twiddle(&mut r0[i], &mut i0[i], &mut r2[i], &mut i2[i], omg_2, omg_3);
inv_twiddle(&mut r1[i], &mut i1[i], &mut r3[i], &mut i3[i], omg_2, omg_3);
}
}

View File

@@ -0,0 +1,128 @@
// ----------------------------------------------------------------------
// DISCLAIMER
//
// This module contains code that has been directly ported from the
// spqlios-arithmetic library
// (https://github.com/tfhe/spqlios-arithmetic), which is licensed
// under the Apache License, Version 2.0.
//
// The porting process from C to Rust was done with minimal changes
// in order to preserve the semantics and performance characteristics
// of the original implementation.
//
// Both Poulpy and spqlios-arithmetic are distributed under the terms
// of the Apache License, Version 2.0. See the LICENSE file for details.
//
// ----------------------------------------------------------------------
#![allow(bad_asm_style)]
mod conversion;
mod fft_ref;
mod fft_vec;
mod ifft_ref;
mod table_fft;
mod table_ifft;
mod zero;
pub use conversion::*;
pub use fft_ref::*;
pub use fft_vec::*;
pub use ifft_ref::*;
pub use table_fft::*;
pub use table_ifft::*;
pub use zero::*;
#[inline(always)]
pub(crate) fn as_arr<const size: usize, R: Float + FloatConst>(x: &[R]) -> &[R; size] {
debug_assert!(x.len() >= size);
unsafe { &*(x.as_ptr() as *const [R; size]) }
}
#[inline(always)]
pub(crate) fn as_arr_mut<const size: usize, R: Float + FloatConst>(x: &mut [R]) -> &mut [R; size] {
debug_assert!(x.len() >= size);
unsafe { &mut *(x.as_mut_ptr() as *mut [R; size]) }
}
use rand_distr::num_traits::{Float, FloatConst};
#[inline(always)]
pub(crate) fn frac_rev_bits<R: Float + FloatConst>(x: usize) -> R {
let half: R = R::from(0.5).unwrap();
match x {
0 => R::zero(),
1 => half,
_ => {
if x.is_multiple_of(2) {
frac_rev_bits::<R>(x >> 1) * half
} else {
frac_rev_bits::<R>(x >> 1) * half + half
}
}
}
}
pub trait ReimDFTExecute<D, T> {
fn reim_dft_execute(table: &D, data: &mut [T]);
}
pub trait ReimFromZnx {
fn reim_from_znx(res: &mut [f64], a: &[i64]);
}
pub trait ReimToZnx {
fn reim_to_znx(res: &mut [i64], divisor: f64, a: &[f64]);
}
pub trait ReimToZnxInplace {
fn reim_to_znx_inplace(res: &mut [f64], divisor: f64);
}
pub trait ReimAdd {
fn reim_add(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimAddInplace {
fn reim_add_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimSub {
fn reim_sub(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimSubABInplace {
fn reim_sub_ab_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimSubBAInplace {
fn reim_sub_ba_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimNegate {
fn reim_negate(res: &mut [f64], a: &[f64]);
}
pub trait ReimNegateInplace {
fn reim_negate_inplace(res: &mut [f64]);
}
pub trait ReimMul {
fn reim_mul(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimMulInplace {
fn reim_mul_inplace(res: &mut [f64], a: &[f64]);
}
pub trait ReimAddMul {
fn reim_addmul(res: &mut [f64], a: &[f64], b: &[f64]);
}
pub trait ReimCopy {
fn reim_copy(res: &mut [f64], a: &[f64]);
}
pub trait ReimZero {
fn reim_zero(res: &mut [f64]);
}

View File

@@ -0,0 +1,207 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::{
alloc_aligned,
reference::fft64::reim::{ReimDFTExecute, fft_ref, frac_rev_bits},
};
pub struct ReimFFTRef;
impl ReimDFTExecute<ReimFFTTable<f64>, f64> for ReimFFTRef {
fn reim_dft_execute(table: &ReimFFTTable<f64>, data: &mut [f64]) {
fft_ref(table.m, &table.omg, data);
}
}
pub struct ReimFFTTable<R: Float + FloatConst + Debug> {
m: usize,
omg: Vec<R>,
}
impl<R: Float + FloatConst + Debug + 'static> ReimFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::from(1. / 4.).unwrap();
if m <= 16 {
match m {
1 => {}
2 => {
fill_fft2_omegas(quarter, &mut omg, 0);
}
4 => {
fill_fft4_omegas(quarter, &mut omg, 0);
}
8 => {
fill_fft8_omegas(quarter, &mut omg, 0);
}
16 => {
fill_fft16_omegas(quarter, &mut omg, 0);
}
_ => {}
}
} else if m <= 2048 {
fill_fft_bfs_16_omegas(m, quarter, &mut omg, 0);
} else {
fill_fft_rec_16_omegas(m, quarter, &mut omg, 0);
}
Self { m, omg }
}
pub fn m(&self) -> usize {
self.m
}
pub fn omg(&self) -> &[R] {
&self.omg
}
}
#[inline(always)]
fn fill_fft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 2);
let angle: R = j / R::from(2).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle);
omg_pos[1] = R::sin(two_pi * angle);
pos + 2
}
#[inline(always)]
fn fill_fft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 4);
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_1);
omg_pos[1] = R::sin(two_pi * angle_1);
omg_pos[2] = R::cos(two_pi * angle_2);
omg_pos[3] = R::sin(two_pi * angle_2);
pos + 4
}
#[inline(always)]
fn fill_fft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 8);
let _8th: R = R::from(1. / 8.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(8).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_1);
omg_pos[1] = R::sin(two_pi * angle_1);
omg_pos[2] = R::cos(two_pi * angle_2);
omg_pos[3] = R::sin(two_pi * angle_2);
omg_pos[4] = R::cos(two_pi * angle_4);
omg_pos[5] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[6] = R::sin(two_pi * angle_4);
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
pos + 8
}
#[inline(always)]
fn fill_fft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 16);
let _8th: R = R::from(1. / 8.).unwrap();
let _16th: R = R::from(1. / 16.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(8).unwrap();
let angle_8: R = j / R::from(16).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_1);
omg_pos[1] = R::sin(two_pi * angle_1);
omg_pos[2] = R::cos(two_pi * angle_2);
omg_pos[3] = R::sin(two_pi * angle_2);
omg_pos[4] = R::cos(two_pi * angle_4);
omg_pos[5] = R::sin(two_pi * angle_4);
omg_pos[6] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[7] = R::sin(two_pi * (angle_4 + _8th));
omg_pos[8] = R::cos(two_pi * angle_8);
omg_pos[9] = R::cos(two_pi * (angle_8 + _8th));
omg_pos[10] = R::cos(two_pi * (angle_8 + _16th));
omg_pos[11] = R::cos(two_pi * (angle_8 + _8th + _16th));
omg_pos[12] = R::sin(two_pi * angle_8);
omg_pos[13] = R::sin(two_pi * (angle_8 + _8th));
omg_pos[14] = R::sin(two_pi * (angle_8 + _16th));
omg_pos[15] = R::sin(two_pi * (angle_8 + _8th + _16th));
pos + 16
}
#[inline(always)]
fn fill_fft_bfs_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut mm: usize = m;
let mut jj: R = j;
let two_pi: R = R::from(2).unwrap() * R::PI();
if !log_m.is_multiple_of(2) {
let h = mm >> 1;
let j: R = jj * R::from(0.5).unwrap();
omg[pos] = R::cos(two_pi * j);
omg[pos + 1] = R::sin(two_pi * j);
pos += 2;
mm = h;
jj = j
}
while mm > 16 {
let h: usize = mm >> 2;
let j: R = jj * R::from(1. / 4.).unwrap();
for i in (0..m).step_by(mm) {
let rs_0 = j + frac_rev_bits::<R>(i / mm) * R::from(1. / 4.).unwrap();
let rs_1 = R::from(2).unwrap() * rs_0;
omg[pos] = R::cos(two_pi * rs_1);
omg[pos + 1] = R::sin(two_pi * rs_1);
omg[pos + 2] = R::cos(two_pi * rs_0);
omg[pos + 3] = R::sin(two_pi * rs_0);
pos += 4;
}
mm = h;
jj = j;
}
for i in (0..m).step_by(16) {
let j = jj + frac_rev_bits(i >> 4);
fill_fft16_omegas(j, omg, pos);
pos += 16
}
pos
}
#[inline(always)]
fn fill_fft_rec_16_omegas<R: Float + FloatConst>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
if m <= 2048 {
return fill_fft_bfs_16_omegas(m, j, omg, pos);
}
let h: usize = m >> 1;
let s: R = j * R::from(0.5).unwrap();
let _2pi = R::from(2).unwrap() * R::PI();
omg[pos] = R::cos(_2pi * s);
omg[pos + 1] = R::sin(_2pi * s);
pos += 2;
pos = fill_fft_rec_16_omegas(h, s, omg, pos);
pos = fill_fft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
pos
}
#[inline(always)]
fn ctwiddle_ref(ra: &mut f64, ia: &mut f64, rb: &mut f64, ib: &mut f64, omg_re: f64, omg_im: f64) {
let dr: f64 = *rb * omg_re - *ib * omg_im;
let di: f64 = *rb * omg_im + *ib * omg_re;
*rb = *ra - dr;
*ib = *ia - di;
*ra += dr;
*ia += di;
}

View File

@@ -0,0 +1,201 @@
use std::fmt::Debug;
use rand_distr::num_traits::{Float, FloatConst};
use crate::{
alloc_aligned,
reference::fft64::reim::{ReimDFTExecute, frac_rev_bits, ifft_ref::ifft_ref},
};
pub struct ReimIFFTRef;
impl ReimDFTExecute<ReimIFFTTable<f64>, f64> for ReimIFFTRef {
fn reim_dft_execute(table: &ReimIFFTTable<f64>, data: &mut [f64]) {
ifft_ref(table.m, &table.omg, data);
}
}
pub struct ReimIFFTTable<R: Float + FloatConst + Debug> {
m: usize,
omg: Vec<R>,
}
impl<R: Float + FloatConst + Debug> ReimIFFTTable<R> {
pub fn new(m: usize) -> Self {
assert!(m & (m - 1) == 0, "m must be a power of two but is {}", m);
let mut omg: Vec<R> = alloc_aligned::<R>(2 * m);
let quarter: R = R::exp2(R::from(-2).unwrap());
if m <= 16 {
match m {
1 => {}
2 => {
fill_ifft2_omegas::<R>(quarter, &mut omg, 0);
}
4 => {
fill_ifft4_omegas(quarter, &mut omg, 0);
}
8 => {
fill_ifft8_omegas(quarter, &mut omg, 0);
}
16 => {
fill_ifft16_omegas(quarter, &mut omg, 0);
}
_ => {}
}
} else if m <= 2048 {
fill_ifft_bfs_16_omegas(m, quarter, &mut omg, 0);
} else {
fill_ifft_rec_16_omegas(m, quarter, &mut omg, 0);
}
Self { m, omg }
}
pub fn execute(&self, data: &mut [R]) {
ifft_ref(self.m, &self.omg, data);
}
pub fn m(&self) -> usize {
self.m
}
pub fn omg(&self) -> &[R] {
&self.omg
}
}
#[inline(always)]
fn fill_ifft2_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 2);
let angle: R = j / R::exp2(R::from(2).unwrap());
let two_pi: R = R::exp2(R::from(2).unwrap()) * R::PI();
omg_pos[0] = R::cos(two_pi * angle);
omg_pos[1] = -R::sin(two_pi * angle);
pos + 2
}
#[inline(always)]
fn fill_ifft4_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 4);
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_2);
omg_pos[1] = -R::sin(two_pi * angle_2);
omg_pos[2] = R::cos(two_pi * angle_1);
omg_pos[3] = -R::sin(two_pi * angle_1);
pos + 4
}
#[inline(always)]
fn fill_ifft8_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 8);
let _8th: R = R::from(1. / 8.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(2).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_4);
omg_pos[1] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[2] = -R::sin(two_pi * angle_4);
omg_pos[3] = -R::sin(two_pi * (angle_4 + _8th));
omg_pos[4] = R::cos(two_pi * angle_2);
omg_pos[5] = -R::sin(two_pi * angle_2);
omg_pos[6] = R::cos(two_pi * angle_1);
omg_pos[7] = -R::sin(two_pi * angle_1);
pos + 8
}
#[inline(always)]
fn fill_ifft16_omegas<R: Float + FloatConst>(j: R, omg: &mut [R], pos: usize) -> usize {
let omg_pos: &mut [R] = &mut omg[pos..];
assert!(omg_pos.len() >= 16);
let _8th: R = R::from(1. / 8.).unwrap();
let _16th: R = R::from(1. / 16.).unwrap();
let angle_1: R = j / R::from(2).unwrap();
let angle_2: R = j / R::from(4).unwrap();
let angle_4: R = j / R::from(8).unwrap();
let angle_8: R = j / R::from(16).unwrap();
let two_pi: R = R::from(2).unwrap() * R::PI();
omg_pos[0] = R::cos(two_pi * angle_8);
omg_pos[1] = R::cos(two_pi * (angle_8 + _8th));
omg_pos[2] = R::cos(two_pi * (angle_8 + _16th));
omg_pos[3] = R::cos(two_pi * (angle_8 + _8th + _16th));
omg_pos[4] = -R::sin(two_pi * angle_8);
omg_pos[5] = -R::sin(two_pi * (angle_8 + _8th));
omg_pos[6] = -R::sin(two_pi * (angle_8 + _16th));
omg_pos[7] = -R::sin(two_pi * (angle_8 + _8th + _16th));
omg_pos[8] = R::cos(two_pi * angle_4);
omg_pos[9] = -R::sin(two_pi * angle_4);
omg_pos[10] = R::cos(two_pi * (angle_4 + _8th));
omg_pos[11] = -R::sin(two_pi * (angle_4 + _8th));
omg_pos[12] = R::cos(two_pi * angle_2);
omg_pos[13] = -R::sin(two_pi * angle_2);
omg_pos[14] = R::cos(two_pi * angle_1);
omg_pos[15] = -R::sin(two_pi * angle_1);
pos + 16
}
#[inline(always)]
fn fill_ifft_bfs_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
let log_m: usize = (usize::BITS - (m - 1).leading_zeros()) as usize;
let mut jj: R = j * R::from(16).unwrap() / R::from(m).unwrap();
for i in (0..m).step_by(16) {
let j = jj + frac_rev_bits(i >> 4);
fill_ifft16_omegas(j, omg, pos);
pos += 16
}
let mut h: usize = 16;
let m_half: usize = m >> 1;
let two_pi: R = R::from(2).unwrap() * R::PI();
while h < m_half {
let mm: usize = h << 2;
for i in (0..m).step_by(mm) {
let rs_0 = jj + frac_rev_bits::<R>(i / mm) / R::from(4).unwrap();
let rs_1 = R::from(2).unwrap() * rs_0;
omg[pos] = R::cos(two_pi * rs_0);
omg[pos + 1] = -R::sin(two_pi * rs_0);
omg[pos + 2] = R::cos(two_pi * rs_1);
omg[pos + 3] = -R::sin(two_pi * rs_1);
pos += 4;
}
h = mm;
jj = jj * R::from(4).unwrap();
}
if !log_m.is_multiple_of(2) {
omg[pos] = R::cos(two_pi * jj);
omg[pos + 1] = -R::sin(two_pi * jj);
pos += 2;
jj = jj * R::from(2).unwrap();
}
assert_eq!(jj, j);
pos
}
#[inline(always)]
fn fill_ifft_rec_16_omegas<R: Float + FloatConst + Debug>(m: usize, j: R, omg: &mut [R], mut pos: usize) -> usize {
if m <= 2048 {
return fill_ifft_bfs_16_omegas(m, j, omg, pos);
}
let h: usize = m >> 1;
let s: R = j / R::from(2).unwrap();
pos = fill_ifft_rec_16_omegas(h, s, omg, pos);
pos = fill_ifft_rec_16_omegas(h, s + R::from(0.5).unwrap(), omg, pos);
let _2pi = R::from(2).unwrap() * R::PI();
omg[pos] = R::cos(_2pi * s);
omg[pos + 1] = -R::sin(_2pi * s);
pos += 2;
pos
}

View File

@@ -0,0 +1,11 @@
pub fn reim_zero_ref(res: &mut [f64]) {
res.fill(0.);
}
pub fn reim_copy_ref(res: &mut [f64], a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
res.copy_from_slice(a);
}

View File

@@ -0,0 +1,209 @@
use crate::reference::fft64::reim::as_arr;
#[inline(always)]
pub fn reim4_extract_1blk_from_reim_ref(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= 2 * rows * 4);
for chunk in dst.chunks_exact_mut(4).take(2 * rows) {
chunk.copy_from_slice(&src[offset..offset + 4]);
offset += m
}
}
#[inline(always)]
pub fn reim4_save_1blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= offset + m + 4);
debug_assert!(src.len() >= 8);
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[0..4]);
} else {
dst_off[0] += src[0];
dst_off[1] += src[1];
dst_off[2] += src[2];
dst_off[3] += src[3];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[4..8]);
} else {
dst_off[0] += src[4];
dst_off[1] += src[5];
dst_off[2] += src[6];
dst_off[3] += src[7];
}
}
#[inline(always)]
pub fn reim4_save_2blk_to_reim_ref<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
let mut offset: usize = blk << 2;
debug_assert!(blk < (m >> 2));
debug_assert!(dst.len() >= offset + 3 * m + 4);
debug_assert!(src.len() >= 16);
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[0..4]);
} else {
dst_off[0] += src[0];
dst_off[1] += src[1];
dst_off[2] += src[2];
dst_off[3] += src[3];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[4..8]);
} else {
dst_off[0] += src[4];
dst_off[1] += src[5];
dst_off[2] += src[6];
dst_off[3] += src[7];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[8..12]);
} else {
dst_off[0] += src[8];
dst_off[1] += src[9];
dst_off[2] += src[10];
dst_off[3] += src[11];
}
offset += m;
let dst_off = &mut dst[offset..offset + 4];
if OVERWRITE {
dst_off.copy_from_slice(&src[12..16]);
} else {
dst_off[0] += src[12];
dst_off[1] += src[13];
dst_off[2] += src[14];
dst_off[3] += src[15];
}
}
#[inline(always)]
pub fn reim4_vec_mat1col_product_ref(
nrows: usize,
dst: &mut [f64], // 8 doubles: [re1(4), im1(4)]
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
v: &[f64], // nrows * 8 doubles: [ar(4) | ai(4)] per row
) {
#[cfg(debug_assertions)]
{
assert!(dst.len() >= 8, "dst must have at least 8 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(v.len() >= nrows * 8, "v must be at least nrows * 8 doubles");
}
println!("u_ref: {:?}", &u[..nrows * 8]);
println!("v_ref: {:?}", &v[..nrows * 8]);
let mut acc: [f64; 8] = [0f64; 8];
let mut j = 0;
for _ in 0..nrows {
reim4_add_mul(&mut acc, as_arr(&u[j..]), as_arr(&v[j..]));
j += 8;
}
dst[0..8].copy_from_slice(&acc);
println!("dst_ref: {:?}", &dst[..8]);
println!();
}
#[inline(always)]
pub fn reim4_vec_mat2cols_product_ref(
nrows: usize,
dst: &mut [f64], // 16 doubles: [re1(4), im1(4), re2(4), im2(4)]
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
v: &[f64], // nrows * 16 doubles: [ar(4) | ai(4) | br(4) | bi(4)] per row
) {
#[cfg(debug_assertions)]
{
assert_eq!(dst.len(), 16, "dst must have 16 doubles");
assert!(u.len() >= nrows * 8, "u must be at least nrows * 8 doubles");
assert!(
v.len() >= nrows * 16,
"v must be at least nrows * 16 doubles"
);
}
// zero accumulators
let mut acc_0: [f64; 8] = [0f64; 8];
let mut acc_1: [f64; 8] = [0f64; 8];
for i in 0..nrows {
let _1j: usize = i << 3;
let _2j: usize = i << 4;
let u_j: &[f64; 8] = as_arr(&u[_1j..]);
reim4_add_mul(&mut acc_0, u_j, as_arr(&v[_2j..]));
reim4_add_mul(&mut acc_1, u_j, as_arr(&v[_2j + 8..]));
}
dst[0..8].copy_from_slice(&acc_0);
dst[8..16].copy_from_slice(&acc_1);
}
#[inline(always)]
pub fn reim4_vec_mat2cols_2ndcol_product_ref(
nrows: usize,
dst: &mut [f64], // 8 doubles: [re1(4), im1(4), re2(4), im2(4)]
u: &[f64], // nrows * 8 doubles: [ur(4) | ui(4)] per row
v: &[f64], // nrows * 16 doubles: [x | x | br(4) | bi(4)] per row
) {
#[cfg(debug_assertions)]
{
assert!(
dst.len() >= 8,
"dst must be at least 8 doubles but is {}",
dst.len()
);
assert!(
u.len() >= nrows * 8,
"u must be at least nrows={} * 8 doubles but is {}",
nrows,
u.len()
);
assert!(
v.len() >= nrows * 16,
"v must be at least nrows={} * 16 doubles but is {}",
nrows,
v.len()
);
}
// zero accumulators
let mut acc: [f64; 8] = [0f64; 8];
for i in 0..nrows {
let _1j: usize = i << 3;
let _2j: usize = i << 4;
reim4_add_mul(&mut acc, as_arr(&u[_1j..]), as_arr(&v[_2j + 8..]));
}
dst[0..8].copy_from_slice(&acc);
}
#[inline(always)]
pub fn reim4_add_mul(dst: &mut [f64; 8], a: &[f64; 8], b: &[f64; 8]) {
for k in 0..4 {
let ar: f64 = a[k];
let br: f64 = b[k];
let ai: f64 = a[k + 4];
let bi: f64 = b[k + 4];
dst[k] += ar * br - ai * bi;
dst[k + 4] += ar * bi + ai * br;
}
}

View File

@@ -0,0 +1,27 @@
mod arithmetic_ref;
pub use arithmetic_ref::*;
pub trait Reim4Extract1Blk {
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Save1Blk {
fn reim4_save_1blk<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Save2Blks {
fn reim4_save_2blks<const OVERWRITE: bool>(m: usize, blk: usize, dst: &mut [f64], src: &[f64]);
}
pub trait Reim4Mat1ColProd {
fn reim4_mat1col_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}
pub trait Reim4Mat2ColsProd {
fn reim4_mat2cols_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}
pub trait Reim4Mat2Cols2ndColProd {
fn reim4_mat2cols_2ndcol_prod(nrows: usize, dst: &mut [f64], u: &[f64], v: &[f64]);
}

View File

@@ -0,0 +1,119 @@
use crate::{
layouts::{
Backend, ScalarZnx, ScalarZnxToRef, SvpPPol, SvpPPolToMut, SvpPPolToRef, VecZnx, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut,
},
reference::fft64::reim::{ReimAddMul, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimMul, ReimMulInplace, ReimZero},
};
pub fn svp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx,
R: SvpPPolToMut<BE>,
A: ScalarZnxToRef,
{
let mut res: SvpPPol<&mut [u8], BE> = res.to_mut();
let a: ScalarZnx<&[u8]> = a.to_ref();
BE::reim_from_znx(res.at_mut(res_col, 0), a.at(a_col, 0));
BE::reim_dft_execute(table, res.at_mut(res_col, 0));
}
pub fn svp_apply_dft<R, A, B, BE>(
table: &ReimFFTTable<f64>,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
) where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimZero + ReimFromZnx + ReimMulInplace,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
B: VecZnxToRef,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let res_size: usize = res.size();
let b_size: usize = b.size();
let min_size: usize = res_size.min(b_size);
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..min_size {
let out: &mut [f64] = res.at_mut(res_col, j);
BE::reim_from_znx(out, b.at(b_col, j));
BE::reim_dft_execute(table, out);
BE::reim_mul_inplace(out, ppol);
}
for j in min_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
pub fn svp_apply_dft_to_dft<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimMul + ReimZero,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
let res_size: usize = res.size();
let b_size: usize = b.size();
let min_size: usize = res_size.min(b_size);
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..min_size {
BE::reim_mul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
}
for j in min_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
pub fn svp_apply_dft_to_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimAddMul + ReimZero,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
let res_size: usize = res.size();
let b_size: usize = b.size();
let min_size: usize = res_size.min(b_size);
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..min_size {
BE::reim_addmul(res.at_mut(res_col, j), ppol, b.at(b_col, j));
}
for j in min_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
pub fn svp_apply_dft_to_dft_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimMulInplace,
R: VecZnxDftToMut<BE>,
A: SvpPPolToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: SvpPPol<&[u8], BE> = a.to_ref();
let ppol: &[f64] = a.at(a_col, 0);
for j in 0..res.size() {
BE::reim_mul_inplace(res.at_mut(res_col, j), ppol);
}
}

View File

@@ -0,0 +1,521 @@
use std::f64::consts::SQRT_2;
use crate::{
api::VecZnxBigAddNormal,
layouts::{
Backend, Module, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxBigToRef, VecZnxToMut, VecZnxToRef, ZnxView, ZnxViewMut,
},
oep::VecZnxBigAllocBytesImpl,
reference::{
vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_automorphism, vec_znx_automorphism_inplace, vec_znx_negate,
vec_znx_negate_inplace, vec_znx_normalize, vec_znx_sub, vec_znx_sub_ab_inplace, vec_znx_sub_ba_inplace,
},
znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly,
ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero, znx_add_normal_f64_ref,
},
},
source::Source,
};
pub fn vec_znx_big_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let b: VecZnxBig<&[u8], BE> = b.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
let b_vznx: VecZnx<&[u8]> = VecZnx {
data: b.data,
n: b.n,
cols: b.cols,
size: b.size,
max_size: b.max_size,
};
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
}
pub fn vec_znx_big_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
pub fn vec_znx_big_add_small<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAdd + ZnxCopy + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_add::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
}
pub fn vec_znx_big_add_small_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAddInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_add_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
pub fn vec_znx_big_automorphism_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_big_automorphism<R, A, BE>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let a: VecZnxBig<&[u8], _> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_automorphism::<_, _, BE>(p, &mut res_vznx, res_col, &a_vznx, a_col);
}
pub fn vec_znx_big_automorphism_inplace<R, BE>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
BE: Backend<ScalarBig = i64> + ZnxAutomorphism + ZnxCopy,
R: VecZnxBigToMut<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_automorphism_inplace::<_, BE>(p, &mut res_vznx, res_col, tmp);
}
pub fn vec_znx_big_negate<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxNegate + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let a: VecZnxBig<&[u8], _> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_negate::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
pub fn vec_znx_big_negate_inplace<R, BE>(res: &mut R, res_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
{
let res: VecZnxBig<&mut [u8], _> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_negate_inplace::<_, BE>(&mut res_vznx, res_col);
}
pub fn vec_znx_big_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_big_normalize<R, A, BE>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxBigToRef<BE>,
BE: Backend<ScalarBig = i64>
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep
+ ZnxZero,
{
let a: VecZnxBig<&[u8], _> = a.to_ref();
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_normalize::<_, _, BE>(basek, res, res_col, &a_vznx, a_col, carry);
}
pub fn vec_znx_big_add_normal_ref<R, B: Backend<ScalarBig = i64>>(
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxBigToMut<B>,
{
let mut res: VecZnxBig<&mut [u8], B> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
}
pub fn test_vec_znx_big_add_normal<B>(module: &Module<B>)
where
Module<B>: VecZnxBigAddNormal<B>,
B: Backend<ScalarBig = i64> + VecZnxBigAllocBytesImpl<B>,
{
let n: usize = module.n();
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
let sqrt2: f64 = SQRT_2;
(0..cols).for_each(|col_i| {
let mut a: VecZnxBig<Vec<u8>, B> = VecZnxBig::alloc(n, cols, size);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
module.vec_znx_big_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(basek, col_i) * k_f64;
assert!(
(std - sigma * sqrt2).abs() < 0.1,
"std={} ~!= {}",
std,
sigma * sqrt2
);
}
})
});
}
/// R <- A - B
pub fn vec_znx_big_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let b: VecZnxBig<&[u8], BE> = b.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
let b_vznx: VecZnx<&[u8]> = VecZnx {
data: b.data,
n: b.n,
cols: b.cols,
size: b.size,
max_size: b.max_size,
};
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, &b_vznx, b_col);
}
/// R <- A - B
pub fn vec_znx_big_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- B - A
pub fn vec_znx_big_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col);
}
/// R <- A - B
pub fn vec_znx_big_sub_small_a<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
B: VecZnxBigToRef<BE>,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let b: VecZnxBig<&[u8], BE> = b.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let b_vznx: VecZnx<&[u8]> = VecZnx {
data: b.data,
n: b.n,
cols: b.cols,
size: b.size,
max_size: b.max_size,
};
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, a, a_col, &b_vznx, b_col);
}
/// R <- A - B
pub fn vec_znx_big_sub_small_b<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
R: VecZnxBigToMut<BE>,
A: VecZnxBigToRef<BE>,
B: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxBig<&[u8], BE> = a.to_ref();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
let a_vznx: VecZnx<&[u8]> = VecZnx {
data: a.data,
n: a.n,
cols: a.cols,
size: a.size,
max_size: a.max_size,
};
vec_znx_sub::<_, _, _, BE>(&mut res_vznx, res_col, &a_vznx, a_col, b, b_col);
}
/// R <- R - A
pub fn vec_znx_big_sub_small_a_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubABInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_sub_ab_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}
/// R <- A - R
pub fn vec_znx_big_sub_small_b_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarBig = i64> + ZnxSubBAInplace + ZnxNegateInplace,
R: VecZnxBigToMut<BE>,
A: VecZnxToRef,
{
let res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut res_vznx: VecZnx<&mut [u8]> = VecZnx {
data: res.data,
n: res.n,
cols: res.cols,
size: res.size,
max_size: res.max_size,
};
vec_znx_sub_ba_inplace::<_, _, BE>(&mut res_vznx, res_col, a, a_col);
}

View File

@@ -0,0 +1,369 @@
use bytemuck::cast_slice_mut;
use crate::{
layouts::{
Backend, Data, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos,
ZnxView, ZnxViewMut,
},
reference::{
fft64::reim::{
ReimAdd, ReimAddInplace, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimNegate,
ReimNegateInplace, ReimSub, ReimSubABInplace, ReimSubBAInplace, ReimToZnx, ReimToZnxInplace, ReimZero,
},
znx::ZnxZero,
},
};
pub fn vec_znx_dft_add<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimAdd + ReimCopy + ReimZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_copy(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_dft_add_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimAddInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_dft_copy<R, A, BE>(step: usize, offset: usize, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimCopy + ReimZero,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let steps: usize = a.size().div_ceil(step);
let min_steps: usize = res.size().min(steps);
(0..min_steps).for_each(|j| {
let limb: usize = offset + j * step;
if limb < a.size() {
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, limb));
}
});
(min_steps..res.size()).for_each(|j| {
BE::reim_zero(res.at_mut(res_col, j));
})
}
pub fn vec_znx_dft_apply<R, A, BE>(
table: &ReimFFTTable<f64>,
step: usize,
offset: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
) where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + ReimZero,
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(step > 0);
assert_eq!(table.m() << 1, res.n());
assert_eq!(a.n(), res.n());
}
let a_size: usize = a.size();
let res_size: usize = res.size();
let steps: usize = a_size.div_ceil(step);
let min_steps: usize = res_size.min(steps);
for j in 0..min_steps {
let limb = offset + j * step;
if limb < a_size {
BE::reim_from_znx(res.at_mut(res_col, j), a.at(a_col, limb));
BE::reim_dft_execute(table, res.at_mut(res_col, j));
}
}
(min_steps..res.size()).for_each(|j| {
BE::reim_zero(res.at_mut(res_col, j));
});
}
pub fn vec_znx_idft_apply<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64, ScalarBig = i64>
+ ReimDFTExecute<ReimIFFTTable<f64>, f64>
+ ReimCopy
+ ReimToZnxInplace
+ ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(table.m() << 1, res.n());
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let min_size: usize = res_size.min(a.size());
let divisor: f64 = table.m() as f64;
for j in 0..min_size {
let res_slice_f64: &mut [f64] = cast_slice_mut(res.at_mut(res_col, j));
BE::reim_copy(res_slice_f64, a.at(a_col, j));
BE::reim_dft_execute(table, res_slice_f64);
BE::reim_to_znx_inplace(res_slice_f64, divisor);
}
for j in min_size..res_size {
BE::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_idft_apply_tmpa<R, A, BE>(table: &ReimIFFTTable<f64>, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnx + ZnxZero,
R: VecZnxBigToMut<BE>,
A: VecZnxDftToMut<BE>,
{
let mut res: VecZnxBig<&mut [u8], BE> = res.to_mut();
let mut a: VecZnxDft<&mut [u8], BE> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(table.m() << 1, res.n());
assert_eq!(a.n(), res.n());
}
let res_size = res.size();
let min_size: usize = res_size.min(a.size());
let divisor: f64 = table.m() as f64;
for j in 0..min_size {
BE::reim_dft_execute(table, a.at_mut(a_col, j));
BE::reim_to_znx(res.at_mut(res_col, j), divisor, a.at(a_col, j));
}
for j in min_size..res_size {
BE::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_idft_apply_consume<D: Data, BE>(table: &ReimIFFTTable<f64>, mut res: VecZnxDft<D, BE>) -> VecZnxBig<D, BE>
where
BE: Backend<ScalarPrep = f64, ScalarBig = i64> + ReimDFTExecute<ReimIFFTTable<f64>, f64> + ReimToZnxInplace,
VecZnxDft<D, BE>: VecZnxDftToMut<BE>,
{
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(table.m() << 1, res.n());
}
let divisor: f64 = table.m() as f64;
for i in 0..res.cols() {
for j in 0..res.size() {
BE::reim_dft_execute(table, res.at_mut(i, j));
BE::reim_to_znx_inplace(res.at_mut(i, j), divisor);
}
}
}
res.into_big()
}
pub fn vec_znx_dft_sub<R, A, B, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSub + ReimNegate + ReimZero + ReimCopy,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
B: VecZnxDftToRef<BE>,
{
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let b: VecZnxDft<&[u8], BE> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_negate(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
BE::reim_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
BE::reim_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_dft_sub_ab_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubABInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_dft_sub_ba_inplace<R, A, BE>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
BE: Backend<ScalarPrep = f64> + ReimSubBAInplace + ReimNegateInplace,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
{
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
BE::reim_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
BE::reim_negate_inplace(res.at_mut(res_col, j));
}
}
pub fn vec_znx_dft_zero<R, BE>(res: &mut R)
where
R: VecZnxDftToMut<BE>,
BE: Backend<ScalarPrep = f64> + ReimZero,
{
BE::reim_zero(res.to_mut().raw_mut());
}

View File

@@ -0,0 +1,365 @@
use crate::{
cast_mut,
layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
oep::VecZnxDftAllocBytesImpl,
reference::fft64::{
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
reim4::{Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks},
vec_znx_dft::vec_znx_dft_apply,
},
};
use crate::layouts::{Backend, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VmpPMat, VmpPMatToRef, ZnxInfos};
pub fn vmp_prepare_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vmp_prepare<R, A, BE>(table: &ReimFFTTable<f64>, pmat: &mut R, mat: &A, tmp: &mut [f64])
where
BE: Backend<ScalarPrep = f64> + ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
R: VmpPMatToMut<BE>,
A: MatZnxToRef,
{
let mut res: crate::layouts::VmpPMat<&mut [u8], BE> = pmat.to_mut();
let a: MatZnx<&[u8]> = mat.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols_in(),
a.cols_in(),
"res.cols_in: {} != a.cols_in: {}",
res.cols_in(),
a.cols_in()
);
assert_eq!(
res.rows(),
a.rows(),
"res.rows: {} != a.rows: {}",
res.rows(),
a.rows()
);
assert_eq!(
res.cols_out(),
a.cols_out(),
"res.cols_out: {} != a.cols_out: {}",
res.cols_out(),
a.cols_out()
);
assert_eq!(
res.size(),
a.size(),
"res.size: {} != a.size: {}",
res.size(),
a.size()
);
}
let nrows: usize = a.cols_in() * a.rows();
let ncols: usize = a.cols_out() * a.size();
vmp_prepare_core::<BE>(table, res.raw_mut(), a.raw(), nrows, ncols, tmp);
}
pub(crate) fn vmp_prepare_core<REIM>(
table: &ReimFFTTable<f64>,
pmat: &mut [f64],
mat: &[i64],
nrows: usize,
ncols: usize,
tmp: &mut [f64],
) where
REIM: ReimDFTExecute<ReimFFTTable<f64>, f64> + ReimFromZnx + Reim4Extract1Blk,
{
let m: usize = table.m();
let n: usize = m << 1;
#[cfg(debug_assertions)]
{
assert!(n >= 8);
assert_eq!(mat.len(), n * nrows * ncols);
assert_eq!(pmat.len(), n * nrows * ncols);
assert_eq!(tmp.len(), vmp_prepare_tmp_bytes(n) / size_of::<i64>())
}
let offset: usize = nrows * ncols * 8;
for row_i in 0..nrows {
for col_i in 0..ncols {
let pos: usize = n * (row_i * ncols + col_i);
REIM::reim_from_znx(tmp, &mat[pos..pos + n]);
REIM::reim_dft_execute(table, tmp);
let dst: &mut [f64] = if col_i == (ncols - 1) && !ncols.is_multiple_of(2) {
&mut pmat[col_i * nrows * 8 + row_i * 8..]
} else {
&mut pmat[(col_i / 2) * (nrows * 16) + row_i * 16 + (col_i % 2) * 8..]
};
for blk_i in 0..m >> 2 {
REIM::reim4_extract_1blk(m, 1, blk_i, &mut dst[blk_i * offset..], tmp);
}
}
}
}
pub fn vmp_apply_dft_tmp_bytes(n: usize, a_size: usize, prows: usize, pcols_in: usize) -> usize {
let row_max: usize = (a_size).min(prows);
(16 + (n + 8) * row_max * pcols_in) * size_of::<f64>()
}
pub fn vmp_apply_dft<R, A, M, BE>(table: &ReimFFTTable<f64>, res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
where
BE: Backend<ScalarPrep = f64>
+ VecZnxDftAllocBytesImpl<BE>
+ ReimDFTExecute<ReimFFTTable<f64>, f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk
+ ReimFromZnx,
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
M: VmpPMatToRef<BE>,
{
let a: VecZnx<&[u8]> = a.to_ref();
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
let n: usize = a.n();
let cols: usize = pmat.cols_in();
let size: usize = a.size().min(pmat.rows());
#[cfg(debug_assertions)]
{
assert!(tmp_bytes.len() >= vmp_apply_dft_tmp_bytes(n, size, pmat.rows(), cols));
assert!(a.cols() <= cols);
}
let (data, tmp_bytes) = tmp_bytes.split_at_mut(BE::vec_znx_dft_alloc_bytes_impl(n, cols, size));
let mut a_dft: VecZnxDft<&mut [u8], BE> = VecZnxDft::from_data(cast_mut(data), n, cols, size);
let offset: usize = cols - a.cols();
for j in 0..cols {
vec_znx_dft_apply(table, 1, 0, &mut a_dft, j, &a, offset + j);
}
vmp_apply_dft_to_dft(res, &a_dft, &pmat, tmp_bytes);
}
pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usize) -> usize {
let row_max: usize = (a_size).min(prows);
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
}
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
M: VmpPMatToRef<BE>,
{
use crate::layouts::{ZnxView, ZnxViewMut};
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), pmat.n());
assert_eq!(a.n(), pmat.n());
assert_eq!(res.cols(), pmat.cols_out());
assert_eq!(a.cols(), pmat.cols_in());
}
let n: usize = res.n();
let nrows: usize = pmat.cols_in() * pmat.rows();
let ncols: usize = pmat.cols_out() * pmat.size();
let pmat_raw: &[f64] = pmat.raw();
let a_raw: &[f64] = a.raw();
let res_raw: &mut [f64] = res.raw_mut();
vmp_apply_dft_to_dft_core::<true, BE>(n, res_raw, a_raw, pmat_raw, 0, nrows, ncols, tmp_bytes)
}
pub fn vmp_apply_dft_to_dft_add<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, limb_offset: usize, tmp_bytes: &mut [f64])
where
BE: Backend<ScalarPrep = f64>
+ ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk,
R: VecZnxDftToMut<BE>,
A: VecZnxDftToRef<BE>,
M: VmpPMatToRef<BE>,
{
use crate::layouts::{ZnxView, ZnxViewMut};
let mut res: VecZnxDft<&mut [u8], BE> = res.to_mut();
let a: VecZnxDft<&[u8], BE> = a.to_ref();
let pmat: VmpPMat<&[u8], BE> = pmat.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), pmat.n());
assert_eq!(a.n(), pmat.n());
assert_eq!(res.cols(), pmat.cols_out());
assert_eq!(a.cols(), pmat.cols_in());
}
let n: usize = res.n();
let nrows: usize = pmat.cols_in() * pmat.rows();
let ncols: usize = pmat.cols_out() * pmat.size();
let pmat_raw: &[f64] = pmat.raw();
let a_raw: &[f64] = a.raw();
let res_raw: &mut [f64] = res.raw_mut();
vmp_apply_dft_to_dft_core::<false, BE>(
n,
res_raw,
a_raw,
pmat_raw,
limb_offset,
nrows,
ncols,
tmp_bytes,
)
}
#[allow(clippy::too_many_arguments)]
fn vmp_apply_dft_to_dft_core<const OVERWRITE: bool, REIM>(
n: usize,
res: &mut [f64],
a: &[f64],
pmat: &[f64],
limb_offset: usize,
nrows: usize,
ncols: usize,
tmp_bytes: &mut [f64],
) where
REIM: ReimZero
+ Reim4Extract1Blk
+ Reim4Mat1ColProd
+ Reim4Mat2Cols2ndColProd
+ Reim4Mat2ColsProd
+ Reim4Save2Blks
+ Reim4Save1Blk,
{
#[cfg(debug_assertions)]
{
assert!(n >= 8);
assert!(n.is_power_of_two());
assert_eq!(pmat.len(), n * nrows * ncols);
assert!(res.len() & (n - 1) == 0);
assert!(a.len() & (n - 1) == 0);
}
let a_size: usize = a.len() / n;
let res_size: usize = res.len() / n;
let m: usize = n >> 1;
let (mat2cols_output, extracted_blk) = tmp_bytes.split_at_mut(16);
let row_max: usize = nrows.min(a_size);
let col_max: usize = ncols.min(res_size);
if limb_offset >= col_max {
if OVERWRITE {
REIM::reim_zero(res);
}
return;
}
for blk_i in 0..(m >> 2) {
let mat_blk_start: &[f64] = &pmat[blk_i * (8 * nrows * ncols)..];
REIM::reim4_extract_1blk(m, row_max, blk_i, extracted_blk, a);
if limb_offset.is_multiple_of(2) {
for (col_res, col_pmat) in (0..).step_by(2).zip((limb_offset..col_max - 1).step_by(2)) {
let col_offset: usize = col_pmat * (8 * nrows);
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
}
} else {
let col_offset: usize = (limb_offset - 1) * (8 * nrows);
REIM::reim4_mat2cols_2ndcol_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_save_1blk::<OVERWRITE>(m, blk_i, res, mat2cols_output);
for (col_res, col_pmat) in (1..)
.step_by(2)
.zip((limb_offset + 1..col_max - 1).step_by(2))
{
let col_offset: usize = col_pmat * (8 * nrows);
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
REIM::reim4_save_2blks::<OVERWRITE>(m, blk_i, &mut res[col_res * n..], mat2cols_output);
}
}
if !col_max.is_multiple_of(2) {
let last_col: usize = col_max - 1;
let col_offset: usize = last_col * (8 * nrows);
if last_col >= limb_offset {
if ncols == col_max {
REIM::reim4_mat1col_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
} else {
REIM::reim4_mat2cols_prod(
row_max,
mat2cols_output,
extracted_blk,
&mat_blk_start[col_offset..],
);
}
REIM::reim4_save_1blk::<OVERWRITE>(
m,
blk_i,
&mut res[(last_col - limb_offset) * n..],
mat2cols_output,
);
}
}
}
REIM::reim_zero(&mut res[col_max * n..]);
}

View File

@@ -0,0 +1,4 @@
pub mod fft64;
pub mod vec_znx;
pub mod zn;
pub mod znx;

View File

@@ -0,0 +1,177 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxAdd, VecZnxAddInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
source::Source,
};
pub fn vec_znx_add<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_add_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxAddInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_add_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn bench_vec_znx_add<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAdd + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_add(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_add_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_add_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAddInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_add_inplace(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,57 @@
use crate::{
layouts::{ScalarZnx, ScalarZnxToRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxAdd, ZnxAddInplace, ZnxCopy, ZnxZero},
};
pub fn vec_znx_add_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxAdd + ZnxCopy + ZnxZero,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let min_size: usize = b.size().min(res.size());
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
}
for j in 0..min_size {
if j == b_limb {
ZNXARI::znx_add(res.at_mut(res_col, j), a.at(a_col, 0), b.at(b_col, j));
} else {
ZNXARI::znx_copy(res.at_mut(res_col, j), b.at(b_col, j));
}
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_add_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxAddInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(res_limb < res.size());
}
ZNXARI::znx_add_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}

View File

@@ -0,0 +1,150 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAutomorphism, VecZnxAutomorphismInplace,
VecZnxAutomorphismInplaceTmpBytes,
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxAutomorphism, ZnxCopy, ZnxZero},
source::Source,
};
pub fn vec_znx_automorphism_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_automorphism<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxAutomorphism + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
use crate::layouts::ZnxInfos;
assert_eq!(a.n(), res.n());
}
let min_size: usize = res.size().min(a.size());
for j in 0..min_size {
ZNXARI::znx_automorphism(p, res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_automorphism_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxAutomorphism + ZnxCopy,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), tmp.len());
}
for j in 0..res.size() {
ZNXARI::znx_automorphism(p, tmp, res.at(res_col, j));
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
}
}
pub fn bench_vec_znx_automorphism<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_automorphism::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAutomorphism + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_automorphism(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_automorphism_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxAutomorphismInplace<B> + VecZnxAutomorphismInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_automorphism_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxAutomorphismInplace<B> + ModuleNew<B> + VecZnxAutomorphismInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch = ScratchOwned::alloc(module.vec_znx_automorphism_inplace_tmp_bytes());
// Fill a with random i64
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_automorphism_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,32 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxCopy, ZnxZero},
};
pub fn vec_znx_copy<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxCopy + ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let res_size = res.size();
let a_size = a.size();
let min_size = res_size.min(a_size);
for j in 0..min_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}

View File

@@ -0,0 +1,49 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
reference::{
vec_znx::{vec_znx_rotate_inplace, vec_znx_switch_ring},
znx::{ZnxCopy, ZnxRotate, ZnxSwitchRing, ZnxZero},
},
};
pub fn vec_znx_merge_rings_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_merge_rings<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &[A], a_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxRotate + ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_out, n_in) = (res.n(), a[0].to_ref().n());
#[cfg(debug_assertions)]
{
assert_eq!(tmp.len(), res.n());
debug_assert!(
n_out > n_in,
"invalid a: output ring degree should be greater"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.to_ref().n(),
n_in,
"invalid input a: all VecZnx must have the same degree"
)
});
assert!(n_out.is_multiple_of(n_in));
assert_eq!(a.len(), n_out / n_in);
}
a.iter().for_each(|ai| {
vec_znx_switch_ring::<_, _, ZNXARI>(&mut res, res_col, ai, a_col);
vec_znx_rotate_inplace::<_, ZNXARI>(-1, &mut res, res_col, tmp);
});
vec_znx_rotate_inplace::<_, ZNXARI>(a.len() as i64, &mut res, res_col, tmp);
}

View File

@@ -0,0 +1,31 @@
mod add;
mod add_scalar;
mod automorphism;
mod copy;
mod merge_rings;
mod mul_xp_minus_one;
mod negate;
mod normalize;
mod rotate;
mod sampling;
mod shift;
mod split_ring;
mod sub;
mod sub_scalar;
mod switch_ring;
pub use add::*;
pub use add_scalar::*;
pub use automorphism::*;
pub use copy::*;
pub use merge_rings::*;
pub use mul_xp_minus_one::*;
pub use negate::*;
pub use normalize::*;
pub use rotate::*;
pub use sampling::*;
pub use shift::*;
pub use split_ring::*;
pub use sub::*;
pub use sub_scalar::*;
pub use switch_ring::*;

View File

@@ -0,0 +1,136 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace,
VecZnxMulXpMinusOneInplaceTmpBytes,
},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::{vec_znx_rotate, vec_znx_sub_ab_inplace},
znx::{ZnxNegate, ZnxRotate, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
},
source::Source,
};
pub fn vec_znx_mul_xp_minus_one_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_mul_xp_minus_one<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero + ZnxSubABInplace,
{
vec_znx_rotate::<_, _, ZNXARI>(p, res, res_col, a, a_col);
vec_znx_sub_ab_inplace::<_, _, ZNXARI>(res, res_col, a, a_col);
}
pub fn vec_znx_mul_xp_minus_one_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxNegate + ZnxSubBAInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), tmp.len());
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), tmp);
}
}
pub fn bench_vec_znx_mul_xp_minus_one<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxMulXpMinusOne + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_mul_xp_minus_one(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_mul_xp_minus_one_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxMulXpMinusOneInplace<B> + VecZnxMulXpMinusOneInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_mul_xp_minus_one_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxMulXpMinusOneInplace<B> + ModuleNew<B> + VecZnxMulXpMinusOneInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch = ScratchOwned::alloc(module.vec_znx_mul_xp_minus_one_inplace_tmp_bytes());
// Fill a with random i64
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_mul_xp_minus_one_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,131 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxNegate, VecZnxNegateInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxNegate, ZnxNegateInplace, ZnxZero},
source::Source,
};
pub fn vec_znx_negate<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxNegate + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let min_size: usize = res.size().min(a.size());
for j in 0..min_size {
ZNXARI::znx_negate(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_negate_inplace<R, ZNXARI>(res: &mut R, res_col: usize)
where
R: VecZnxToMut,
ZNXARI: ZnxNegateInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
}
}
pub fn bench_vec_znx_negate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNegate + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_negate(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_negate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_negate_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNegateInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_negate_inplace(&mut a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,193 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{
ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
},
source::Source,
};
pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_normalize<R, A, ZNXARI>(basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeFinalStep
+ ZnxNormalizeFirstStep,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
}
let res_size: usize = res.size();
let a_size = a.size();
if a_size > res_size {
for j in (res_size..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, 0, a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, 0, a.at(a_col, j), carry);
}
}
for j in (1..res_size).rev() {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
} else {
for j in (0..a_size).rev() {
if j == a_size - 1 {
ZNXARI::znx_normalize_first_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step(basek, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
}
}
for j in a_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(carry.len() >= res.n());
}
let res_size: usize = res.size();
for j in (0..res_size).rev() {
if j == res_size - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, 0, res.at_mut(res_col, j), carry);
}
}
}
pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_normalize(basek, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_normalize_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
move || {
for i in 0..cols {
module.vec_znx_normalize_inplace(basek, &mut a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,148 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxCopy, ZnxRotate, ZnxZero},
source::Source,
};
pub fn vec_znx_rotate_inplace_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rotate<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxRotate + ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let min_size: usize = res_size.min(a_size);
for j in 0..min_size {
ZNXARI::znx_rotate(p, res.at_mut(res_col, j), a.at(a_col, j))
}
for j in min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_rotate_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxRotate + ZnxCopy,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), tmp.len());
}
for j in 0..res.size() {
ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
}
}
pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let group_name: String = format!("vec_znx_rotate::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRotate + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rotate(-7, &mut res, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_rotate_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rotate_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRotateInplace<B> + ModuleNew<B> + VecZnxRotateInplaceTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes());
// Fill a with random i64
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rotate_inplace(-7, &mut res, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,64 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut},
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
source::Source,
};
pub fn vec_znx_fill_uniform_ref<R>(basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, res.at_mut(res_col, j), source)
}
}
pub fn vec_znx_fill_normal_ref<R>(
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
sigma: f64,
bound: f64,
source: &mut Source,
) where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_fill_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
}
pub fn vec_znx_add_normal_ref<R>(basek: usize, res: &mut R, res_col: usize, k: usize, sigma: f64, bound: f64, source: &mut Source)
where
R: VecZnxToMut,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_add_normal_f64_ref(
res.at_mut(res_col, limb),
sigma * scale,
bound * scale,
source,
)
}

View File

@@ -0,0 +1,672 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::vec_znx_copy,
znx::{
ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
ZnxZero,
},
},
source::Source,
};
pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_lsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
if steps >= size {
for j in 0..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
// Inplace shift of limbs by a k/basek
if steps > 0 {
let start: usize = n * res_col;
let end: usize = start + n;
let slice_size: usize = n * cols;
let res_raw: &mut [i64] = res.raw_mut();
(0..size - steps).for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
ZNXARI::znx_copy(
&mut lhs[start + j * slice_size..end + j * slice_size],
&rhs[start..end],
);
});
for j in size - steps..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
// Inplace normalization with left shift of k % basek
if !k.is_multiple_of(basek) {
for j in (0..size - steps).rev() {
if j == size - steps - 1 {
ZNXARI::znx_normalize_first_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
} else if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, k_rem, res.at_mut(res_col, j), carry);
}
}
}
}
pub fn vec_znx_lsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero + ZnxNormalizeFirstStep + ZnxNormalizeMiddleStep + ZnxNormalizeFirstStep + ZnxCopy + ZnxNormalizeFinalStep,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let res_size: usize = res.size();
let a_size = a.size();
let steps: usize = k / basek;
let k_rem: usize = k % basek;
if steps >= res_size.min(a_size) {
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
let min_size: usize = a_size.min(res_size) - steps;
// Simply a left shifted normalization of limbs
// by k/basek and intra-limb by basek - k%basek
if !k.is_multiple_of(basek) {
for j in (0..min_size).rev() {
if j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
} else if j == 0 {
ZNXARI::znx_normalize_final_step(
basek,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
k_rem,
res.at_mut(res_col, j),
a.at(a_col, j + steps),
carry,
);
}
}
} else {
// If k % basek = 0, then this is simply a copy.
for j in (0..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
}
}
// Zeroes bottom
for j in min_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_rsh_inplace<R, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let n: usize = res.n();
let cols: usize = res.cols();
let size: usize = res.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
if k == 0 {
return;
}
if steps >= size {
for j in 0..size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
let start: usize = n * res_col;
let end: usize = start + n;
let slice_size: usize = n * cols;
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
// All limbs of a that would fall outside of the limbs of res are discarded,
// but the carry still need to be computed.
(size - steps..size).rev().for_each(|j| {
if j == size - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, res.at(res_col, j), carry);
}
});
// Continues with shifted normalization
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
let rhs_slice: &mut [i64] = &mut rhs[start..end];
let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
ZNXARI::znx_normalize_middle_step(basek, basek - k_rem, rhs_slice, lhs_slice, carry);
});
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
// Shift by multiples of basek
let res_raw: &mut [i64] = res.raw_mut();
(steps..size).rev().for_each(|j| {
let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
ZNXARI::znx_copy(
&mut rhs[start..end],
&lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
);
});
// Zeroes the top
(0..steps).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
}
}
pub fn vec_znx_rsh<R, A, ZNXARI>(basek: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxZero
+ ZnxCopy
+ ZnxNormalizeFirstStepCarryOnly
+ ZnxNormalizeMiddleStepCarryOnly
+ ZnxNormalizeFirstStep
+ ZnxNormalizeMiddleStep
+ ZnxNormalizeMiddleStepInplace
+ ZnxNormalizeFirstStepInplace
+ ZnxNormalizeFinalStepInplace,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: VecZnx<&[u8]> = a.to_ref();
let res_size: usize = res.size();
let a_size: usize = a.size();
let mut steps: usize = k / basek;
let k_rem: usize = k % basek;
if k == 0 {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
return;
}
if steps >= res_size {
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
return;
}
if !k.is_multiple_of(basek) {
// We rsh by an additional basek and then lsh by basek-k
// Allows to re-use efficient normalization code, avoids
// avoids overflows & produce output that is normalized
steps += 1;
// All limbs of a that are moved outside of the limbs of res are discarded,
// but the carry still need to be computed.
for j in (res_size..a_size + steps).rev() {
if j == a_size + steps - 1 {
ZNXARI::znx_normalize_first_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
} else {
ZNXARI::znx_normalize_middle_step_carry_only(basek, basek - k_rem, a.at(a_col, j - steps), carry);
}
}
// Avoids over flow of limbs of res
let min_size: usize = res_size.min(a_size + steps);
// Zeroes lower limbs of res if a_size + steps < res_size
(min_size..res_size).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Continues with shifted normalization
for j in (steps..min_size).rev() {
// Case if no limb of a was previously discarded
if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
ZNXARI::znx_normalize_first_step(
basek,
basek - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
} else {
ZNXARI::znx_normalize_middle_step(
basek,
basek - k_rem,
res.at_mut(res_col, j),
a.at(a_col, j - steps),
carry,
);
}
}
// Propagates carry on the rest of the limbs of res
for j in (0..steps).rev() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
if j == 0 {
ZNXARI::znx_normalize_final_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
} else {
ZNXARI::znx_normalize_middle_step_inplace(basek, basek - k_rem, res.at_mut(res_col, j), carry);
}
}
} else {
let min_size: usize = res_size.min(a_size + steps);
// Zeroes the top
(0..steps).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
// Shift a into res, up to the maximum
for j in (steps..min_size).rev() {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
}
// Zeroes bottom if a_size + steps < res_size
(min_size..res_size).for_each(|j| {
ZNXARI::znx_zero(res.at_mut(res_col, j));
});
}
}
pub fn bench_vec_znx_lsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxLshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_lsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_lsh<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_lsh::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxLsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_lsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_rsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rsh_inplace(basek, basek - 1, &mut b, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_rsh<B: Backend>(c: &mut Criterion, label: &str)
where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let group_name: String = format!("vec_znx_rsh::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxRsh<B> + ModuleNew<B>,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
// Fill a with random i64
a.fill_uniform(50, &mut source);
res.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_rsh(basek, basek - 1, &mut res, i, &a, i, scratch.borrow());
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
#[cfg(test)]
mod tests {
use crate::{
layouts::{FillUniform, VecZnx, ZnxView},
reference::{
vec_znx::{
vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
vec_znx_sub_ab_inplace,
},
znx::ZnxRef,
},
source::Source,
};
#[test]
fn test_vec_znx_lsh() {
let n: usize = 8;
let cols: usize = 2;
let size: usize = 7;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut source: Source = Source::new([0u8; 32]);
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
for k in 0..256 {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
for i in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, i, &mut carry);
vec_znx_lsh::<_, _, ZnxRef>(basek, k, &mut res_test, i, &a, i, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, i, &mut carry);
}
assert_eq!(res_ref, res_test);
}
}
#[test]
fn test_vec_znx_rsh() {
let n: usize = 8;
let cols: usize = 2;
let res_size: usize = 7;
let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
let mut carry: Vec<i64> = vec![0i64; n];
let basek: usize = 50;
let mut source: Source = Source::new([0u8; 32]);
let zero: Vec<i64> = vec![0i64; n];
for a_size in [res_size - 1, res_size, res_size + 1] {
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
for k in 0..res_size * basek {
a.fill_uniform(50, &mut source);
for i in 0..cols {
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut a, i, &mut carry);
vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
}
res_test.fill_uniform(50, &mut source);
for j in 0..cols {
vec_znx_rsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_rsh::<_, _, ZnxRef>(basek, k, &mut res_test, j, &a, j, &mut carry);
}
for j in 0..cols {
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_ref, j, &mut carry);
vec_znx_lsh_inplace::<_, ZnxRef>(basek, k, &mut res_test, j, &mut carry);
}
// Case where res has enough to fully store a right shifted without any loss
// In this case we can check exact equality.
if a_size + k.div_ceil(basek) <= res_size {
assert_eq!(res_ref, res_test);
for i in 0..cols {
for j in 0..a_size {
assert_eq!(res_ref.at(i, j), a.at(i, j), "r0 {} {}", i, j);
assert_eq!(res_test.at(i, j), a.at(i, j), "r1 {} {}", i, j);
}
for j in a_size..res_size {
assert_eq!(res_ref.at(i, j), zero, "r0 {} {}", i, j);
assert_eq!(res_test.at(i, j), zero, "r1 {} {}", i, j);
}
}
// Some loss occures, either because a initially has more precision than res
// or because the storage of the right shift of a requires more precision than
// res.
} else {
for j in 0..cols {
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
vec_znx_sub_ab_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_ref, j, &mut carry);
vec_znx_normalize_inplace::<_, ZnxRef>(basek, &mut res_test, j, &mut carry);
assert!(res_ref.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
assert!(res_test.std(basek, j).log2() - (k as f64) <= (k * basek) as f64);
}
}
}
}
}
}

View File

@@ -0,0 +1,62 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxRotate, ZnxSwitchRing, ZnxZero},
};
pub fn vec_znx_split_ring_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn vec_znx_split_ring<R, A, ZNXARI>(res: &mut [R], res_col: usize, a: &A, a_col: usize, tmp: &mut [i64])
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSwitchRing + ZnxRotate + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let a_size = a.size();
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
#[cfg(debug_assertions)]
{
assert_eq!(tmp.len(), a.n());
assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
res[1..].iter_mut().for_each(|bi| {
assert_eq!(
bi.to_mut().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
assert!(n_in.is_multiple_of(n_out));
assert_eq!(res.len(), n_in / n_out);
}
res.iter_mut().enumerate().for_each(|(i, bi)| {
let mut bi: VecZnx<&mut [u8]> = bi.to_mut();
let min_size = bi.size().min(a_size);
if i == 0 {
for j in 0..min_size {
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), a.at(a_col, j));
}
} else {
for j in 0..min_size {
ZNXARI::znx_rotate(-(i as i64), tmp, a.at(a_col, j));
ZNXARI::znx_switch_ring(bi.at_mut(res_col, j), tmp);
}
}
for j in min_size..bi.size() {
ZNXARI::znx_zero(bi.at_mut(res_col, j));
}
})
}

View File

@@ -0,0 +1,250 @@
use std::hint::black_box;
use criterion::{BenchmarkId, Criterion};
use crate::{
api::{ModuleNew, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace},
layouts::{Backend, FillUniform, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
oep::{ModuleNewImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl},
reference::znx::{ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxSub, ZnxSubABInplace, ZnxSubBAInplace, ZnxZero},
source::Source,
};
pub fn vec_znx_sub<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxSub + ZnxNegate + ZnxZero + ZnxCopy,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let b_size: usize = b.size();
if a_size <= b_size {
let sum_size: usize = a_size.min(res_size);
let cpy_size: usize = b_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_negate(res.at_mut(res_col, j), b.at(b_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
} else {
let sum_size: usize = b_size.min(res_size);
let cpy_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub(res.at_mut(res_col, j), a.at(a_col, j), b.at(b_col, j));
}
for j in sum_size..cpy_size {
ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in cpy_size..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
}
pub fn vec_znx_sub_ab_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubABInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
}
pub fn vec_znx_sub_ba_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxSubBAInplace + ZnxNegateInplace,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), res.n());
}
let res_size: usize = res.size();
let a_size: usize = a.size();
let sum_size: usize = a_size.min(res_size);
for j in 0..sum_size {
ZNXARI::znx_sub_ba_inplace(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in sum_size..res_size {
ZNXARI::znx_negate_inplace(res.at_mut(res_col, j));
}
}
pub fn bench_vec_znx_sub<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubImpl<B>,
{
let group_name: String = format!("vec_znx_sub::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSub + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut c: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_sub(&mut c, i, &a, i, &b, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_sub_ab_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubABInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ab_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubABInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_sub_ab_inplace(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
pub fn bench_vec_znx_sub_ba_inplace<B>(c: &mut Criterion, label: &str)
where
B: Backend + ModuleNewImpl<B> + VecZnxSubBAInplaceImpl<B>,
{
let group_name: String = format!("vec_znx_sub_ba_inplace::{}", label);
let mut group = c.benchmark_group(group_name);
fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
where
Module<B>: VecZnxSubBAInplace + ModuleNew<B>,
{
let n: usize = 1 << params[0];
let cols: usize = params[1];
let size: usize = params[2];
let module: Module<B> = Module::<B>::new(n as u64);
let mut source: Source = Source::new([0u8; 32]);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
// Fill a with random i64
a.fill_uniform(50, &mut source);
b.fill_uniform(50, &mut source);
move || {
for i in 0..cols {
module.vec_znx_sub_ba_inplace(&mut b, i, &a, i);
}
black_box(());
}
}
for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
let mut runner = runner::<B>(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}

View File

@@ -0,0 +1,58 @@
use crate::layouts::{ScalarZnxToRef, VecZnxToMut, VecZnxToRef};
use crate::{
layouts::{ScalarZnx, VecZnx, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxSub, ZnxSubABInplace, ZnxZero},
};
pub fn vec_znx_sub_scalar<R, A, B, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize, b_limb: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
B: VecZnxToRef,
ZNXARI: ZnxSub + ZnxZero,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let min_size: usize = b.size().min(res.size());
#[cfg(debug_assertions)]
{
assert!(
b_limb < min_size,
"b_limb: {} > min_size: {}",
b_limb,
min_size
);
}
for j in 0..min_size {
if j == b_limb {
ZNXARI::znx_sub(res.at_mut(res_col, j), b.at(b_col, j), a.at(a_col, 0));
} else {
res.at_mut(res_col, j).copy_from_slice(b.at(b_col, j));
}
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}
pub fn vec_znx_sub_scalar_inplace<R, A, ZNXARI>(res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
ZNXARI: ZnxSubABInplace,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert!(res_limb < res.size());
}
ZNXARI::znx_sub_ab_inplace(res.at_mut(res_col, res_limb), a.at(a_col, 0));
}

View File

@@ -0,0 +1,37 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::{
vec_znx::vec_znx_copy,
znx::{ZnxCopy, ZnxSwitchRing, ZnxZero},
},
};
/// Maps between negacyclic rings by changing the polynomial degree.
/// Up: Z[X]/(X^N+1) -> Z[X]/(X^{2^d N}+1) via X ↦ X^{2^d}
/// Down: Z[X]/(X^N+1) -> Z[X]/(X^{N/2^d}+1) by folding indices.
pub fn vec_znx_switch_ring<R, A, ZNXARI>(res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
ZNXARI: ZnxCopy + ZnxSwitchRing + ZnxZero,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (a.n(), res.n());
if n_in == n_out {
vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
return;
}
let min_size: usize = a.size().min(res.size());
for j in 0..min_size {
ZNXARI::znx_switch_ring(res.at_mut(res_col, j), a.at(a_col, j));
}
for j in min_size..res.size() {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}

View File

@@ -0,0 +1,5 @@
mod normalization;
mod sampling;
pub use normalization::*;
pub use sampling::*;

View File

@@ -0,0 +1,72 @@
use crate::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes},
layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef},
source::Source,
};
pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn zn_normalize_inplace<R, ARI>(n: usize, basek: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: ZnToMut,
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(carry.len(), res.n());
}
let res_size: usize = res.size();
for j in (0..res_size).rev() {
let out = &mut res.at_mut(res_col, j)[..n];
if j == res_size - 1 {
ARI::znx_normalize_first_step_inplace(basek, 0, out, carry);
} else if j == 0 {
ARI::znx_normalize_final_step_inplace(basek, 0, out, carry);
} else {
ARI::znx_normalize_middle_step_inplace(basek, 0, out, carry);
}
}
}
pub fn test_zn_normalize_inplace<B: Backend>(module: &Module<B>)
where
Module<B>: ZnNormalizeInplace<B> + ZnNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let basek: usize = 12;
let n = 33;
let mut carry: Vec<i64> = vec![0i64; zn_normalize_tmp_bytes(n)];
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n()));
for res_size in [1, 2, 6, 11] {
let mut res_0: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
let mut res_1: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
res_0
.raw_mut()
.iter_mut()
.for_each(|x| *x = source.next_i32() as i64);
res_1.raw_mut().copy_from_slice(res_0.raw());
// Reference
for i in 0..cols {
zn_normalize_inplace::<_, ZnxRef>(n, basek, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, basek, &mut res_1, i, scratch.borrow());
}
assert_eq!(res_0.raw(), res_1.raw());
}
}

View File

@@ -0,0 +1,75 @@
use crate::{
layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut},
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
source::Source,
};
pub fn zn_fill_uniform<R>(n: usize, basek: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(basek, &mut res.at_mut(res_col, j)[..n], source)
}
}
#[allow(clippy::too_many_arguments)]
pub fn zn_fill_normal<R>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_fill_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
bound * scale,
source,
)
}
#[allow(clippy::too_many_arguments)]
pub fn zn_add_normal<R>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(basek) - 1;
let scale: f64 = (1 << ((limb + 1) * basek - k)) as f64;
znx_add_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
bound * scale,
source,
)
}

View File

@@ -0,0 +1,25 @@
#[inline(always)]
pub fn znx_add_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] = a[i] + b[i];
}
}
pub fn znx_add_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] += a[i];
}
}

View File

@@ -0,0 +1,153 @@
use crate::reference::znx::{
ZnxAdd, ZnxAddInplace, ZnxAutomorphism, ZnxCopy, ZnxNegate, ZnxNegateInplace, ZnxNormalizeFinalStep,
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxSub, ZnxSubABInplace,
ZnxSubBAInplace, ZnxSwitchRing, ZnxZero,
add::{znx_add_inplace_ref, znx_add_ref},
automorphism::znx_automorphism_ref,
copy::znx_copy_ref,
neg::{znx_negate_inplace_ref, znx_negate_ref},
normalization::{
znx_normalize_final_step_inplace_ref, znx_normalize_final_step_ref, znx_normalize_first_step_carry_only_ref,
znx_normalize_first_step_inplace_ref, znx_normalize_first_step_ref, znx_normalize_middle_step_carry_only_ref,
znx_normalize_middle_step_inplace_ref, znx_normalize_middle_step_ref,
},
sub::{znx_sub_ab_inplace_ref, znx_sub_ba_inplace_ref, znx_sub_ref},
switch_ring::znx_switch_ring_ref,
zero::znx_zero_ref,
};
pub struct ZnxRef {}
impl ZnxAdd for ZnxRef {
#[inline(always)]
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_add_ref(res, a, b);
}
}
impl ZnxAddInplace for ZnxRef {
#[inline(always)]
fn znx_add_inplace(res: &mut [i64], a: &[i64]) {
znx_add_inplace_ref(res, a);
}
}
impl ZnxSub for ZnxRef {
#[inline(always)]
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]) {
znx_sub_ref(res, a, b);
}
}
impl ZnxSubABInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ab_inplace_ref(res, a);
}
}
impl ZnxSubBAInplace for ZnxRef {
#[inline(always)]
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]) {
znx_sub_ba_inplace_ref(res, a);
}
}
impl ZnxAutomorphism for ZnxRef {
#[inline(always)]
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]) {
znx_automorphism_ref(p, res, a);
}
}
impl ZnxCopy for ZnxRef {
#[inline(always)]
fn znx_copy(res: &mut [i64], a: &[i64]) {
znx_copy_ref(res, a);
}
}
impl ZnxNegate for ZnxRef {
#[inline(always)]
fn znx_negate(res: &mut [i64], src: &[i64]) {
znx_negate_ref(res, src);
}
}
impl ZnxNegateInplace for ZnxRef {
#[inline(always)]
fn znx_negate_inplace(res: &mut [i64]) {
znx_negate_inplace_ref(res);
}
}
impl ZnxZero for ZnxRef {
#[inline(always)]
fn znx_zero(res: &mut [i64]) {
znx_zero_ref(res);
}
}
impl ZnxSwitchRing for ZnxRef {
#[inline(always)]
fn znx_switch_ring(res: &mut [i64], a: &[i64]) {
znx_switch_ring_ref(res, a);
}
}
impl ZnxNormalizeFinalStep for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_final_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFinalStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_final_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStep for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeFirstStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_first_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeFirstStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_first_step_inplace_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStep for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_ref(basek, lsh, x, a, carry);
}
}
impl ZnxNormalizeMiddleStepCarryOnly for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
znx_normalize_middle_step_carry_only_ref(basek, lsh, x, carry);
}
}
impl ZnxNormalizeMiddleStepInplace for ZnxRef {
#[inline(always)]
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
znx_normalize_middle_step_inplace_ref(basek, lsh, x, carry);
}
}

View File

@@ -0,0 +1,21 @@
pub fn znx_automorphism_ref(p: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
let mut k: usize = 0usize;
let mask: usize = 2 * n - 1;
let p_2n = (p & mask as i64) as usize;
res[0] = a[0];
for ai in a.iter().take(n).skip(1) {
k = (k + p_2n) & mask;
if k < n {
res[k] = *ai
} else {
res[k - n] = -*ai
}
}
}

View File

@@ -0,0 +1,8 @@
#[inline(always)]
pub fn znx_copy_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
res.copy_from_slice(a);
}

View File

@@ -0,0 +1,104 @@
mod add;
mod arithmetic_ref;
mod automorphism;
mod copy;
mod neg;
mod normalization;
mod rotate;
mod sampling;
mod sub;
mod switch_ring;
mod zero;
pub use add::*;
pub use arithmetic_ref::*;
pub use automorphism::*;
pub use copy::*;
pub use neg::*;
pub use normalization::*;
pub use rotate::*;
pub use sub::*;
pub use switch_ring::*;
pub use zero::*;
pub use sampling::*;
pub trait ZnxAdd {
fn znx_add(res: &mut [i64], a: &[i64], b: &[i64]);
}
pub trait ZnxAddInplace {
fn znx_add_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxSub {
fn znx_sub(res: &mut [i64], a: &[i64], b: &[i64]);
}
pub trait ZnxSubABInplace {
fn znx_sub_ab_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxSubBAInplace {
fn znx_sub_ba_inplace(res: &mut [i64], a: &[i64]);
}
pub trait ZnxAutomorphism {
fn znx_automorphism(p: i64, res: &mut [i64], a: &[i64]);
}
pub trait ZnxCopy {
fn znx_copy(res: &mut [i64], a: &[i64]);
}
pub trait ZnxNegate {
fn znx_negate(res: &mut [i64], src: &[i64]);
}
pub trait ZnxNegateInplace {
fn znx_negate_inplace(res: &mut [i64]);
}
pub trait ZnxRotate {
fn znx_rotate(p: i64, res: &mut [i64], src: &[i64]);
}
pub trait ZnxZero {
fn znx_zero(res: &mut [i64]);
}
pub trait ZnxSwitchRing {
fn znx_switch_ring(res: &mut [i64], a: &[i64]);
}
pub trait ZnxNormalizeFirstStepCarryOnly {
fn znx_normalize_first_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStepInplace {
fn znx_normalize_first_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFirstStep {
fn znx_normalize_first_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepCarryOnly {
fn znx_normalize_middle_step_carry_only(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStepInplace {
fn znx_normalize_middle_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeMiddleStep {
fn znx_normalize_middle_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStepInplace {
fn znx_normalize_final_step_inplace(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]);
}
pub trait ZnxNormalizeFinalStep {
fn znx_normalize_final_step(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]);
}

View File

@@ -0,0 +1,18 @@
#[inline(always)]
pub fn znx_negate_ref(res: &mut [i64], src: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len())
}
for i in 0..res.len() {
res[i] = -src[i]
}
}
#[inline(always)]
pub fn znx_negate_inplace_ref(res: &mut [i64]) {
for value in res {
*value = -*value
}
}

View File

@@ -0,0 +1,199 @@
use itertools::izip;
#[inline(always)]
pub fn get_digit(basek: usize, x: i64) -> i64 {
(x << (u64::BITS - basek as u32)) >> (u64::BITS - basek as u32)
}
#[inline(always)]
pub fn get_carry(basek: usize, x: i64, digit: i64) -> i64 {
(x.wrapping_sub(digit)) >> basek
}
#[inline(always)]
pub fn znx_normalize_first_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek, *x, get_digit(basek, *x));
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
*c = get_carry(basek_lsh, *x, get_digit(basek_lsh, *x));
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
*c = get_carry(basek, *x, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
*c = get_carry(basek_lsh, *x, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_first_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
*c = get_carry(basek, *a, digit);
*x = digit;
});
} else {
let basek_lsh: usize = basek - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
*c = get_carry(basek_lsh, *a, digit);
*x = digit << lsh;
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_carry_only_ref(basek: usize, lsh: usize, x: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit_plus_c: i64 = digit + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*c = carry + get_carry(basek, digit_plus_c, get_digit(basek, digit_plus_c));
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek, *x);
let carry: i64 = get_carry(basek, *x, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
let digit: i64 = get_digit(basek_lsh, *x);
let carry: i64 = get_carry(basek_lsh, *x, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_middle_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(x.len(), a.len());
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek, *a);
let carry: i64 = get_carry(basek, *a, digit);
let digit_plus_c: i64 = digit + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
} else {
let basek_lsh: usize = basek - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
let digit: i64 = get_digit(basek_lsh, *a);
let carry: i64 = get_carry(basek_lsh, *a, digit);
let digit_plus_c: i64 = (digit << lsh) + *c;
*x = get_digit(basek, digit_plus_c);
*c = carry + get_carry(basek, digit_plus_c, *x);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_inplace_ref(basek: usize, lsh: usize, x: &mut [i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, get_digit(basek, *x) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
x.iter_mut().zip(carry.iter_mut()).for_each(|(x, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *x) << lsh) + *c);
});
}
}
#[inline(always)]
pub fn znx_normalize_final_step_ref(basek: usize, lsh: usize, x: &mut [i64], a: &[i64], carry: &mut [i64]) {
#[cfg(debug_assertions)]
{
assert!(x.len() <= carry.len());
assert!(lsh < basek);
}
if lsh == 0 {
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, get_digit(basek, *a) + *c);
});
} else {
let basek_lsh: usize = basek - lsh;
izip!(x.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(x, a, c)| {
*x = get_digit(basek, (get_digit(basek_lsh, *a) << lsh) + *c);
});
}
}

View File

@@ -0,0 +1,26 @@
use crate::reference::znx::{ZnxCopy, ZnxNegate};
pub fn znx_rotate<ZNXARI: ZnxNegate + ZnxCopy>(p: i64, res: &mut [i64], src: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), src.len());
}
let n: usize = res.len();
let mp_2n: usize = (p & (2 * n as i64 - 1)) as usize; // -p % 2n
let mp_1n: usize = mp_2n & (n - 1); // -p % n
let mp_1n_neg: usize = n - mp_1n; // p % n
let neg_first: bool = mp_2n < n;
let (dst1, dst2) = res.split_at_mut(mp_1n);
let (src1, src2) = src.split_at(mp_1n_neg);
if neg_first {
ZNXARI::znx_negate(dst1, src2);
ZNXARI::znx_copy(dst2, src1);
} else {
ZNXARI::znx_copy(dst1, src2);
ZNXARI::znx_negate(dst2, src1);
}
}

View File

@@ -0,0 +1,53 @@
use rand_distr::{Distribution, Normal};
use crate::source::Source;
pub fn znx_fill_uniform_ref(basek: usize, res: &mut [i64], source: &mut Source) {
let pow2k: u64 = 1 << basek;
let mask: u64 = pow2k - 1;
let pow2k_half: i64 = (pow2k >> 1) as i64;
res.iter_mut()
.for_each(|xi| *xi = (source.next_u64n(pow2k, mask) as i64) - pow2k_half)
}
pub fn znx_fill_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*xi = dist_f64.round() as i64
})
}
pub fn znx_add_dist_f64_ref<D: rand::prelude::Distribution<f64>>(res: &mut [i64], dist: D, bound: f64, source: &mut Source) {
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*xi += dist_f64.round() as i64
})
}
pub fn znx_fill_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = normal.sample(source);
while dist_f64.abs() > bound {
dist_f64 = normal.sample(source)
}
*xi = dist_f64.round() as i64
})
}
pub fn znx_add_normal_f64_ref(res: &mut [i64], sigma: f64, bound: f64, source: &mut Source) {
let normal: Normal<f64> = Normal::new(0.0, sigma).unwrap();
res.iter_mut().for_each(|xi| {
let mut dist_f64: f64 = normal.sample(source);
while dist_f64.abs() > bound {
dist_f64 = normal.sample(source)
}
*xi += dist_f64.round() as i64
})
}

View File

@@ -0,0 +1,36 @@
pub fn znx_sub_ref(res: &mut [i64], a: &[i64], b: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
assert_eq!(res.len(), b.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] = a[i] - b[i];
}
}
pub fn znx_sub_ab_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] -= a[i];
}
}
pub fn znx_sub_ba_inplace_ref(res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
let n: usize = res.len();
for i in 0..n {
res[i] = a[i] - res[i];
}
}

View File

@@ -0,0 +1,29 @@
use crate::reference::znx::{copy::znx_copy_ref, zero::znx_zero_ref};
pub fn znx_switch_ring_ref(res: &mut [i64], a: &[i64]) {
let (n_in, n_out) = (a.len(), res.len());
#[cfg(debug_assertions)]
{
assert!(n_in.is_power_of_two());
assert!(n_in.max(n_out).is_multiple_of(n_in.min(n_out)))
}
if n_in == n_out {
znx_copy_ref(res, a);
return;
}
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
znx_zero_ref(res);
}
res.iter_mut()
.step_by(gap_out)
.zip(a.iter().step_by(gap_in))
.for_each(|(x_out, x_in)| *x_out = *x_in);
}

View File

@@ -0,0 +1,3 @@
pub fn znx_zero_ref(res: &mut [i64]) {
res.fill(0);
}