mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Added API in poulpy for updated vmp_add (+tests)
This commit is contained in:
Submodule backend/spqlios-arithmetic updated: b919282c9b...d6045033e5
@@ -47,6 +47,7 @@ unsafe extern "C" {
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
@@ -79,6 +80,7 @@ unsafe extern "C" {
|
||||
pmat: *const VMP_PMAT,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
pmat_scale: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||
B: MatZnxToRef<FFT64>;
|
||||
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
@@ -309,7 +309,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
@@ -358,6 +358,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
scale as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
@@ -368,6 +369,7 @@ mod tests {
|
||||
use crate::{
|
||||
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -409,7 +411,7 @@ mod tests {
|
||||
let basek: usize = 15;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = 5;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|in_cols| {
|
||||
[1, 2].iter().for_each(|out_cols| {
|
||||
@@ -419,7 +421,6 @@ mod tests {
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
let res_cols: usize = mat_cols_out;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
@@ -435,7 +436,7 @@ mod tests {
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
a.at_mut(i, 2)[i + 1] = 1;
|
||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
@@ -479,7 +480,100 @@ mod tests {
|
||||
(0..a_cols).for_each(|i| {
|
||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||
});
|
||||
res_have.decode_vec_i64(col_i, basek, basek * 3, &mut res_have_vi64);
|
||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply_add() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let basek: usize = 15;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = a_size;
|
||||
|
||||
[1, 2].iter().for_each(|in_cols| {
|
||||
[1, 2].iter().for_each(|out_cols| {
|
||||
let a_cols: usize = *in_cols;
|
||||
let res_cols: usize = *out_cols;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
c_dft.zero();
|
||||
(0..c_dft.cols()).for_each(|i| {
|
||||
module.vec_znx_dft(&mut c_dft, i, &a, 0);
|
||||
});
|
||||
|
||||
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow());
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
(0..a_cols).for_each(|i| {
|
||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||
});
|
||||
|
||||
res_want_vi64[1] += 1;
|
||||
|
||||
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user