Support for bivariate convolution & normalization with offset (#126)

* Add bivariate-convolution
* Add pair-wise convolution + tests + benches
* Add take_cnv_pvec_[left/right] to Scratch & updated CHANGELOG.md
* cross-base2k normalization with positive offset
* clippy & fix CI doctest avx compile error
* more streamlined bounds derivation for normalization
* Working cross-base2k normalization with pos/neg offset
* Update normalization API & tests
* Add glwe tensoring test
* Add relinearization + preliminary test
* Fix GGLWEToGGSW key infos
* Add (X,Y) convolution by const (1, Y) poly
* Faster normalization test + add bench for cnv_by_const
* Update changelog
This commit is contained in:
Jean-Philippe Bossuat
2025-12-21 16:56:42 +01:00
committed by GitHub
parent 76424d0ab5
commit 4e90e08a71
219 changed files with 6571 additions and 5041 deletions

View File

@@ -27,5 +27,5 @@ rustdoc-args = ["--cfg", "docsrs"]
[[bench]]
name = "vmp"
name = "convolution"
harness = false

View File

@@ -0,0 +1,35 @@
use criterion::{Criterion, criterion_group, criterion_main};
use poulpy_cpu_ref::FFT64Ref;
use poulpy_hal::bench_suite::convolution::{
bench_cnv_apply_dft, bench_cnv_by_const_apply, bench_cnv_pairwise_apply_dft, bench_cnv_prepare_left, bench_cnv_prepare_right,
};
fn bench_cnv_prepare_left_cpu_ref_fft64(c: &mut Criterion) {
bench_cnv_prepare_left::<FFT64Ref>(c, "cpu_ref::fft64");
}
fn bench_cnv_prepare_right_cpu_ref_fft64(c: &mut Criterion) {
bench_cnv_prepare_right::<FFT64Ref>(c, "cpu_ref::fft64");
}
fn bench_bench_cnv_apply_dft_cpu_ref_fft64(c: &mut Criterion) {
bench_cnv_apply_dft::<FFT64Ref>(c, "cpu_ref::fft64");
}
fn bench_bench_bench_cnv_pairwise_apply_dft_cpu_ref_fft64(c: &mut Criterion) {
bench_cnv_pairwise_apply_dft::<FFT64Ref>(c, "cpu_ref::fft64");
}
fn bench_cnv_by_const_apply_cpu_ref_fft64(c: &mut Criterion) {
bench_cnv_by_const_apply::<FFT64Ref>(c, "cpu_ref::fft64");
}
criterion_group!(
benches,
bench_cnv_prepare_left_cpu_ref_fft64,
bench_cnv_prepare_right_cpu_ref_fft64,
bench_bench_cnv_apply_dft_cpu_ref_fft64,
bench_bench_bench_cnv_pairwise_apply_dft_cpu_ref_fft64,
bench_cnv_by_const_apply_cpu_ref_fft64,
);
criterion_main!(benches);

View File

@@ -11,10 +11,7 @@ pub fn bench_fft_ref(c: &mut Criterion) {
fn runner(m: usize) -> impl FnMut() {
let mut values: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / (2 * m) as f64;
values
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let table: ReimFFTTable<f64> = ReimFFTTable::<f64>::new(m);
move || {
ReimFFTRef::reim_dft_execute(&table, &mut values);
@@ -39,10 +36,7 @@ pub fn bench_ifft_ref(c: &mut Criterion) {
fn runner(m: usize) -> impl FnMut() {
let mut values: Vec<f64> = vec![0f64; m << 1];
let scale: f64 = 1.0f64 / (2 * m) as f64;
values
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = (i + 1) as f64 * scale);
values.iter_mut().enumerate().for_each(|(i, x)| *x = (i + 1) as f64 * scale);
let table: ReimIFFTTable<f64> = ReimIFFTTable::<f64>::new(m);
move || {
ReimIFFTRef::reim_dft_execute(&table, &mut values);

View File

@@ -5,7 +5,7 @@ use poulpy_hal::reference::vec_znx::{bench_vec_znx_add, bench_vec_znx_automorphi
#[allow(dead_code)]
fn bench_vec_znx_add_cpu_ref_fft64(c: &mut Criterion) {
bench_vec_znx_add::<FFT64Ref>(c, "cpu_spqlios::fft64");
bench_vec_znx_add::<FFT64Ref>(c, "cpu_ref::fft64");
}
#[allow(dead_code)]

View File

@@ -0,0 +1,166 @@
use poulpy_hal::{
api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
layouts::{
Backend, CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx,
VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
},
oep::{CnvPVecBytesOfImpl, CnvPVecLAllocImpl, ConvolutionImpl},
reference::fft64::convolution::{
convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
convolution_prepare_right,
},
};
use crate::{FFT64Ref, module::FFT64ModuleHandle};
unsafe impl CnvPVecLAllocImpl<Self> for FFT64Ref {
fn cnv_pvec_left_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecL<Vec<u8>, Self> {
CnvPVecL::alloc(n, cols, size)
}
fn cnv_pvec_right_alloc_impl(n: usize, cols: usize, size: usize) -> CnvPVecR<Vec<u8>, Self> {
CnvPVecR::alloc(n, cols, size)
}
}
unsafe impl CnvPVecBytesOfImpl for FFT64Ref {
fn bytes_of_cnv_pvec_left_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Ref as Backend>::ScalarPrep>()
}
fn bytes_of_cnv_pvec_right_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_prep_word_count() * n * cols * size * size_of::<<FFT64Ref as Backend>::ScalarPrep>()
}
}
unsafe impl ConvolutionImpl<Self> for FFT64Ref
where
Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
{
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecLToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecL<&mut [u8], FFT64Ref> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecRToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecR<&mut [u8], FFT64Ref> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_apply_dft_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len()) / size_of::<i64>());
convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
}
fn cnv_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Ref> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Ref> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Ref> = &b.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
}
fn cnv_pairwise_apply_dft_tmp_bytes(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_pairwise_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
col_0: usize,
col_1: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Ref> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Ref> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Ref> = &b.to_ref();
let (tmp, _) = scratch
.take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, col_0, col_1, tmp);
}
}

