mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add glwe tensoiring
This commit is contained in:
@@ -93,7 +93,7 @@ impl GLWETensor<Vec<u8>> {
|
||||
}
|
||||
|
||||
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize;
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
|
||||
GLWETensor {
|
||||
data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize),
|
||||
base2k,
|
||||
@@ -110,7 +110,7 @@ impl GLWETensor<Vec<u8>> {
|
||||
}
|
||||
|
||||
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize;
|
||||
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
|
||||
VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +1,77 @@
|
||||
use poulpy_hal::{
|
||||
api::{
|
||||
Convolution, ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero
|
||||
BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
|
||||
VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace,
|
||||
VecZnxSubNegateInplace, VecZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, Scratch, VecZnx, ZnxZero},
|
||||
layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
|
||||
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
layouts::{GLWEInfos, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision, GLWE}, ScratchTakeCore
|
||||
ScratchTakeCore,
|
||||
layouts::{
|
||||
GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
|
||||
TorusPrecision,
|
||||
},
|
||||
};
|
||||
|
||||
pub trait GLWETensoring<BE: Backend> where Self: Convolution<BE>, Scratch<BE>: ScratchTakeCore<BE> {
|
||||
fn glwe_tensor<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>) where R: GLWETensorToMut, A: GLWEToRef, B: GLWEToRef{
|
||||
|
||||
pub trait GLWETensoring<BE: Backend>
|
||||
where
|
||||
Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
/// res = (a (x) b) * 2^{k * a_base2k}
|
||||
///
|
||||
/// # Requires
|
||||
/// * a.base2k() == b.base2k()
|
||||
/// * res.cols() >= a.cols() + b.cols() - 1
|
||||
///
|
||||
/// # Behavior
|
||||
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
|
||||
fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWETensorToMut,
|
||||
A: GLWEToRef,
|
||||
B: GLWEPreparedToRef<BE>,
|
||||
{
|
||||
let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
|
||||
let a: &mut GLWE<&[u8]> = &mut a.to_ref();
|
||||
let b: &GLWE<&[u8]> = &b.to_ref();
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
|
||||
|
||||
self.bivariate_convolution(res.data_mut(), res_scale, a, b, scratch);
|
||||
assert_eq!(a.base2k(), b.base2k());
|
||||
assert_eq!(a.rank(), res.rank());
|
||||
|
||||
let res_cols: usize = res.data.cols();
|
||||
|
||||
// Get tmp buffer of min precision between a_prec * b_prec and res_prec
|
||||
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
|
||||
|
||||
// DFT(res) = DFT(a) (x) DFT(b)
|
||||
self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
|
||||
|
||||
// res = IDFT(res)
|
||||
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
|
||||
|
||||
// Normalize and switches basis if required
|
||||
for res_col in 0..res_cols {
|
||||
self.vec_znx_big_normalize(
|
||||
res.base2k().into(),
|
||||
&mut res.data,
|
||||
res_col,
|
||||
a.base2k().into(),
|
||||
&res_big,
|
||||
res_col,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GLWEAdd
|
||||
where
|
||||
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace,
|
||||
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
|
||||
{
|
||||
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
|
||||
where
|
||||
@@ -39,35 +87,38 @@ where
|
||||
assert_eq!(b.n(), self.n() as u32);
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.base2k(), b.base2k());
|
||||
assert!(res.rank() >= a.rank().max(b.rank()));
|
||||
assert_eq!(res.base2k(), b.base2k());
|
||||
|
||||
if a.rank() == 0 {
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
} else if b.rank() == 0 {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
} else {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
}
|
||||
|
||||
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
|
||||
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
|
||||
let self_col: usize = (res.rank() + 1).into();
|
||||
|
||||
(0..min_col).for_each(|i| {
|
||||
for i in 0..min_col {
|
||||
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
|
||||
});
|
||||
|
||||
if a.rank() > b.rank() {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
} else {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
|
||||
});
|
||||
}
|
||||
|
||||
let size: usize = res.size();
|
||||
(max_col..self_col).for_each(|i| {
|
||||
(0..size).for_each(|j| {
|
||||
res.data.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
if a.rank() > b.rank() {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
}
|
||||
} else {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_binary(res, a, b));
|
||||
for i in max_col..self_col {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
|
||||
@@ -83,24 +134,22 @@ where
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
assert!(res.rank() >= a.rank());
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(a.rank() + 1).into() {
|
||||
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {}
|
||||
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
|
||||
|
||||
impl<BE: Backend> GLWESub for Module<BE> where
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
|
||||
{
|
||||
}
|
||||
|
||||
pub trait GLWESub
|
||||
where
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace,
|
||||
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
|
||||
{
|
||||
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
|
||||
where
|
||||
@@ -114,37 +163,40 @@ where
|
||||
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(b.n(), self.n() as u32);
|
||||
assert_eq!(a.base2k(), b.base2k());
|
||||
assert!(res.rank() >= a.rank().max(b.rank()));
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.base2k(), res.base2k());
|
||||
assert_eq!(b.base2k(), res.base2k());
|
||||
|
||||
if a.rank() == 0 {
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
} else if b.rank() == 0 {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
} else {
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
assert_eq!(res.rank(), b.rank());
|
||||
}
|
||||
|
||||
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
|
||||
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
|
||||
let self_col: usize = (res.rank() + 1).into();
|
||||
|
||||
(0..min_col).for_each(|i| {
|
||||
for i in 0..min_col {
|
||||
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
|
||||
});
|
||||
|
||||
if a.rank() > b.rank() {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
} else {
|
||||
(min_col..max_col).for_each(|i| {
|
||||
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
|
||||
self.vec_znx_negate_inplace(res.data_mut(), i);
|
||||
});
|
||||
}
|
||||
|
||||
let size: usize = res.size();
|
||||
(max_col..self_col).for_each(|i| {
|
||||
(0..size).for_each(|j| {
|
||||
res.data.zero_at(i, j);
|
||||
});
|
||||
});
|
||||
if a.rank() > b.rank() {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
|
||||
}
|
||||
} else {
|
||||
for i in min_col..max_col {
|
||||
self.vec_znx_negate(res.data_mut(), i, b.data(), i);
|
||||
}
|
||||
}
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_binary(res, a, b));
|
||||
for i in max_col..self_col {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
|
||||
@@ -158,13 +210,11 @@ where
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
assert!(res.rank() >= a.rank());
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(a.rank() + 1).into() {
|
||||
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
|
||||
@@ -178,21 +228,19 @@ where
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(res.base2k(), a.base2k());
|
||||
assert!(res.rank() >= a.rank());
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(a.rank() + 1).into() {
|
||||
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {}
|
||||
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
|
||||
|
||||
pub trait GLWERotate<BE: Backend>
|
||||
where
|
||||
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>,
|
||||
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
|
||||
{
|
||||
fn glwe_rotate_tmp_bytes(&self) -> usize {
|
||||
vec_znx_rotate_inplace_tmp_bytes(self.n())
|
||||
@@ -207,14 +255,18 @@ where
|
||||
let a: &GLWE<&[u8]> = &a.to_ref();
|
||||
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert_eq!(res.rank(), a.rank());
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
(0..(a.rank() + 1).into()).for_each(|i| {
|
||||
let res_cols = (res.rank() + 1).into();
|
||||
let a_cols = (a.rank() + 1).into();
|
||||
|
||||
for i in 0..a_cols {
|
||||
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
|
||||
});
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
for i in a_cols..res_cols {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
}
|
||||
|
||||
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
|
||||
@@ -224,9 +276,9 @@ where
|
||||
{
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
|
||||
(0..(res.rank() + 1).into()).for_each(|i| {
|
||||
for i in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,9 +303,6 @@ where
|
||||
for i in 0..res.rank().as_usize() + 1 {
|
||||
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
|
||||
}
|
||||
|
||||
res.set_base2k(a.base2k());
|
||||
res.set_k(set_k_unary(res, a))
|
||||
}
|
||||
|
||||
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
|
||||
@@ -286,6 +335,7 @@ where
|
||||
|
||||
assert_eq!(res.n(), self.n() as u32);
|
||||
assert_eq!(a.n(), self.n() as u32);
|
||||
assert!(res.rank() == a.rank() || a.rank() == 0);
|
||||
|
||||
let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
|
||||
|
||||
@@ -296,9 +346,6 @@ where
|
||||
for i in min_rank..(res.rank() + 1).into() {
|
||||
self.vec_znx_zero(res.data_mut(), i);
|
||||
}
|
||||
|
||||
res.set_k(a.k().min(res.max_k()));
|
||||
res.set_base2k(a.base2k());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,8 +411,6 @@ where
|
||||
scratch,
|
||||
);
|
||||
}
|
||||
|
||||
res.set_k(a.k().min(res.k()));
|
||||
}
|
||||
|
||||
fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
|
||||
@@ -380,6 +425,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
// c = op(a, b)
|
||||
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
|
||||
// If either operands is a ciphertext
|
||||
@@ -401,6 +447,7 @@ fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> T
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
// a = op(a, b)
|
||||
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
|
||||
if a.rank() != 0 || b.rank() != 0 {
|
||||
|
||||
Reference in New Issue
Block a user