glwe operations

This commit is contained in:
Pro7ech
2025-10-16 16:57:30 +02:00
parent 1925571492
commit d27d43759a
9 changed files with 329 additions and 358 deletions

View File

@@ -1,320 +1,292 @@
use poulpy_hal::{
api::{
VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
VecZnxSubInplace, VecZnxSubNegateInplace,
},
layouts::{Backend, DataMut, Scratch, VecZnx, ZnxZero},
layouts::{Backend, Module, Scratch, VecZnx, ZnxZero},
};
use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision};
use crate::{
ScratchTakeCore,
layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision},
};
impl<D> GLWEOperations for GLWEPlaintext<D>
pub trait GLWEAdd
where
D: DataMut,
GLWEPlaintext<D>: GLWEToMut + GLWEInfos,
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace,
{
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where
R: GLWEToMut,
A: GLWEToRef,
B: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &mut GLWE<&[u8]> = &mut a.to_ref();
let b: &GLWE<&[u8]> = &b.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k());
assert!(res.rank() >= a.rank().max(b.rank()));
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| {
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
});
}
let size: usize = res.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
res.data.zero_at(i, j);
});
});
res.set_base2k(a.base2k());
res.set_k(set_k_binary(res, a, b));
}
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
(0..(a.rank() + 1).into()).for_each(|i| {
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
});
res.set_k(set_k_unary(res, a))
}
}
impl<D: DataMut> GLWEOperations for GLWE<D> where GLWE<D>: GLWEToMut + GLWEInfos {}
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {}
pub trait GLWEOperations: GLWEToMut + GLWEInfos + SetGLWEInfos + Sized {
fn add<A, B, M>(&mut self, module: &M, a: &A, b: &B)
pub trait GLWESub
where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace,
{
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where
A: GLWEToRef + GLWEInfos,
B: GLWEToRef + GLWEInfos,
M: VecZnxAdd + VecZnxCopy,
R: GLWEToMut,
A: GLWEToRef,
B: GLWEToRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.base2k(), b.base2k());
assert!(self.rank() >= a.rank().max(b.rank()));
}
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWE<&[u8]> = &b.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k());
assert!(res.rank() >= a.rank().max(b.rank()));
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (self.rank() + 1).into();
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
let b_ref: &GLWE<&[u8]> = &b.to_ref();
let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| {
module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
self.vec_znx_negate_inplace(res.data_mut(), i);
});
}
let size: usize = self_mut.size();
let size: usize = res.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
self_mut.data.zero_at(i, j);
res.data.zero_at(i, j);
});
});
self.set_base2k(a.base2k());
self.set_k(set_k_binary(self, a, b));
res.set_base2k(a.base2k());
res.set_k(set_k_binary(res, a, b));
}
fn add_inplace<A, M>(&mut self, module: &M, a: &A)
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxAddInplace,
R: GLWEToMut,
A: GLWEToRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.base2k(), a.base2k());
assert!(self.rank() >= a.rank())
}
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
});
self.set_k(set_k_unary(self, a))
res.set_k(set_k_unary(res, a))
}
fn sub<A, B, M>(&mut self, module: &M, a: &A, b: &B)
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
where
A: GLWEToRef + GLWEInfos,
B: GLWEToRef + GLWEInfos,
M: VecZnxSub + VecZnxCopy + VecZnxNegateInplace,
R: GLWEToMut,
A: GLWEToRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.base2k(), b.base2k());
assert!(self.rank() >= a.rank().max(b.rank()));
}
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (self.rank() + 1).into();
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
let b_ref: &GLWE<&[u8]> = &b.to_ref();
(0..min_col).for_each(|i| {
module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
} else {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
module.vec_znx_negate_inplace(&mut self_mut.data, i);
});
}
let size: usize = self_mut.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
self_mut.data.zero_at(i, j);
});
});
self.set_base2k(a.base2k());
self.set_k(set_k_binary(self, a, b));
}
fn sub_inplace_ab<A, M>(&mut self, module: &M, a: &A)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxSubInplace,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.base2k(), a.base2k());
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank());
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i);
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
});
self.set_k(set_k_unary(self, a))
res.set_k(set_k_unary(res, a))
}
}
fn sub_inplace_ba<A, M>(&mut self, module: &M, a: &A)
pub trait GLWERotate<BE: Backend>
where
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>,
{
fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxSubNegateInplace,
R: GLWEToMut,
A: GLWEToRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.base2k(), a.base2k());
assert!(self.rank() >= a.rank())
}
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i);
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
});
self.set_k(set_k_unary(self, a))
res.set_base2k(a.base2k());
res.set_k(set_k_unary(res, a))
}
fn rotate<A, M>(&mut self, module: &M, k: i64, a: &A)
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxRotate,
R: GLWEToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.rank(), a.rank())
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
(0..(res.rank() + 1).into()).for_each(|i| {
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
});
}
}
pub trait GLWEMulXpMinusOne<BE: Backend>
where
Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,
{
fn glwe_mul_xp_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A)
where
R: GLWEToMut,
A: GLWEToRef,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
}
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
});
self.set_base2k(a.base2k());
self.set_k(set_k_unary(self, a))
res.set_base2k(a.base2k());
res.set_k(set_k_unary(res, a))
}
fn rotate_inplace<M, BE: Backend>(&mut self, module: &M, k: i64, scratch: &mut Scratch<BE>)
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
where
M: VecZnxRotateInplace<BE>,
R: GLWEToMut,
{
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch);
});
assert_eq!(res.n(), self.n() as u32);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch);
}
}
}
fn mul_xp_minus_one<A, M>(&mut self, module: &M, k: i64, a: &A)
pub trait GLWECopy
where
Self: ModuleN + VecZnxCopy,
{
fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxMulXpMinusOne,
R: GLWEToMut,
A: GLWEToRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.rank(), a.rank())
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
}
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i);
});
self.set_base2k(a.base2k());
self.set_k(set_k_unary(self, a))
res.set_k(a.k().min(res.max_k()));
res.set_base2k(a.base2k());
}
}
fn mul_xp_minus_one_inplace<M, BE: Backend>(&mut self, module: &M, k: i64, scratch: &mut Scratch<BE>)
pub trait GLWEShift<BE: Backend>
where
Self: ModuleN + VecZnxRshInplace<BE>,
{
fn glwe_rsh<R>(&self, k: usize, res: &mut R, scratch: &mut Scratch<BE>)
where
M: VecZnxMulXpMinusOneInplace<BE>,
R: GLWEToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch);
});
}
fn copy<A, M>(&mut self, module: &M, a: &A)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxCopy,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), a.n());
assert_eq!(self.rank(), a.rank());
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let base2k: usize = res.base2k().into();
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch);
}
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().min(self.max_k()));
self.set_base2k(a.base2k());
}
fn rsh<M, BE: Backend>(&mut self, module: &M, k: usize, scratch: &mut Scratch<BE>)
where
M: VecZnxRshInplace<BE>,
{
let base2k: usize = self.base2k().into();
(0..(self.rank() + 1).into()).for_each(|i| {
module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch);
})
}
fn normalize<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
where
A: GLWEToRef + GLWEInfos,
M: VecZnxNormalize<BE>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), a.n());
assert_eq!(self.rank(), a.rank());
}
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWE<&[u8]> = &a.to_ref();
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_normalize(
a.base2k().into(),
&mut self_mut.data,
i,
a.base2k().into(),
&a_ref.data,
i,
scratch,
);
});
self.set_base2k(a.base2k());
self.set_k(a.k().min(self.k()));
}
fn normalize_inplace<M, BE: Backend>(&mut self, module: &M, scratch: &mut Scratch<BE>)
where
M: VecZnxNormalizeInplace<BE>,
{
let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut();
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch);
});
}
}
@@ -324,6 +296,50 @@ impl GLWE<Vec<u8>> {
}
}
pub trait GLWENormalize<BE: Backend>
where
Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE>,
{
fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank());
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_normalize(
res.base2k().into(),
res.data_mut(),
i,
a.base2k().into(),
a.data(),
i,
scratch,
);
}
res.set_k(a.k().min(res.k()));
}
fn glwe_normalize_inplace<R>(&mut self, res: &mut R, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch);
}
}
}
// c = op(a, b)
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
// If either operands is a ciphertext