mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Support for bivariate convolution & normalization with offset (#126)
* Add bivariate-convolution * Add pair-wise convolution + tests + benches * Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md * cross-base2k normalization with positive offset * clippy & fix CI doctest avx compile error * more streamlined bounds derivation for normalization * Working cross-base2k normalization with pos/neg offset * Update normalization API & tests * Add glwe tensoring test * Add relinearization + preliminary test * Fix GGLWEToGGSW key infos * Add (X,Y) convolution by const (1, Y) poly * Faster normalization test + add bench for cnv_by_const * Update changelog
This commit is contained in:
committed by
GitHub
parent
76424d0ab5
commit
4e90e08a71
166
poulpy-cpu-ref/src/convolution.rs
Normal file
166
poulpy-cpu-ref/src/convolution.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use poulpy_hal::{
|
||||
api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
|
||||
layouts::{
|
||||
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx,
|
||||
VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
|
||||
},
|
||||
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
|
||||
reference::fft64::convolution::{
|
||||
convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
|
||||
convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
|
||||
convolution_prepare_right,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{FFT64Ref, module::FFT64ModuleHandle};
|
||||
|
||||
unsafe impl CnvPVecLAllocImpl<Self> for FFT64Ref {
|
||||
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, Self> {
|
||||
CnvPVecL::alloc(n, cols, size)
|
||||
}
|
||||
|
||||
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, Self> {
|
||||
CnvPVecR::alloc(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl CnvPVecBytesOfImpl for FFT64Ref {
|
||||
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Ref as Backend>::ScalarPrep>()
|
||||
}
|
||||
|
||||
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Ref as Backend>::ScalarPrep>()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ConvolutionImpl<Self> for FFT64Ref
|
||||
where
|
||||
Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
|
||||
{
|
||||
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
|
||||
}
|
||||
|
||||
fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: CnvPVecLToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut CnvPVecL<&mut [u8], FFT64Ref> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
|
||||
convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
|
||||
}
|
||||
|
||||
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
|
||||
}
|
||||
|
||||
fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
R: CnvPVecRToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut CnvPVecR<&mut [u8], FFT64Ref> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
|
||||
convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
|
||||
}
|
||||
|
||||
fn cnv_apply_dft_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
res_size: usize,
|
||||
_res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_by_const_apply_tmp_bytes_impl(
|
||||
_module: &Module<Self>,
|
||||
res_size: usize,
|
||||
_res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_by_const_apply_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &[i64],
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxBigToMut<Self>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
|
||||
let a: &VecZnx<&[u8]> = &a.to_ref();
|
||||
let (tmp, _) =
|
||||
scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::<i64>());
|
||||
convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
|
||||
}
|
||||
|
||||
fn cnv_apply_dft_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
b: &B,
|
||||
b_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: CnvPVecLToRef<Self>,
|
||||
B: CnvPVecRToRef<Self>,
|
||||
{
|
||||
let res: &mut VecZnxDft<&mut [u8], FFT64Ref> = &mut res.to_mut();
|
||||
let a: &CnvPVecL<&[u8], FFT64Ref> = &a.to_ref();
|
||||
let b: &CnvPVecR<&[u8], FFT64Ref> = &b.to_ref();
|
||||
let (tmp, _) =
|
||||
scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
|
||||
convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
|
||||
}
|
||||
|
||||
fn cnv_pairwise_apply_dft_tmp_bytes(
|
||||
_module: &Module<Self>,
|
||||
res_size: usize,
|
||||
_res_offset: usize,
|
||||
a_size: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
|
||||
}
|
||||
|
||||
fn cnv_pairwise_apply_dft_impl<R, A, B>(
|
||||
module: &Module<Self>,
|
||||
res: &mut R,
|
||||
res_offset: usize,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
b: &B,
|
||||
col_0: usize,
|
||||
col_1: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
R: VecZnxDftToMut<Self>,
|
||||
A: CnvPVecLToRef<Self>,
|
||||
B: CnvPVecRToRef<Self>,
|
||||
{
|
||||
let res: &mut VecZnxDft<&mut [u8], FFT64Ref> = &mut res.to_mut();
|
||||
let a: &CnvPVecL<&[u8], FFT64Ref> = &a.to_ref();
|
||||
let b: &CnvPVecR<&[u8], FFT64Ref> = &b.to_ref();
|
||||
let (tmp, _) = scratch
|
||||
.take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
|
||||
convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, col_0, col_1, tmp);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod convolution;
|
||||
mod module;
|
||||
mod reim;
|
||||
mod scratch;
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
||||
@@ -1,4 +1,9 @@
|
||||
use poulpy_hal::reference::fft64::{
|
||||
convolution::{
|
||||
I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous,
|
||||
i64_convolution_by_const_1coeff_ref, i64_convolution_by_const_2coeffs_ref, i64_extract_1blk_contiguous_ref,
|
||||
i64_save_1blk_contiguous_ref,
|
||||
},
|
||||
reim::{
|
||||
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
|
||||
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
|
||||
@@ -8,9 +13,13 @@ use poulpy_hal::reference::fft64::{
|
||||
reim_zero_ref,
|
||||
},
|
||||
reim4::{
|
||||
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
|
||||
reim4_extract_1blk_from_reim_ref, reim4_save_1blk_to_reim_ref, reim4_save_2blk_to_reim_ref,
|
||||
reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref, reim4_vec_mat2cols_product_ref,
|
||||
Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff, Reim4ConvolutionByRealConst2Coeffs,
|
||||
Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk,
|
||||
Reim4Save1BlkContiguous, Reim4Save2Blks, reim4_convolution_1coeff_ref, reim4_convolution_2coeffs_ref,
|
||||
reim4_convolution_by_real_const_1coeff_ref, reim4_convolution_by_real_const_2coeffs_ref,
|
||||
reim4_extract_1blk_from_reim_contiguous_ref, reim4_save_1blk_to_reim_contiguous_ref, reim4_save_1blk_to_reim_ref,
|
||||
reim4_save_2blk_to_reim_ref, reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref,
|
||||
reim4_vec_mat2cols_product_ref,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -133,10 +142,29 @@ impl ReimZero for FFT64Ref {
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Extract1Blk for FFT64Ref {
|
||||
impl Reim4Convolution1Coeff for FFT64Ref {
|
||||
fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
reim4_convolution_1coeff_ref(k, dst, a, a_size, b, b_size);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Convolution2Coeffs for FFT64Ref {
|
||||
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
|
||||
reim4_convolution_2coeffs_ref(k, dst, a, a_size, b, b_size);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Extract1BlkContiguous for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
reim4_extract_1blk_from_reim_ref(m, rows, blk, dst, src);
|
||||
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
reim4_extract_1blk_from_reim_contiguous_ref(m, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4Save1BlkContiguous for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
|
||||
reim4_save_1blk_to_reim_contiguous_ref(m, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,3 +202,45 @@ impl Reim4Mat2Cols2ndColProd for FFT64Ref {
|
||||
reim4_vec_mat2cols_2ndcol_product_ref(nrows, dst, u, v);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4ConvolutionByRealConst1Coeff for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
reim4_convolution_by_real_const_1coeff_ref(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl Reim4ConvolutionByRealConst2Coeffs for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
|
||||
reim4_convolution_by_real_const_2coeffs_ref(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl I64ConvolutionByConst1Coeff for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
i64_convolution_by_const_1coeff_ref(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl I64ConvolutionByConst2Coeffs for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
|
||||
i64_convolution_by_const_2coeffs_ref(k, dst, a, a_size, b);
|
||||
}
|
||||
}
|
||||
|
||||
impl I64Save1BlkContiguous for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
i64_save_1blk_contiguous_ref(n, offset, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
impl I64Extract1BlkContiguous for FFT64Ref {
|
||||
#[inline(always)]
|
||||
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
|
||||
i64_extract_1blk_contiguous_ref(n, offset, rows, blk, dst, src);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
|
||||
use poulpy_hal::{
|
||||
api::ModuleNew,
|
||||
layouts::Module,
|
||||
test_suite::convolution::{test_convolution, test_convolution_by_const, test_convolution_pairwise},
|
||||
};
|
||||
|
||||
use crate::FFT64Ref;
|
||||
|
||||
#[test]
|
||||
fn test_convolution_by_const_fft64_ref() {
|
||||
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
|
||||
test_convolution_by_const(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convolution_fft64_ref() {
|
||||
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
|
||||
test_bivariate_tensoring(&module);
|
||||
test_convolution(&module);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convolution_pairwise_fft64_ref() {
|
||||
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
|
||||
test_convolution_pairwise(&module);
|
||||
}
|
||||
|
||||
@@ -53,11 +53,12 @@ where
|
||||
{
|
||||
fn vec_znx_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
@@ -65,7 +66,7 @@ where
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
vec_znx_normalize::<R, A, Self>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ use poulpy_hal::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Ref {
|
||||
unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref {
|
||||
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
|
||||
}
|
||||
@@ -280,11 +280,12 @@ where
|
||||
{
|
||||
fn vec_znx_big_normalize_impl<R, A>(
|
||||
module: &Module<Self>,
|
||||
res_basek: usize,
|
||||
res: &mut R,
|
||||
res_base2k: usize,
|
||||
res_offset: i64,
|
||||
res_col: usize,
|
||||
a_basek: usize,
|
||||
a: &A,
|
||||
a_base2k: usize,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch<Self>,
|
||||
) where
|
||||
@@ -292,7 +293,7 @@ where
|
||||
A: VecZnxBigToRef<Self>,
|
||||
{
|
||||
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
|
||||
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
|
||||
vec_znx_big_normalize(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user