Add glwe tensoiring

This commit is contained in:
Pro7ech
2025-10-26 19:03:15 +01:00
parent 6e9cef5ecd
commit 41ca5aafcc
9 changed files with 199 additions and 138 deletions

View File

@@ -1,5 +1,6 @@
use poulpy_hal::{ use poulpy_hal::{
api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution, api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module,
test_suite::convolution::test_bivariate_tensoring,
}; };
use crate::FFT64Avx; use crate::FFT64Avx;
@@ -123,5 +124,5 @@ backend_test_suite! {
#[test] #[test]
fn test_convolution_fft64_avx() { fn test_convolution_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64); let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
test_convolution(&module); test_bivariate_tensoring(&module);
} }

View File

@@ -1,9 +1,9 @@
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution}; use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
use crate::FFT64Ref; use crate::FFT64Ref;
#[test] #[test]
fn test_convolution_fft64_ref() { fn test_convolution_fft64_ref() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8); let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
test_convolution(&module); test_bivariate_tensoring(&module);
} }

View File

@@ -93,7 +93,7 @@ impl GLWETensor<Vec<u8>> {
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize; let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
GLWETensor { GLWETensor {
data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize), data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize),
base2k, base2k,
@@ -110,7 +110,7 @@ impl GLWETensor<Vec<u8>> {
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize; let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize) VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize)
} }
} }

View File

