Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -2,40 +2,42 @@ use poulpy_hal::{
api::{
VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
VecZnxSubABInplace, VecZnxSubBAInplace,
VecZnxSubInplace, VecZnxSubNegateInplace,
},
layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero},
};
use crate::layouts::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, Infos, SetMetaData};
use crate::layouts::{
GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, GLWEPlaintext, LWEInfos, TorusPrecision,
};
impl<D> GLWEOperations for GLWEPlaintext<D>
where
D: DataMut,
GLWEPlaintext<D>: GLWECiphertextToMut + Infos + SetMetaData,
GLWEPlaintext<D>: GLWECiphertextToMut + GLWEInfos,
{
}
impl<D: DataMut> GLWEOperations for GLWECiphertext<D> where GLWECiphertext<D>: GLWECiphertextToMut + Infos + SetMetaData {}
impl<D: DataMut> GLWEOperations for GLWECiphertext<D> where GLWECiphertext<D>: GLWECiphertextToMut + GLWEInfos {}
pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Sized {
fn add<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
where
A: GLWECiphertextToRef,
B: GLWECiphertextToRef,
A: GLWECiphertextToRef + GLWEInfos,
B: GLWECiphertextToRef + GLWEInfos,
Module<BACKEND>: VecZnxAdd + VecZnxCopy,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.basek(), b.basek());
assert_eq!(a.base2k(), b.base2k());
assert!(self.rank() >= a.rank().max(b.rank()));
}
let min_col: usize = a.rank().min(b.rank()) + 1;
let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1;
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (self.rank() + 1).into();
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
@@ -62,26 +64,26 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
});
});
self.set_basek(a.basek());
self.set_basek(a.base2k());
self.set_k(set_k_binary(self, a, b));
}
fn add_inplace<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef + GLWEInfos,
Module<BACKEND>: VecZnxAddInplace,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.basek(), a.basek());
assert_eq!(self.base2k(), a.base2k());
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
@@ -90,21 +92,21 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
fn sub<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
where
A: GLWECiphertextToRef,
B: GLWECiphertextToRef,
A: GLWECiphertextToRef + GLWEInfos,
B: GLWECiphertextToRef + GLWEInfos,
Module<BACKEND>: VecZnxSub + VecZnxCopy + VecZnxNegateInplace,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.basek(), b.basek());
assert_eq!(a.base2k(), b.base2k());
assert!(self.rank() >= a.rank().max(b.rank()));
}
let min_col: usize = a.rank().min(b.rank()) + 1;
let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1;
let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (self.rank() + 1).into();
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
@@ -132,27 +134,27 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
});
});
self.set_basek(a.basek());
self.set_basek(a.base2k());
self.set_k(set_k_binary(self, a, b));
}
fn sub_inplace_ab<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: GLWECiphertextToRef + Infos,
Module<BACKEND>: VecZnxSubABInplace,
A: GLWECiphertextToRef + GLWEInfos,
Module<BACKEND>: VecZnxSubInplace,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.basek(), a.basek());
assert_eq!(self.base2k(), a.base2k());
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(set_k_unary(self, a))
@@ -160,21 +162,21 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
fn sub_inplace_ba<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: GLWECiphertextToRef + Infos,
Module<BACKEND>: VecZnxSubBAInplace,
A: GLWECiphertextToRef + GLWEInfos,
Module<BACKEND>: VecZnxSubNegateInplace,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(self.basek(), a.basek());
assert_eq!(self.base2k(), a.base2k());
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(set_k_unary(self, a))
@@ -182,7 +184,7 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
fn rotate<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
where
A: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef + GLWEInfos,
Module<B>: VecZnxRotate,
{
#[cfg(debug_assertions)]
@@ -194,11 +196,11 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
});
self.set_basek(a.basek());
self.set_basek(a.base2k());
self.set_k(set_k_unary(self, a))
}
@@ -208,14 +210,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
{
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
(0..self_mut.rank() + 1).for_each(|i| {
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch);
});
}
fn mul_xp_minus_one<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
where
A: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef + GLWEInfos,
Module<B>: VecZnxMulXpMinusOne,
{
#[cfg(debug_assertions)]
@@ -227,11 +229,11 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
(0..(a.rank() + 1).into()).for_each(|i| {
module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i);
});
self.set_basek(a.basek());
self.set_basek(a.base2k());
self.set_k(set_k_unary(self, a))
}
@@ -241,14 +243,14 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
{
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
(0..self_mut.rank() + 1).for_each(|i| {
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch);
});
}
fn copy<A, B: Backend>(&mut self, module: &Module<B>, a: &A)
where
A: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef + GLWEInfos,
Module<B>: VecZnxCopy,
{
#[cfg(debug_assertions)]
@@ -260,27 +262,27 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..self_mut.rank() + 1).for_each(|i| {
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().min(self.size() * self.basek()));
self.set_basek(a.basek());
self.set_k(a.k().min(self.max_k()));
self.set_basek(a.base2k());
}
fn rsh<B: Backend>(&mut self, module: &Module<B>, k: usize, scratch: &mut Scratch<B>)
where
Module<B>: VecZnxRshInplace<B>,
{
let basek: usize = self.basek();
(0..self.cols()).for_each(|i| {
module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data, i, scratch);
let base2k: usize = self.base2k().into();
(0..(self.rank() + 1).into()).for_each(|i| {
module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch);
})
}
fn normalize<A, B: Backend>(&mut self, module: &Module<B>, a: &A, scratch: &mut Scratch<B>)
where
A: GLWECiphertextToRef,
A: GLWECiphertextToRef + GLWEInfos,
Module<B>: VecZnxNormalize<B>,
{
#[cfg(debug_assertions)]
@@ -292,10 +294,18 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_normalize(
a.base2k().into(),
&mut self_mut.data,
i,
a.base2k().into(),
&a_ref.data,
i,
scratch,
);
});
self.set_basek(a.basek());
self.set_basek(a.base2k());
self.set_k(a.k().min(self.k()));
}
@@ -304,8 +314,8 @@ pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
Module<B>: VecZnxNormalizeInplace<B>,
{
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch);
(0..(self_mut.rank() + 1).into()).for_each(|i| {
module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch);
});
}
}
@@ -317,7 +327,7 @@ impl GLWECiphertext<Vec<u8>> {
}
// c = op(a, b)
fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize {
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
// If either operands is a ciphertext
if a.rank() != 0 || b.rank() != 0 {
// If a is a plaintext (but b ciphertext)
@@ -338,7 +348,7 @@ fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize {
}
// a = op(a, b)
fn set_k_unary(a: &impl Infos, b: &impl Infos) -> usize {
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
if a.rank() != 0 || b.rank() != 0 {
a.k().min(b.k())
} else {