View File

@@ -1,3 +1,4 @@
mod convolution;
mod module;
mod reim;
mod scratch;

View File

@@ -1,3 +0,0 @@
fn main() {
println!("Hello, world!");
}

View File

@@ -1,4 +1,9 @@
use poulpy_hal::reference::fft64::{
convolution::{
I64ConvolutionByConst1Coeff, I64ConvolutionByConst2Coeffs, I64Extract1BlkContiguous, I64Save1BlkContiguous,
i64_convolution_by_const_1coeff_ref, i64_convolution_by_const_2coeffs_ref, i64_extract_1blk_contiguous_ref,
i64_save_1blk_contiguous_ref,
},
reim::{
ReimAdd, ReimAddInplace, ReimAddMul, ReimCopy, ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimIFFTTable, ReimMul,
ReimMulInplace, ReimNegate, ReimNegateInplace, ReimSub, ReimSubInplace, ReimSubNegateInplace, ReimToZnx,
@@ -8,9 +13,13 @@ use poulpy_hal::reference::fft64::{
reim_zero_ref,
},
reim4::{
Reim4Extract1Blk, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk, Reim4Save2Blks,
reim4_extract_1blk_from_reim_ref, reim4_save_1blk_to_reim_ref, reim4_save_2blk_to_reim_ref,
reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref, reim4_vec_mat2cols_product_ref,
Reim4Convolution1Coeff, Reim4Convolution2Coeffs, Reim4ConvolutionByRealConst1Coeff, Reim4ConvolutionByRealConst2Coeffs,
Reim4Extract1BlkContiguous, Reim4Mat1ColProd, Reim4Mat2Cols2ndColProd, Reim4Mat2ColsProd, Reim4Save1Blk,
Reim4Save1BlkContiguous, Reim4Save2Blks, reim4_convolution_1coeff_ref, reim4_convolution_2coeffs_ref,
reim4_convolution_by_real_const_1coeff_ref, reim4_convolution_by_real_const_2coeffs_ref,
reim4_extract_1blk_from_reim_contiguous_ref, reim4_save_1blk_to_reim_contiguous_ref, reim4_save_1blk_to_reim_ref,
reim4_save_2blk_to_reim_ref, reim4_vec_mat1col_product_ref, reim4_vec_mat2cols_2ndcol_product_ref,
reim4_vec_mat2cols_product_ref,
},
};
@@ -133,10 +142,29 @@ impl ReimZero for FFT64Ref {
}
}
impl Reim4Extract1Blk for FFT64Ref {
impl Reim4Convolution1Coeff for FFT64Ref {
fn reim4_convolution_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
reim4_convolution_1coeff_ref(k, dst, a, a_size, b, b_size);
}
}
impl Reim4Convolution2Coeffs for FFT64Ref {
fn reim4_convolution_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64], b_size: usize) {
reim4_convolution_2coeffs_ref(k, dst, a, a_size, b, b_size);
}
}
impl Reim4Extract1BlkContiguous for FFT64Ref {
#[inline(always)]
fn reim4_extract_1blk(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
reim4_extract_1blk_from_reim_ref(m, rows, blk, dst, src);
fn reim4_extract_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
reim4_extract_1blk_from_reim_contiguous_ref(m, rows, blk, dst, src);
}
}
impl Reim4Save1BlkContiguous for FFT64Ref {
#[inline(always)]
fn reim4_save_1blk_contiguous(m: usize, rows: usize, blk: usize, dst: &mut [f64], src: &[f64]) {
reim4_save_1blk_to_reim_contiguous_ref(m, rows, blk, dst, src);
}
}
@@ -174,3 +202,45 @@ impl Reim4Mat2Cols2ndColProd for FFT64Ref {
reim4_vec_mat2cols_2ndcol_product_ref(nrows, dst, u, v);
}
}
impl Reim4ConvolutionByRealConst1Coeff for FFT64Ref {
#[inline(always)]
fn reim4_convolution_by_real_const_1coeff(k: usize, dst: &mut [f64; 8], a: &[f64], a_size: usize, b: &[f64]) {
reim4_convolution_by_real_const_1coeff_ref(k, dst, a, a_size, b);
}
}
impl Reim4ConvolutionByRealConst2Coeffs for FFT64Ref {
#[inline(always)]
fn reim4_convolution_by_real_const_2coeffs(k: usize, dst: &mut [f64; 16], a: &[f64], a_size: usize, b: &[f64]) {
reim4_convolution_by_real_const_2coeffs_ref(k, dst, a, a_size, b);
}
}
impl I64ConvolutionByConst1Coeff for FFT64Ref {
#[inline(always)]
fn i64_convolution_by_const_1coeff(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
i64_convolution_by_const_1coeff_ref(k, dst, a, a_size, b);
}
}
impl I64ConvolutionByConst2Coeffs for FFT64Ref {
#[inline(always)]
fn i64_convolution_by_const_2coeffs(k: usize, dst: &mut [i64; 16], a: &[i64], a_size: usize, b: &[i64]) {
i64_convolution_by_const_2coeffs_ref(k, dst, a, a_size, b);
}
}
impl I64Save1BlkContiguous for FFT64Ref {
#[inline(always)]
fn i64_save_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
i64_save_1blk_contiguous_ref(n, offset, rows, blk, dst, src);
}
}
impl I64Extract1BlkContiguous for FFT64Ref {
#[inline(always)]
fn i64_extract_1blk_contiguous(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
i64_extract_1blk_contiguous_ref(n, offset, rows, blk, dst, src);
}
}

