Add glwe tensoiring

This commit is contained in:
Pro7ech
2025-10-26 19:03:15 +01:00
parent 6e9cef5ecd
commit 41ca5aafcc
9 changed files with 199 additions and 138 deletions

View File

@@ -93,7 +93,7 @@ impl GLWETensor<Vec<u8>> {
}
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize;
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
GLWETensor {
data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize),
base2k,
@@ -110,7 +110,7 @@ impl GLWETensor<Vec<u8>> {
}
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize;
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize)
}
}

View File

@@ -1,29 +1,77 @@
use poulpy_hal::{
api::{
Convolution, ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero
BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace,
VecZnxSubNegateInplace, VecZnxZero,
},
layouts::{Backend, Module, Scratch, VecZnx, ZnxZero},
layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
};
use crate::{
layouts::{GLWEInfos, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision, GLWE}, ScratchTakeCore
ScratchTakeCore,
layouts::{
GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
TorusPrecision,
},
};
pub trait GLWETensoring<BE: Backend> where Self: Convolution<BE>, Scratch<BE>: ScratchTakeCore<BE> {
fn glwe_tensor<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>) where R: GLWETensorToMut, A: GLWEToRef, B: GLWEToRef{
pub trait GLWETensoring<BE: Backend>
where
Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
/// res = (a (x) b) * 2^{k * a_base2k}
///
/// # Requires
/// * a.base2k() == b.base2k()
/// * res.cols() >= a.cols() + b.cols() - 1
///
/// # Behavior
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
R: GLWETensorToMut,
A: GLWEToRef,
B: GLWEPreparedToRef<BE>,
{
let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
let a: &mut GLWE<&[u8]> = &mut a.to_ref();
let b: &GLWE<&[u8]> = &b.to_ref();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
self.bivariate_convolution(res.data_mut(), res_scale, a, b, scratch);
assert_eq!(a.base2k(), b.base2k());
assert_eq!(a.rank(), res.rank());
let res_cols: usize = res.data.cols();
// Get tmp buffer of min precision between a_prec * b_prec and res_prec
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
// DFT(res) = DFT(a) (x) DFT(b)
self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
// res = IDFT(res)
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
// Normalize and switches basis if required
for res_col in 0..res_cols {
self.vec_znx_big_normalize(
res.base2k().into(),
&mut res.data,
res_col,
a.base2k().into(),
&res_big,
res_col,
scratch_1,
);
}
}
}
pub trait GLWEAdd
where
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace,
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
{
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where
@@ -39,35 +87,38 @@ where
assert_eq!(b.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k());
assert!(res.rank() >= a.rank().max(b.rank()));
assert_eq!(res.base2k(), b.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| {
for i in 0..min_col {
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
});
}
let size: usize = res.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
res.data.zero_at(i, j);
});
});
if a.rank() > b.rank() {
for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
}
} else {
for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
}
}
res.set_base2k(a.base2k());
res.set_k(set_k_binary(res, a, b));
for i in max_col..self_col {
self.vec_znx_zero(res.data_mut(), i);
}
}
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -83,24 +134,22 @@ where
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
(0..(a.rank() + 1).into()).for_each(|i| {
for i in 0..(a.rank() + 1).into() {
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
});
res.set_k(set_k_unary(res, a))
}
}
}
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {}
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
impl<BE: Backend> GLWESub for Module<BE> where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
{
}
pub trait GLWESub
where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace,
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
{
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where
@@ -114,37 +163,40 @@ where
assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k());
assert!(res.rank() >= a.rank().max(b.rank()));
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), res.base2k());
assert_eq!(b.base2k(), res.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| {
for i in 0..min_col {
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
self.vec_znx_negate_inplace(res.data_mut(), i);
});
}
let size: usize = res.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
res.data.zero_at(i, j);
});
});
if a.rank() > b.rank() {
for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
}
} else {
for i in min_col..max_col {
self.vec_znx_negate(res.data_mut(), i, b.data(), i);
}
}
res.set_base2k(a.base2k());
res.set_k(set_k_binary(res, a, b));
for i in max_col..self_col {
self.vec_znx_zero(res.data_mut(), i);
}
}
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -158,13 +210,11 @@ where
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| {
for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
});
res.set_k(set_k_unary(res, a))
}
}
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -178,21 +228,19 @@ where
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| {
for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
});
res.set_k(set_k_unary(res, a))
}
}
}
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {}
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
pub trait GLWERotate<BE: Backend>
where
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>,
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
{
fn glwe_rotate_tmp_bytes(&self) -> usize {
vec_znx_rotate_inplace_tmp_bytes(self.n())
@@ -207,14 +255,18 @@ where
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
assert_eq!(res.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| {
let res_cols = (res.rank() + 1).into();
let a_cols = (a.rank() + 1).into();
for i in 0..a_cols {
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
});
res.set_base2k(a.base2k());
res.set_k(set_k_unary(res, a))
}
for i in a_cols..res_cols {
self.vec_znx_zero(res.data_mut(), i);
}
}
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
@@ -224,9 +276,9 @@ where
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
(0..(res.rank() + 1).into()).for_each(|i| {
for i in 0..(res.rank() + 1).into() {
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
});
}
}
}
@@ -251,9 +303,6 @@ where
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
}
res.set_base2k(a.base2k());
res.set_k(set_k_unary(res, a))
}
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
@@ -286,6 +335,7 @@ where
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
@@ -296,9 +346,6 @@ where
for i in min_rank..(res.rank() + 1).into() {
self.vec_znx_zero(res.data_mut(), i);
}
res.set_k(a.k().min(res.max_k()));
res.set_base2k(a.base2k());
}
}
@@ -364,8 +411,6 @@ where
scratch,
);
}
res.set_k(a.k().min(res.k()));
}
fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
@@ -380,6 +425,7 @@ where
}
}
#[allow(dead_code)]
// c = op(a, b)
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
// If either operands is a ciphertext
@@ -401,6 +447,7 @@ fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> T
}
}
#[allow(dead_code)]
// a = op(a, b)
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
if a.rank() != 0 || b.rank() != 0 {