@@ -1,29 +1,77 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
Convolution, ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace,
VecZnxSubNegateInplace, VecZnxZero,
}, },
layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
}; };
use crate::{ use crate::{
layouts::{GLWEInfos, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision, GLWE}, ScratchTakeCore ScratchTakeCore,
layouts::{
GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
TorusPrecision,
},
}; };
pub trait GLWETensoring<BE: Backend> where Self: Convolution<BE>, Scratch<BE>: ScratchTakeCore<BE> { pub trait GLWETensoring<BE: Backend>
fn glwe_tensor<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>) where R: GLWETensorToMut, A: GLWEToRef, B: GLWEToRef{ where
Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
/// res = (a (x) b) * 2^{k * a_base2k}
///
/// # Requires
/// * a.base2k() == b.base2k()
/// * res.cols() >= a.cols() + b.cols() - 1
///
/// # Behavior
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
R: GLWETensorToMut,
A: GLWEToRef,
B: GLWEPreparedToRef<BE>,
{
let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
let a: &mut GLWE<&[u8]> = &mut a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWE<&[u8]> = &b.to_ref(); let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
self.bivariate_convolution(res.data_mut(), res_scale, a, b, scratch); assert_eq!(a.base2k(), b.base2k());
assert_eq!(a.rank(), res.rank());
let res_cols: usize = res.data.cols();
// Get tmp buffer of min precision between a_prec * b_prec and res_prec
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
// DFT(res) = DFT(a) (x) DFT(b)
self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
// res = IDFT(res)
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
// Normalize and switches basis if required
for res_col in 0..res_cols {
self.vec_znx_big_normalize(
res.base2k().into(),
&mut res.data,
res_col,
a.base2k().into(),
&res_big,
res_col,
scratch_1,
);
}
} }
} }
pub trait GLWEAdd pub trait GLWEAdd
where where
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
{ {
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B) fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where where
@@ -39,35 +87,38 @@ where
assert_eq!(b.n(), self.n() as u32); assert_eq!(b.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k()); assert_eq!(a.base2k(), b.base2k());
assert!(res.rank() >= a.rank().max(b.rank())); assert_eq!(res.base2k(), b.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into(); let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| { for i in 0..min_col {
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i); self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
});
} }
let size: usize = res.size(); if a.rank() > b.rank() {
(max_col..self_col).for_each(|i| { for i in min_col..max_col {
(0..size).for_each(|j| { self.vec_znx_copy(res.data_mut(), i, a.data(), i);
res.data.zero_at(i, j); }
}); } else {
}); for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
}
}
res.set_base2k(a.base2k()); for i in max_col..self_col {
res.set_k(set_k_binary(res, a, b)); self.vec_znx_zero(res.data_mut(), i);
}
} }
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A) fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -83,24 +134,22 @@ where
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank()); assert!(res.rank() >= a.rank());
(0..(a.rank() + 1).into()).for_each(|i| { for i in 0..(a.rank() + 1).into() {
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i); self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
}); }
res.set_k(set_k_unary(res, a))
} }
} }
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {} impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
impl<BE: Backend> GLWESub for Module<BE> where impl<BE: Backend> GLWESub for Module<BE> where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
{ {
} }
pub trait GLWESub pub trait GLWESub
where where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace, Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
{ {
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B) fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where where
@@ -114,37 +163,40 @@ where
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32); assert_eq!(b.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k()); assert_eq!(res.n(), self.n() as u32);
assert!(res.rank() >= a.rank().max(b.rank())); assert_eq!(a.base2k(), res.base2k());
assert_eq!(b.base2k(), res.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into(); let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| { for i in 0..min_col {
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i); self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
self.vec_znx_negate_inplace(res.data_mut(), i);
});
} }
let size: usize = res.size(); if a.rank() > b.rank() {
(max_col..self_col).for_each(|i| { for i in min_col..max_col {
(0..size).for_each(|j| { self.vec_znx_copy(res.data_mut(), i, a.data(), i);
res.data.zero_at(i, j); }
}); } else {
}); for i in min_col..max_col {
self.vec_znx_negate(res.data_mut(), i, b.data(), i);
}
}
res.set_base2k(a.base2k()); for i in max_col..self_col {
res.set_k(set_k_binary(res, a, b)); self.vec_znx_zero(res.data_mut(), i);
}
} }
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A) fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -158,13 +210,11 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank()); assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| { for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i); self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
}); }
res.set_k(set_k_unary(res, a))
} }
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A) fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -178,21 +228,19 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank()); assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| { for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i); self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
}); }
res.set_k(set_k_unary(res, a))
} }
} }
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {} impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
pub trait GLWERotate<BE: Backend> pub trait GLWERotate<BE: Backend>
where where
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>, Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
{ {
fn glwe_rotate_tmp_bytes(&self) -> usize { fn glwe_rotate_tmp_bytes(&self) -> usize {
vec_znx_rotate_inplace_tmp_bytes(self.n()) vec_znx_rotate_inplace_tmp_bytes(self.n())
@@ -207,14 +255,18 @@ where
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank()); assert_eq!(res.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| { let res_cols = (res.rank() + 1).into();
let a_cols = (a.rank() + 1).into();
for i in 0..a_cols {
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i); self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
}); }
for i in a_cols..res_cols {
res.set_base2k(a.base2k()); self.vec_znx_zero(res.data_mut(), i);
res.set_k(set_k_unary(res, a)) }
} }
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>) fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
@@ -224,9 +276,9 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
(0..(res.rank() + 1).into()).for_each(|i| { for i in 0..(res.rank() + 1).into() {
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch); self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
}); }
} }
} }
@@ -251,9 +303,6 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i); self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
} }
res.set_base2k(a.base2k());
res.set_k(set_k_unary(res, a))
} }
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>) fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
@@ -286,6 +335,7 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1; let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
@@ -296,9 +346,6 @@ where
for i in min_rank..(res.rank() + 1).into() { for i in min_rank..(res.rank() + 1).into() {
self.vec_znx_zero(res.data_mut(), i); self.vec_znx_zero(res.data_mut(), i);
} }
res.set_k(a.k().min(res.max_k()));
res.set_base2k(a.base2k());
} }
} }
@@ -364,8 +411,6 @@ where
scratch, scratch,
); );
} }
res.set_k(a.k().min(res.k()));
} }
fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>) fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
@@ -380,6 +425,7 @@ where
} }
} }
#[allow(dead_code)]
// c = op(a, b) // c = op(a, b)
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
// If either operands is a ciphertext // If either operands is a ciphertext
@@ -401,6 +447,7 @@ fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> T
} }
} }
#[allow(dead_code)]
// a = op(a, b) // a = op(a, b)
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
if a.rank() != 0 || b.rank() != 0 { if a.rank() != 0 || b.rank() != 0 {

View File

@@ -6,39 +6,19 @@ use crate::{
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos}, layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
}; };
impl<BE: Backend> Convolution<BE> for Module<BE> impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
where where
Self: Sized Self: BivariateConvolution<BE>,
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic, Scratch<BE>: ScratchTakeBasic,
{ {
} }
pub trait Convolution<BE: Backend> pub trait BivariateTensoring<BE: Backend>
where where
Self: Sized Self: BivariateConvolution<BE>,
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic, Scratch<BE>: ScratchTakeBasic,
{ {
fn convolution_tmp_bytes(&self, b_size: usize) -> usize { fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
}
fn bivariate_convolution_full<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
R: VecZnxDftToMut<BE>, R: VecZnxDftToMut<BE>,
A: VecZnxToRef, A: VecZnxToRef,
@@ -55,14 +35,48 @@ where
assert!(res_cols >= a_cols + b_cols - 1); assert!(res_cols >= a_cols + b_cols - 1);
for res_col in 0..res_cols { for res_col in 0..res_cols {
let a_min: usize = res_col.saturating_sub(b_cols - 1); self.vec_znx_dft_zero(res, res_col);
let a_max: usize = res_col.min(a_cols - 1); }
self.bivariate_convolution_single(k, res, res_col, a, a_min, b, res_col - a_min, scratch);
for a_col in a_min + 1..a_max + 1 { for a_col in 0..a_cols {
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, res_col - a_col, scratch); for b_col in 0..b_cols {
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
} }
} }
} }
}
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
}
pub trait BivariateConvolution<BE: Backend>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
}
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K} /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
@@ -96,7 +110,7 @@ where
/// [r03, r13, r23, r33] /// [r03, r13, r23, r33]
/// ///
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension. /// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
fn bivariate_convolution_single_add<R, A, B>( fn bivariate_convolution_add<R, A, B>(
&self, &self,
k: i64, k: i64,
res: &mut R, res: &mut R,
@@ -123,10 +137,9 @@ where
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k); self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
} }
} }
fn bivariate_convolution_single<R, A, B>( fn bivariate_convolution<R, A, B>(
&self, &self,
k: i64, k: i64,
res: &mut R, res: &mut R,
@@ -142,6 +155,6 @@ where
B: VecZnxDftToRef<BE>, B: VecZnxDftToRef<BE>,
{ {
self.vec_znx_dft_zero(res, res_col); self.vec_znx_dft_zero(res, res_col);
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, b_col, scratch); self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
} }
} }

