mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add glwe tensoiring
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use poulpy_hal::{
|
||||
api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution,
|
||||
api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module,
|
||||
test_suite::convolution::test_bivariate_tensoring,
|
||||
};
|
||||
|
||||
use crate::FFT64Avx;
|
||||
@@ -123,5 +124,5 @@ backend_test_suite! {
|
||||
#[test]
|
||||
fn test_convolution_fft64_avx() {
|
||||
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
|
||||
test_convolution(&module);
|
||||
test_bivariate_tensoring(&module);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution};
|
||||
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
|
||||
|
||||
use crate::FFT64Ref;
|
||||
|
||||
#[test]
|
||||
fn test_convolution_fft64_ref() {
|
||||
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
|
||||
test_convolution(&module);
|
||||
test_bivariate_tensoring(&module);
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ impl GLWETensor<Vec<u8>> {
|
||||
}
|
||||
|
||||
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize;
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
|
||||
GLWETensor {
|
||||
data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize),
|
||||
base2k,
|
||||
@@ -110,7 +110,7 @@ impl GLWETensor<Vec<u8>> {
|
||||
}
|
||||
|
||||
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize;
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
|
||||
VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +1,77 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
Convolution, ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero
|
||||
BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
|
||||
VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace,
|
||||
VecZnxSubNegateInplace, VecZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, Scratch, VecZnx, ZnxZero},
|
||||
layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
|
||||
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
layouts::{GLWEInfos, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision, GLWE}, ScratchTakeCore
|
||||
ScratchTakeCore,
|
||||
layouts::{
|
||||
GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
|
||||
TorusPrecision,
|
||||
},
|
||||
};
|
||||
|
||||
pub trait GLWETensoring<BE: Backend> where Self: Convolution<BE>, Scratch<BE>: ScratchTakeCore<BE> {
|
||||
fn glwe_tensor<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>) where R: GLWETensorToMut, A: GLWEToRef, B: GLWEToRef{
|
||||
|
||||
pub trait GLWETensoring<BE: Backend>
|
||||
where
|
||||
Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
/// res = (a (x) b) * 2^{k * a_base2k}
|
||||
///
|
||||
/// # Requires
|
||||
/// * a.base2k() == b.base2k()
|
||||
/// * res.cols() >= a.cols() + b.cols() - 1
|
||||
///
|
||||
/// # Behavior
|
||||
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
|
||||
fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWETensorToMut,
|
||||
A: GLWEToRef,
|
||||
B: GLWEPreparedToRef<BE>,
|
||||
{
|
||||
let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
|
||||
let a: &mut GLWE<&[u8]> = &mut a.to_ref();
|
||||
let b: &GLWE<&[u8]> = &b.to_ref();
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
|
||||
|
||||
self.bivariate_convolution(res.data_mut(), res_scale, a, b, scratch);
|
||||
assert_eq!(a.base2k(), b.base2k());
|
||||
assert_eq!(a.rank(), res.rank());
|
||||
|
||||
let res_cols: usize = res.data.cols();
|
||||
|
||||
// Get tmp buffer of min precision between a_prec * b_prec and res_prec
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
|
||||
|
||||
// DFT(res) = DFT(a) (x) DFT(b)
|
||||
self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
|
||||
|
||||
// res = IDFT(res)
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
// Normalize and switches basis if required
|
||||
for res_col in 0..res_cols {
|
||||
self.vec_znx_big_normalize(
|
||||
res.base2k().into(),
|
||||
&mut res.data,
|
||||
res_col,
|
||||
a.base2k().into(),
|
||||
&res_big,
|
||||
res_col,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEAdd
|
||||
where
|
||||
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace,
|
||||
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
|
||||
{
|
||||
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
|
||||
where
|
||||
@@ -39,35 +87,38 @@ where
|
||||
assert_eq!(b.n(), self.n() as u32);
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.base2k(), b.base2k());
|
||||
assert!(res.rank() >= a.rank().max(b.rank()));
|
||||
assert_eq!(res.base2k(), b.base2k());
|
||||
|
||||
if a.rank() == 0 {
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
} else if b.rank() == 0 {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
} else {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
}
|
||||
|
||||
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
|
||||
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
|
||||
let self_col: usize = (res.rank() + 1).into();
|
||||
|
||||
(0..min_col).for_each(|i| {
|
||||
for i in 0..min_col {
|
||||
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
|
||||
});
|
||||
|
||||
if a.rank() > b.rank() {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
} else {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
|
||||
});
|
||||
}
|
||||
|
||||
let size: usize = res.size();
|
||||
(max_col..self_col).for_each(|i| {
|
||||
(0..size).for_each(|j| {
|
||||
res.data.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
if a.rank() > b.rank() {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
}
|
||||
} else {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_binary(res, a, b));
|
||||
for i in max_col..self_col {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
|
||||
@@ -83,24 +134,22 @@ where
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
assert!(res.rank() >= a.rank());
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(a.rank() + 1).into() {
|
||||
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {}
|
||||
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
|
||||
|
||||
impl<BE: Backend> GLWESub for Module<BE> where
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
|
||||
{
|
||||
}
|
||||
|
||||
pub trait GLWESub
|
||||
where
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace,
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
|
||||
{
|
||||
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
|
||||
where
|
||||
@@ -114,37 +163,40 @@ where
|
||||
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(b.n(), self.n() as u32);
|
||||
assert_eq!(a.base2k(), b.base2k());
|
||||
assert!(res.rank() >= a.rank().max(b.rank()));
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.base2k(), res.base2k());
|
||||
assert_eq!(b.base2k(), res.base2k());
|
||||
|
||||
if a.rank() == 0 {
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
} else if b.rank() == 0 {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
} else {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
}
|
||||
|
||||
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
|
||||
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
|
||||
let self_col: usize = (res.rank() + 1).into();
|
||||
|
||||
(0..min_col).for_each(|i| {
|
||||
for i in 0..min_col {
|
||||
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
|
||||
});
|
||||
|
||||
if a.rank() > b.rank() {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
} else {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
|
||||
self.vec_znx_negate_inplace(res.data_mut(), i);
|
||||
});
|
||||
}
|
||||
|
||||
let size: usize = res.size();
|
||||
(max_col..self_col).for_each(|i| {
|
||||
(0..size).for_each(|j| {
|
||||
res.data.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
if a.rank() > b.rank() {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
}
|
||||
} else {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_negate(res.data_mut(), i, b.data(), i);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_binary(res, a, b));
|
||||
for i in max_col..self_col {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
|
||||
@@ -158,13 +210,11 @@ where
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
assert!(res.rank() >= a.rank());
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(a.rank() + 1).into() {
|
||||
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
|
||||
@@ -178,21 +228,19 @@ where
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
assert!(res.rank() >= a.rank());
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(a.rank() + 1).into() {
|
||||
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {}
|
||||
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
|
||||
|
||||
pub trait GLWERotate<BE: Backend>
|
||||
where
|
||||
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>,
|
||||
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
|
||||
{
|
||||
fn glwe_rotate_tmp_bytes(&self) -> usize {
|
||||
vec_znx_rotate_inplace_tmp_bytes(self.n())
|
||||
@@ -207,14 +255,18 @@ where
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
let res_cols = (res.rank() + 1).into();
|
||||
let a_cols = (a.rank() + 1).into();
|
||||
|
||||
for i in 0..a_cols {
|
||||
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
for i in a_cols..res_cols {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
|
||||
@@ -224,9 +276,9 @@ where
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
|
||||
(0..(res.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,9 +303,6 @@ where
|
||||
for i in 0..res.rank().as_usize() + 1 {
|
||||
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
|
||||
}
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
|
||||
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
|
||||
@@ -286,6 +335,7 @@ where
|
||||
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
|
||||
|
||||
@@ -296,9 +346,6 @@ where
|
||||
for i in min_rank..(res.rank() + 1).into() {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
|
||||
res.set_k(a.k().min(res.max_k()));
|
||||
res.set_base2k(a.base2k());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,8 +411,6 @@ where
|
||||
scratch,
|
||||
);
|
||||
}
|
||||
|
||||
res.set_k(a.k().min(res.k()));
|
||||
}
|
||||
|
||||
fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
|
||||
@@ -380,6 +425,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
// c = op(a, b)
|
||||
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
|
||||
// If either operands is a ciphertext
|
||||
@@ -401,6 +447,7 @@ fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> T
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
// a = op(a, b)
|
||||
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
|
||||
if a.rank() != 0 || b.rank() != 0 {
|
||||
|
||||
@@ -6,39 +6,19 @@ use crate::{
|
||||
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
|
||||
};
|
||||
|
||||
impl<BE: Backend> Convolution<BE> for Module<BE>
|
||||
impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Self: BivariateConvolution<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
}
|
||||
|
||||
pub trait Convolution<BE: Backend>
|
||||
pub trait BivariateTensoring<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Self: BivariateConvolution<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
|
||||
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
|
||||
}
|
||||
|
||||
fn bivariate_convolution_full<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: VecZnxDftToMut<BE>,
|
||||
A: VecZnxToRef,
|
||||
@@ -55,14 +35,48 @@ where
|
||||
assert!(res_cols >= a_cols + b_cols - 1);
|
||||
|
||||
for res_col in 0..res_cols {
|
||||
let a_min: usize = res_col.saturating_sub(b_cols - 1);
|
||||
let a_max: usize = res_col.min(a_cols - 1);
|
||||
self.bivariate_convolution_single(k, res, res_col, a, a_min, b, res_col - a_min, scratch);
|
||||
for a_col in a_min + 1..a_max + 1 {
|
||||
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, res_col - a_col, scratch);
|
||||
self.vec_znx_dft_zero(res, res_col);
|
||||
}
|
||||
|
||||
for a_col in 0..a_cols {
|
||||
for b_col in 0..b_cols {
|
||||
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
}
|
||||
|
||||
pub trait BivariateConvolution<BE: Backend>
|
||||
where
|
||||
Self: Sized
|
||||
+ ModuleN
|
||||
+ SvpPPolAlloc<BE>
|
||||
+ SvpApplyDftToDft<BE>
|
||||
+ SvpPrepare<BE>
|
||||
+ SvpPPolBytesOf
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxDftAddScaledInplace<BE>
|
||||
+ VecZnxDftZero<BE>,
|
||||
Scratch<BE>: ScratchTakeBasic,
|
||||
{
|
||||
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
|
||||
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
|
||||
}
|
||||
|
||||
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
|
||||
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
|
||||
@@ -96,7 +110,7 @@ where
|
||||
/// [r03, r13, r23, r33]
|
||||
///
|
||||
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
|
||||
fn bivariate_convolution_single_add<R, A, B>(
|
||||
fn bivariate_convolution_add<R, A, B>(
|
||||
&self,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
@@ -123,10 +137,9 @@ where
|
||||
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
|
||||
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn bivariate_convolution_single<R, A, B>(
|
||||
fn bivariate_convolution<R, A, B>(
|
||||
&self,
|
||||
k: i64,
|
||||
res: &mut R,
|
||||
@@ -142,6 +155,6 @@ where
|
||||
B: VecZnxDftToRef<BE>,
|
||||
{
|
||||
self.vec_znx_dft_zero(res, res_col);
|
||||
self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, b_col, scratch);
|
||||
self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
api::{
|
||||
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
|
||||
BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
|
||||
VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
|
||||
},
|
||||
layouts::{
|
||||
@@ -10,10 +10,10 @@ use crate::{
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub fn test_convolution<M, BE: Backend>(module: &M)
|
||||
pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
|
||||
where
|
||||
M: ModuleN
|
||||
+ Convolution<BE>
|
||||
+ BivariateTensoring<BE>
|
||||
+ VecZnxDftAlloc<BE>
|
||||
+ VecZnxDftApply<BE>
|
||||
+ VecZnxIdftApplyTmpA<BE>
|
||||
@@ -55,7 +55,7 @@ where
|
||||
for mut k in 0..(2 * c_size + 1) as i64 {
|
||||
k -= c_size as i64;
|
||||
|
||||
module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
|
||||
module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
|
||||
|
||||
for i in 0..c_cols {
|
||||
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
|
||||
@@ -73,13 +73,13 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
|
||||
bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
|
||||
|
||||
assert_eq!(c_want, c_have);
|
||||
}
|
||||
}
|
||||
|
||||
fn convolution_naive<R, A, B, M, BE: Backend>(
|
||||
fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
|
||||
module: &M,
|
||||
base2k: usize,
|
||||
k: i64,
|
||||
|
||||
@@ -29,6 +29,7 @@ where
|
||||
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
|
||||
fn ggsw_to_ggsw_blind_rotation<R, A, K>(
|
||||
&self,
|
||||
@@ -74,6 +75,7 @@ where
|
||||
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn scalar_to_ggsw_blind_rotation<R, S, K>(
|
||||
&self,
|
||||
res: &mut R,
|
||||
@@ -143,6 +145,7 @@ where
|
||||
self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
|
||||
fn glwe_to_glwe_blind_rotation<R, A, K>(
|
||||
&self,
|
||||
@@ -162,6 +165,7 @@ where
|
||||
assert!(bit_rsh + bit_mask <= T::WORD_SIZE);
|
||||
|
||||
let mut res: GLWE<&mut [u8]> = res.to_mut();
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
|
||||
let (mut tmp_res, scratch_1) = scratch.take_glwe(&res);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use poulpy_backend::FFT64Ref;
|
||||
use crate::tfhe::{
|
||||
bdd_arithmetic::tests::test_suite::{
|
||||
test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra,
|
||||
test_bdd_srl, test_bdd_sub, test_bdd_xor, test_scalar_to_ggsw_blind_rotation, test_glwe_to_glwe_blind_rotation,
|
||||
test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation,
|
||||
},
|
||||
blind_rotation::CGGI,
|
||||
};
|
||||
|
||||
@@ -70,12 +70,8 @@ where
|
||||
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
test_glwe.encode_vec_i64(&data, base2k.as_usize().into());
|
||||
|
||||
println!("pt: {}", test_glwe);
|
||||
|
||||
let k: u32 = source.next_u32();
|
||||
|
||||
println!("k: {k}");
|
||||
|
||||
let mut k_enc_prep: FheUintBlocksPrepared<Vec<u8>, u32, BE> =
|
||||
FheUintBlocksPrepared::<Vec<u8>, u32, BE>::alloc(&module, &ggsw_infos);
|
||||
k_enc_prep.encrypt_sk(
|
||||
|
||||
Reference in New Issue
Block a user