View File

@@ -1,9 +1,25 @@
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
use poulpy_hal::{
api::ModuleNew,
layouts::Module,
test_suite::convolution::{test_convolution, test_convolution_by_const, test_convolution_pairwise},
};
use crate::FFT64Ref;
#[test]
fn test_convolution_by_const_fft64_ref() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
test_convolution_by_const(&module);
}
#[test]
fn test_convolution_fft64_ref() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
test_bivariate_tensoring(&module);
test_convolution(&module);
}
#[test]
fn test_convolution_pairwise_fft64_ref() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
test_convolution_pairwise(&module);
}

View File

@@ -53,11 +53,12 @@ where
{
fn vec_znx_normalize_impl<R, A>(
module: &Module<Self>,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
@@ -65,7 +66,7 @@ where
A: VecZnxToRef,
{
let (carry, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_normalize::<R, A, Self>(res_basek, res, res_col, a_basek, a, a_col, carry);
vec_znx_normalize::<R, A, Self>(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
}
}

View File

@@ -26,7 +26,7 @@ use poulpy_hal::{
source::Source,
};
unsafe impl VecZnxBigAllocBytesImpl<Self> for FFT64Ref {
unsafe impl VecZnxBigAllocBytesImpl for FFT64Ref {
fn vec_znx_big_bytes_of_impl(n: usize, cols: usize, size: usize) -> usize {
Self::layout_big_word_count() * n * cols * size * size_of::<f64>()
}
@@ -280,11 +280,12 @@ where
{
fn vec_znx_big_normalize_impl<R, A>(
module: &Module<Self>,
res_basek: usize,
res: &mut R,
res_base2k: usize,
res_offset: i64,
res_col: usize,
a_basek: usize,
a: &A,
a_base2k: usize,
a_col: usize,
scratch: &mut Scratch<Self>,
) where
@@ -292,7 +293,7 @@ where
A: VecZnxBigToRef<Self>,
{
let (carry, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes() / size_of::<i64>());
vec_znx_big_normalize(res_basek, res, res_col, a_basek, a, a_col, carry);
vec_znx_big_normalize(res, res_base2k, res_offset, res_col, a, a_base2k, a_col, carry);
}
}