Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -2,7 +2,7 @@ use poulpy_hal::{
api::{
ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
VmpPMatAlloc, VmpPrepare,
},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut},
@@ -14,21 +14,27 @@ use std::marker::PhantomData;
use poulpy_core::{
Distribution,
layouts::{
GGSWCiphertext, LWESecret,
GGSWCiphertext, GGSWInfos, LWESecret,
compressed::GGSWCiphertextCompressed,
prepared::{GGSWCiphertextPrepared, GLWESecretPrepared},
},
};
use crate::tfhe::blind_rotation::{
BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyPrepared,
BlindRotationKeyPreparedAlloc, CGGI,
BlindRotationKey, BlindRotationKeyAlloc, BlindRotationKeyCompressed, BlindRotationKeyEncryptSk, BlindRotationKeyInfos,
BlindRotationKeyPrepared, BlindRotationKeyPreparedAlloc, CGGI,
};
impl BlindRotationKeyAlloc for BlindRotationKey<Vec<u8>, CGGI> {
fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut data: Vec<GGSWCiphertext<Vec<u8>>> = Vec::with_capacity(n_lwe);
(0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(n_gglwe, basek, k, rows, 1, rank)));
fn alloc<A>(infos: &A) -> Self
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSWCiphertext<Vec<u8>>> = Vec::with_capacity(infos.n_lwe().into());
for _ in 0..infos.n_lwe().as_usize() {
data.push(GGSWCiphertext::alloc(infos));
}
Self {
keys: data,
dist: Distribution::NONE,
@@ -38,11 +44,12 @@ impl BlindRotationKeyAlloc for BlindRotationKey<Vec<u8>, CGGI> {
}
impl BlindRotationKey<Vec<u8>, CGGI> {
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn generate_from_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GGSWInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
{
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
GGSWCiphertext::encrypt_sk_scratch_space(module, infos)
}
}
@@ -56,7 +63,7 @@ where
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -78,9 +85,11 @@ where
{
#[cfg(debug_assertions)]
{
assert_eq!(self.keys.len(), sk_lwe.n());
assert!(sk_glwe.n() <= module.n());
assert_eq!(sk_glwe.rank(), self.keys[0].rank());
use poulpy_core::layouts::{GLWEInfos, LWEInfos};
assert_eq!(self.keys.len() as u32, sk_lwe.n());
assert!(sk_glwe.n() <= module.n() as u32);
assert_eq!(sk_glwe.rank(), self.rank());
match sk_lwe.dist() {
Distribution::BinaryBlock(_)
| Distribution::BinaryFixed(_)
@@ -94,7 +103,7 @@ where
self.dist = sk_lwe.dist();
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n(), 1);
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n().into(), 1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref();
self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| {
@@ -108,13 +117,12 @@ impl<B: Backend> BlindRotationKeyPreparedAlloc<B> for BlindRotationKeyPrepared<V
where
Module<B>: VmpPMatAlloc<B> + VmpPrepare<B>,
{
fn alloc(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut data: Vec<GGSWCiphertextPrepared<Vec<u8>, B>> = Vec::with_capacity(n_lwe);
(0..n_lwe).for_each(|_| {
data.push(GGSWCiphertextPrepared::alloc(
module, basek, k, rows, 1, rank,
))
});
fn alloc<A>(module: &Module<B>, infos: &A) -> Self
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSWCiphertextPrepared<Vec<u8>, B>> = Vec::with_capacity(infos.n_lwe().into());
(0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextPrepared::alloc(module, infos)));
Self {
data,
dist: Distribution::NONE,
@@ -125,13 +133,12 @@ where
}
impl BlindRotationKeyCompressed<Vec<u8>, CGGI> {
pub fn alloc(n_gglwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut data: Vec<GGSWCiphertextCompressed<Vec<u8>>> = Vec::with_capacity(n_lwe);
(0..n_lwe).for_each(|_| {
data.push(GGSWCiphertextCompressed::alloc(
n_gglwe, basek, k, rows, 1, rank,
))
});
pub fn alloc<A>(infos: &A) -> Self
where
A: BlindRotationKeyInfos,
{
let mut data: Vec<GGSWCiphertextCompressed<Vec<u8>>> = Vec::with_capacity(infos.n_lwe().into());
(0..infos.n_lwe().as_usize()).for_each(|_| data.push(GGSWCiphertextCompressed::alloc(infos)));
Self {
keys: data,
dist: Distribution::NONE,
@@ -139,11 +146,12 @@ impl BlindRotationKeyCompressed<Vec<u8>, CGGI> {
}
}
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn generate_from_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
where
A: GGSWInfos,
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
{
GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, basek, k, rank)
GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, infos)
}
}
@@ -168,7 +176,7 @@ impl<D: DataMut> BlindRotationKeyCompressed<D, CGGI> {
+ VecZnxIdftApplyConsume<B>
+ VecZnxNormalizeTmpBytes
+ VecZnxFillUniform
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxAddInplace
+ VecZnxNormalizeInplace<B>
+ VecZnxAddNormal
@@ -178,9 +186,11 @@ impl<D: DataMut> BlindRotationKeyCompressed<D, CGGI> {
{
#[cfg(debug_assertions)]
{
assert_eq!(self.keys.len(), sk_lwe.n());
assert!(sk_glwe.n() <= module.n());
assert_eq!(sk_glwe.rank(), self.keys[0].rank());
use poulpy_core::layouts::{GLWEInfos, LWEInfos};
assert_eq!(self.n_lwe(), sk_lwe.n());
assert!(sk_glwe.n() <= module.n() as u32);
assert_eq!(sk_glwe.rank(), self.rank());
match sk_lwe.dist() {
Distribution::BinaryBlock(_)
| Distribution::BinaryFixed(_)
@@ -194,7 +204,7 @@ impl<D: DataMut> BlindRotationKeyCompressed<D, CGGI> {
self.dist = sk_lwe.dist();
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n(), 1);
let mut pt: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(sk_glwe.n().into(), 1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data().to_ref();
let mut source_xa: Source = Source::new(seed_xa);