Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -1,8 +1,8 @@
use backend::hal::{
api::{
ScalarZnxAllocBytes, ScratchAvailable, SvpApply, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft,
VecZnxAddScalarInplace, VecZnxAllocBytes, VecZnxAutomorphism, VecZnxBigAllocBytes, VecZnxDftToVecZnxBigTmpA,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSwithcDegree, ZnxZero,
ScratchAvailable, SvpApply, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddScalarInplace,
VecZnxAutomorphism, VecZnxBigAllocBytes, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
VecZnxSwithcDegree, ZnxZero,
},
layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
};
@@ -18,15 +18,15 @@ use crate::{
pub trait GGLWEEncryptSkFamily<B: Backend> = GLWEEncryptSkFamily<B> + GLWESecretFamily<B>;
impl GGLWECiphertext<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
where
Module<B>: GGLWEEncryptSkFamily<B> + VecZnxAllocBytes,
Module<B>: GGLWEEncryptSkFamily<B>,
{
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ (GLWEPlaintext::byte_of(module, basek, k) | module.vec_znx_normalize_tmp_bytes(module.n()))
GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k)
+ (GLWEPlaintext::byte_of(n, basek, k) | module.vec_znx_normalize_tmp_bytes(n))
}
pub fn encrypt_pk_scratch_space<B: Backend>(_module: &Module<B>, _basek: usize, _k: usize, _rank: usize) -> usize {
pub fn encrypt_pk_scratch_space<B: Backend>(_module: &Module<B>, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize {
unimplemented!()
}
}
@@ -42,8 +42,8 @@ impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GGLWEEncryptSkFamily<B> + VecZnxAllocBytes + VecZnxAddScalarInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx<B>,
Module<B>: GGLWEEncryptSkFamily<B> + VecZnxAddScalarInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
@@ -63,16 +63,15 @@ impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
self.rank_out(),
sk.rank()
);
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(pt.n(), module.n());
assert_eq!(self.n(), sk.n());
assert_eq!(pt.n(), sk.n());
assert!(
scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()),
"scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
scratch.available(),
self.rank(),
self.size(),
GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k())
);
assert!(
self.rows() * self.digits() * self.basek() <= self.k(),
@@ -91,7 +90,7 @@ impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
let k: usize = self.k();
let rank_in: usize = self.rank_in();
let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(module, basek, k);
let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
// For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
//
// Example for ksk rank 2 to rank 3:
@@ -125,11 +124,11 @@ impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
}
impl GGLWECiphertextCompressed<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
where
Module<B>: GLWESwitchingKeyEncryptSkFamily<B> + VecZnxAllocBytes,
Module<B>: GLWESwitchingKeyEncryptSkFamily<B>,
{
GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k)
}
}
@@ -144,8 +143,8 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GGLWEEncryptSkFamily<B> + VecZnxAllocBytes + VecZnxAddScalarInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx<B>,
Module<B>: GGLWEEncryptSkFamily<B> + VecZnxAddScalarInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
@@ -165,16 +164,16 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
self.rank_out(),
sk.rank()
);
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(pt.n(), module.n());
assert_eq!(self.n(), sk.n());
assert_eq!(pt.n(), sk.n());
assert!(
scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k()),
scratch.available()
>= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()),
"scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
scratch.available(),
self.rank(),
self.size(),
GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k())
GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k())
);
assert!(
self.rows() * self.digits() * self.basek() <= self.k(),
@@ -196,7 +195,7 @@ impl<D: DataMut> GGLWECiphertextCompressed<D> {
let mut source_xa = Source::new(seed);
let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(module, basek, k);
let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
(0..rank_in).for_each(|col_i| {
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
@@ -237,27 +236,29 @@ pub trait GLWESwitchingKeyEncryptSkFamily<B: Backend> = GGLWEEncryptSkFamily<B>;
impl GLWESwitchingKey<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(
module: &Module<B>,
n: usize,
basek: usize,
k: usize,
rank_in: usize,
rank_out: usize,
) -> usize
where
Module<B>: GLWESwitchingKeyEncryptSkFamily<B> + ScalarZnxAllocBytes + VecZnxAllocBytes,
Module<B>: GLWESwitchingKeyEncryptSkFamily<B>,
{
(GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | module.scalar_znx_alloc_bytes(1))
+ module.scalar_znx_alloc_bytes(rank_in)
+ GLWESecretExec::bytes_of(module, rank_out)
(GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1))
+ ScalarZnx::alloc_bytes(n, rank_in)
+ GLWESecretExec::bytes_of(module, n, rank_out)
}
pub fn encrypt_pk_scratch_space<B: Backend>(
module: &Module<B>,
_n: usize,
_basek: usize,
_k: usize,
_rank_in: usize,
_rank_out: usize,
) -> usize {
GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out)
GGLWECiphertext::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank_out)
}
}
@@ -272,13 +273,8 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GLWESwitchingKeyEncryptSkFamily<B>
+ ScalarZnxAllocBytes
+ VecZnxSwithcDegree
+ VecZnxAllocBytes
+ VecZnxAddScalarInplace,
Scratch<B>:
ScratchAvailable + TakeScalarZnx<B> + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + ScratchAvailable + TakeVecZnx<B>,
Module<B>: GLWESwitchingKeyEncryptSkFamily<B> + VecZnxSwithcDegree + VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
@@ -288,6 +284,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
scratch.available()
>= GLWESwitchingKey::encrypt_sk_scratch_space(
module,
sk_out.n(),
self.basek(),
self.k(),
self.rank_in(),
@@ -297,6 +294,7 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
scratch.available(),
GLWESwitchingKey::encrypt_sk_scratch_space(
module,
sk_out.n(),
self.basek(),
self.k(),
self.rank_in(),
@@ -305,7 +303,9 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
)
}
let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(module, sk_in.rank());
let n: usize = sk_in.n().max(sk_out.n());
let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
(0..sk_in.rank()).for_each(|i| {
module.vec_znx_switch_degree(
&mut sk_in_tmp.as_vec_znx_mut(),
@@ -315,9 +315,9 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
);
});
let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(module, sk_out.rank());
let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(n, sk_out.rank());
{
let (mut tmp, _) = scratch2.take_scalar_znx(module, 1);
let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
(0..sk_out.rank()).for_each(|i| {
module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
@@ -341,17 +341,18 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
impl GLWESwitchingKeyCompressed<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(
module: &Module<B>,
n: usize,
basek: usize,
k: usize,
rank_in: usize,
rank_out: usize,
) -> usize
where
Module<B>: GLWESwitchingKeyEncryptSkFamily<B> + ScalarZnxAllocBytes + VecZnxAllocBytes,
Module<B>: GLWESwitchingKeyEncryptSkFamily<B>,
{
(GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | module.scalar_znx_alloc_bytes(1))
+ module.scalar_znx_alloc_bytes(rank_in)
+ GLWESecretExec::bytes_of(module, rank_out)
(GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1))
+ ScalarZnx::alloc_bytes(n, rank_in)
+ GLWESecretExec::bytes_of(module, n, rank_out)
}
}
@@ -366,13 +367,8 @@ impl<DataSelf: DataMut> GLWESwitchingKeyCompressed<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GLWESwitchingKeyEncryptSkFamily<B>
+ ScalarZnxAllocBytes
+ VecZnxSwithcDegree
+ VecZnxAllocBytes
+ VecZnxAddScalarInplace,
Scratch<B>:
ScratchAvailable + TakeScalarZnx<B> + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + ScratchAvailable + TakeVecZnx<B>,
Module<B>: GLWESwitchingKeyEncryptSkFamily<B> + VecZnxSwithcDegree + VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + ScratchAvailable + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
@@ -382,6 +378,7 @@ impl<DataSelf: DataMut> GLWESwitchingKeyCompressed<DataSelf> {
scratch.available()
>= GLWESwitchingKey::encrypt_sk_scratch_space(
module,
sk_out.n(),
self.basek(),
self.k(),
self.rank_in(),
@@ -391,6 +388,7 @@ impl<DataSelf: DataMut> GLWESwitchingKeyCompressed<DataSelf> {
scratch.available(),
GLWESwitchingKey::encrypt_sk_scratch_space(
module,
sk_out.n(),
self.basek(),
self.k(),
self.rank_in(),
@@ -399,7 +397,9 @@ impl<DataSelf: DataMut> GLWESwitchingKeyCompressed<DataSelf> {
)
}
let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(module, sk_in.rank());
let n: usize = sk_in.n().max(sk_out.n());
let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
(0..sk_in.rank()).for_each(|i| {
module.vec_znx_switch_degree(
&mut sk_in_tmp.as_vec_znx_mut(),
@@ -409,9 +409,9 @@ impl<DataSelf: DataMut> GLWESwitchingKeyCompressed<DataSelf> {
);
});
let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(module, sk_out.rank());
let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_exec(n, sk_out.rank());
{
let (mut tmp, _) = scratch2.take_scalar_znx(module, 1);
let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
(0..sk_out.rank()).for_each(|i| {
module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
@@ -435,15 +435,15 @@ impl<DataSelf: DataMut> GLWESwitchingKeyCompressed<DataSelf> {
pub trait AutomorphismKeyEncryptSkFamily<B: Backend> = GGLWEEncryptSkFamily<B>;
impl AutomorphismKey<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: AutomorphismKeyEncryptSkFamily<B> + ScalarZnxAllocBytes + VecZnxAllocBytes,
Module<B>: AutomorphismKeyEncryptSkFamily<B>,
{
GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank)
GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecret::bytes_of(n, rank)
}
pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, _basek: usize, _k: usize, _rank: usize) -> usize {
GLWESwitchingKey::encrypt_pk_scratch_space(module, _basek, _k, _rank, _rank)
pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize {
GLWESwitchingKey::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank, _rank)
}
}
@@ -458,31 +458,26 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: AutomorphismKeyEncryptSkFamily<B>
+ ScalarZnxAllocBytes
+ VecZnxAllocBytes
+ VecZnxAutomorphism
+ VecZnxSwithcDegree
+ VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeScalarZnx<B> + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + TakeVecZnx<B>,
Module<B>: AutomorphismKeyEncryptSkFamily<B> + VecZnxAutomorphism + VecZnxSwithcDegree + VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(self.n(), sk.n());
assert_eq!(self.rank_out(), self.rank_in());
assert_eq!(sk.rank(), self.rank());
assert!(
scratch.available() >= AutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()),
scratch.available()
>= AutomorphismKey::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()),
"scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
scratch.available(),
self.rank(),
self.size(),
AutomorphismKey::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank())
AutomorphismKey::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank())
)
}
let (mut sk_out, scratch_1) = scratch.take_glwe_secret(module, sk.rank());
let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank());
{
(0..self.rank()).for_each(|i| {
@@ -504,11 +499,11 @@ impl<DataSelf: DataMut> AutomorphismKey<DataSelf> {
}
impl AutomorphismKeyCompressed<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: AutomorphismKeyEncryptSkFamily<B> + ScalarZnxAllocBytes + VecZnxAllocBytes,
Module<B>: AutomorphismKeyEncryptSkFamily<B>,
{
GLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, basek, k, rank, rank) + GLWESecret::bytes_of(module, rank)
GLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, n, basek, k, rank, rank) + GLWESecret::bytes_of(n, rank)
}
}
@@ -523,32 +518,26 @@ impl<DataSelf: DataMut> AutomorphismKeyCompressed<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: AutomorphismKeyEncryptSkFamily<B>
+ ScalarZnxAllocBytes
+ VecZnxAllocBytes
+ VecZnxSwithcDegree
+ VecZnxAutomorphism
+ VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeScalarZnx<B> + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + TakeVecZnx<B>,
Module<B>: AutomorphismKeyEncryptSkFamily<B> + VecZnxSwithcDegree + VecZnxAutomorphism + VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeScalarZnx + TakeVecZnxDft<B> + TakeGLWESecretExec<B> + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(self.n(), sk.n());
assert_eq!(self.rank_out(), self.rank_in());
assert_eq!(sk.rank(), self.rank());
assert!(
scratch.available()
>= AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank()),
>= AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank()),
"scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
scratch.available(),
self.rank(),
self.size(),
AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self.basek(), self.k(), self.rank())
AutomorphismKeyCompressed::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k(), self.rank())
)
}
let (mut sk_out, scratch_1) = scratch.take_glwe_secret(module, sk.rank());
let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank());
{
(0..self.rank()).for_each(|i| {
@@ -573,16 +562,16 @@ pub trait GLWETensorKeyEncryptSkFamily<B: Backend> =
GGLWEEncryptSkFamily<B> + VecZnxBigAllocBytes + VecZnxDftToVecZnxBigTmpA<B> + SvpApply<B>;
impl GLWETensorKey<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: GLWETensorKeyEncryptSkFamily<B> + ScalarZnxAllocBytes + VecZnxAllocBytes,
Module<B>: GLWETensorKeyEncryptSkFamily<B>,
{
GLWESecretExec::bytes_of(module, rank)
+ module.vec_znx_dft_alloc_bytes(rank, 1)
+ module.vec_znx_big_alloc_bytes(1, 1)
+ module.vec_znx_dft_alloc_bytes(1, 1)
+ GLWESecret::bytes_of(module, 1)
+ GLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank)
GLWESecretExec::bytes_of(module, n, rank)
+ module.vec_znx_dft_alloc_bytes(n, rank, 1)
+ module.vec_znx_big_alloc_bytes(n, 1, 1)
+ module.vec_znx_dft_alloc_bytes(n, 1, 1)
+ GLWESecret::bytes_of(n, 1)
+ GLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank)
}
}
@@ -596,35 +585,31 @@ impl<DataSelf: DataMut> GLWETensorKey<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GLWETensorKeyEncryptSkFamily<B>
+ ScalarZnxAllocBytes
+ VecZnxSwithcDegree
+ VecZnxAllocBytes
+ VecZnxAddScalarInplace,
Scratch<B>:
ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeGLWESecretExec<B> + TakeScalarZnx<B> + TakeVecZnx<B>,
Module<B>: GLWETensorKeyEncryptSkFamily<B> + VecZnxSwithcDegree + VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeGLWESecretExec<B> + TakeScalarZnx + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk.rank());
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(self.n(), sk.n());
}
let n: usize = sk.n();
let rank: usize = self.rank();
let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(module, rank);
let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(n, rank);
sk_dft_prep.prepare(module, &sk);
let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(module, rank, 1);
let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
(0..rank).for_each(|i| {
module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
});
let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(module, 1, 1);
let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(module, 1);
let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(module, 1, 1);
let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
(0..rank).for_each(|i| {
(i..rank).for_each(|j| {
@@ -648,11 +633,11 @@ impl<DataSelf: DataMut> GLWETensorKey<DataSelf> {
}
impl GLWETensorKeyCompressed<Vec<u8>> {
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: GLWETensorKeyEncryptSkFamily<B> + ScalarZnxAllocBytes + VecZnxAllocBytes,
Module<B>: GLWETensorKeyEncryptSkFamily<B>,
{
GLWETensorKey::encrypt_sk_scratch_space(module, basek, k, rank)
GLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k, rank)
}
}
@@ -666,35 +651,30 @@ impl<DataSelf: DataMut> GLWETensorKeyCompressed<DataSelf> {
sigma: f64,
scratch: &mut Scratch<B>,
) where
Module<B>: GLWETensorKeyEncryptSkFamily<B>
+ ScalarZnxAllocBytes
+ VecZnxSwithcDegree
+ VecZnxAllocBytes
+ VecZnxAddScalarInplace,
Scratch<B>:
ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeGLWESecretExec<B> + TakeScalarZnx<B> + TakeVecZnx<B>,
Module<B>: GLWETensorKeyEncryptSkFamily<B> + VecZnxSwithcDegree + VecZnxAddScalarInplace,
Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeGLWESecretExec<B> + TakeScalarZnx + TakeVecZnx,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk.rank());
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(self.n(), sk.n());
}
let n: usize = sk.n();
let rank: usize = self.rank();
let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(module, rank);
let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_exec(n, rank);
sk_dft_prep.prepare(module, &sk);
let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(module, rank, 1);
let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
(0..rank).for_each(|i| {
module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
});
let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(module, 1, 1);
let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(module, 1);
let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(module, 1, 1);
let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
let mut source_xa: Source = Source::new(seed_xa);