View File

@@ -1,6 +1,6 @@
use crate::{ use crate::{
api::{ api::{
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc, BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
}, },
layouts::{ layouts::{
@@ -10,10 +10,10 @@ use crate::{
source::Source, source::Source,
}; };
pub fn test_convolution<M, BE: Backend>(module: &M) pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
where where
M: ModuleN M: ModuleN
+ Convolution<BE> + BivariateTensoring<BE>
+ VecZnxDftAlloc<BE> + VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE> + VecZnxDftApply<BE>
+ VecZnxIdftApplyTmpA<BE> + VecZnxIdftApplyTmpA<BE>
@@ -55,7 +55,7 @@ where
for mut k in 0..(2 * c_size + 1) as i64 { for mut k in 0..(2 * c_size + 1) as i64 {
k -= c_size as i64; k -= c_size as i64;
module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow()); module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
for i in 0..c_cols { for i in 0..c_cols {
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i); module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
@@ -73,13 +73,13 @@ where
); );
} }
convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow()); bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
assert_eq!(c_want, c_have); assert_eq!(c_want, c_have);
} }
} }
fn convolution_naive<R, A, B, M, BE: Backend>( fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
module: &M, module: &M,
base2k: usize, base2k: usize,
k: i64, k: i64,

View File

@@ -29,6 +29,7 @@ where
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos)
} }
#[allow(clippy::too_many_arguments)]
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn ggsw_to_ggsw_blind_rotation<R, A, K>( fn ggsw_to_ggsw_blind_rotation<R, A, K>(
&self, &self,
@@ -74,6 +75,7 @@ where
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
} }
#[allow(clippy::too_many_arguments)]
fn scalar_to_ggsw_blind_rotation<R, S, K>( fn scalar_to_ggsw_blind_rotation<R, S, K>(
&self, &self,
res: &mut R, res: &mut R,
@@ -143,6 +145,7 @@ where
self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
} }
#[allow(clippy::too_many_arguments)]
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn glwe_to_glwe_blind_rotation<R, A, K>( fn glwe_to_glwe_blind_rotation<R, A, K>(
&self, &self,
@@ -162,6 +165,7 @@ where
assert!(bit_rsh + bit_mask <= T::WORD_SIZE); assert!(bit_rsh + bit_mask <= T::WORD_SIZE);
let mut res: GLWE<&mut [u8]> = res.to_mut(); let mut res: GLWE<&mut [u8]> = res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res);

View File

@@ -3,7 +3,7 @@ use poulpy_backend::FFT64Ref;
use crate::tfhe::{ use crate::tfhe::{
bdd_arithmetic::tests::test_suite::{ bdd_arithmetic::tests::test_suite::{
test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra,
test_bdd_srl, test_bdd_sub, test_bdd_xor, test_scalar_to_ggsw_blind_rotation, test_glwe_to_glwe_blind_rotation, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation,
}, },
blind_rotation::CGGI, blind_rotation::CGGI,
}; };

View File

@@ -70,12 +70,8 @@ where
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
test_glwe.encode_vec_i64(&data, base2k.as_usize().into()); test_glwe.encode_vec_i64(&data, base2k.as_usize().into());
println!("pt: {}", test_glwe);
let k: u32 = source.next_u32(); let k: u32 = source.next_u32();
println!("k: {k}");
let mut k_enc_prep: FheUintBlocksPrepared<Vec<u8>, u32, BE> = let mut k_enc_prep: FheUintBlocksPrepared<Vec<u8>, u32, BE> =
FheUintBlocksPrepared::<Vec<u8>, u32, BE>::alloc(&module, &ggsw_infos); FheUintBlocksPrepared::<Vec<u8>, u32, BE>::alloc(&module, &ggsw_infos);
k_enc_prep.encrypt_sk( k_enc_prep.encrypt_sk(