Added tensor key & associated test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-19 18:06:14 +02:00
parent c5fe07188f
commit 8f2eac4928
12 changed files with 610 additions and 28 deletions

View File

@@ -9,9 +9,9 @@ use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source; use sampling::source::Source;
pub struct ScalarZnx<D> { pub struct ScalarZnx<D> {
data: D, pub(crate) data: D,
n: usize, pub(crate) n: usize,
cols: usize, pub(crate) cols: usize,
} }
impl<D> ZnxInfos for ScalarZnx<D> { impl<D> ZnxInfos for ScalarZnx<D> {

View File

@@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::ffi::svp; use crate::ffi::svp;
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64};
pub struct ScalarZnxDft<D, B: Backend> { pub struct ScalarZnxDft<D, B: Backend> {
data: D, data: D,
@@ -92,6 +92,16 @@ impl<D, B: Backend> ScalarZnxDft<D, B> {
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B>{
VecZnxDft{
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
} }
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>; pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
@@ -158,3 +168,63 @@ impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
} }
} }
} }
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}

View File

@@ -1,5 +1,6 @@
use crate::DataView; use crate::DataView;
use crate::DataViewMut; use crate::DataViewMut;
use crate::ScalarZnx;
use crate::ZnxSliceSize; use crate::ZnxSliceSize;
use crate::ZnxZero; use crate::ZnxZero;
use crate::alloc_aligned; use crate::alloc_aligned;
@@ -128,6 +129,15 @@ impl<D> VecZnx<D> {
size, size,
} }
} }
pub fn to_scalar_znx(self) -> ScalarZnx<D>{
debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols);
ScalarZnx{
data: self.data,
n: self.n,
cols: self.cols,
}
}
} }
/// Copies the coefficients of `a` on the receiver. /// Copies the coefficients of `a` on the receiver.

View File

@@ -8,11 +8,11 @@ use crate::{
use std::fmt; use std::fmt;
pub struct VecZnxDft<D, B: Backend> { pub struct VecZnxDft<D, B: Backend> {
data: D, pub(crate) data: D,
n: usize, pub(crate) n: usize,
cols: usize, pub(crate) cols: usize,
size: usize, pub(crate) size: usize,
_phantom: PhantomData<B>, pub(crate) _phantom: PhantomData<B>,
} }
impl<D, B: Backend> VecZnxDft<D, B> { impl<D, B: Backend> VecZnxDft<D, B> {

View File

@@ -6,6 +6,7 @@ use base2k::{
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos, SetRow}, elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext, glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_ciphertext_fourier::GLWECiphertextFourier,
@@ -78,6 +79,20 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx_dft(rank + 1, size) + module.bytes_of_vec_znx_dft(rank + 1, size)
} }
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
auto_key_size: usize,
rank: usize,
) -> usize {
let size: usize = in_size.min(out_size);
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size);
let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size);
let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size);
tmp_dft + tmp_idft + vmp
}
pub fn external_product_scratch_space( pub fn external_product_scratch_space(
module: &Module<FFT64>, module: &Module<FFT64>,
out_size: usize, out_size: usize,
@@ -182,6 +197,73 @@ where
}); });
} }
pub fn automorphism<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
lhs.rank(),
"ggsw_out rank: {} != ggsw_in rank: {}",
self.rank(),
lhs.rank()
);
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_in rank: {} != auto_key rank: {}",
self.rank(),
rhs.rank()
);
}
let size: usize = self.size().min(lhs.size());
let cols: usize = self.rank() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size);
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size);
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: tmp_idft_data,
basek: self.basek(),
k: self.k(),
};
(0..cols).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2);
(0..cols).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
});
self.set_row(module, row_j, col_i, &tmp_dft);
});
});
tmp_dft.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank() + 1).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_dft);
});
});
}
pub fn external_product<DataLhs, DataRhs>( pub fn external_product<DataLhs, DataRhs>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,

View File

@@ -7,6 +7,7 @@ pub mod glwe_ciphertext_fourier;
pub mod glwe_plaintext; pub mod glwe_plaintext;
pub mod keys; pub mod keys;
pub mod keyswitch_key; pub mod keyswitch_key;
pub mod tensor_key;
#[cfg(test)] #[cfg(test)]
mod test_fft64; mod test_fft64;
mod utils; mod utils;

125
core/src/tensor_key.rs Normal file
View File

@@ -0,0 +1,125 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc,
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef,
};
use sampling::source::Source;
use crate::{
elem::Infos,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
};
pub struct TensorKey<C, B: Backend> {
pub(crate) keys: Vec<GLWESwitchingKey<C, B>>,
}
impl TensorKey<Vec<u8>, FFT64> {
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
let pairs: usize = ((rank + 1) * rank) >> 1;
(0..pairs).for_each(|_| {
keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank));
});
Self { keys: keys }
}
}
impl<T, B: Backend> Infos for TensorKey<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.keys[0].inner()
}
fn basek(&self) -> usize {
self.keys[0].basek()
}
fn k(&self) -> usize {
self.keys[0].k()
}
}
impl<T, B: Backend> TensorKey<T, B> {
pub fn rank(&self) -> usize {
self.keys[0].rank()
}
pub fn rank_in(&self) -> usize {
self.keys[0].rank_in()
}
pub fn rank_out(&self) -> usize {
self.keys[0].rank_out()
}
}
impl TensorKey<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size)
}
}
impl<DataSelf> TensorKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn encrypt_sk<DataSk>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnxDft<DataSk, FFT64>: VecZnxDftToRef<FFT64> + ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
}
let rank: usize = self.rank();
(0..rank).for_each(|i| {
(i..rank).for_each(|j| {
let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1);
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
let sk_ij: ScalarZnx<&mut [u8]> = module
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
.to_vec_znx_small()
.to_scalar_znx();
let sk_ij: SecretKey<&mut [u8]> = SecretKey {
data: sk_ij,
dist: sk_dft.dist,
};
self.at_mut(i, j).encrypt_sk(
module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1,
);
});
})
}
// Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j])
pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey<DataSelf, FFT64> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank();
&self.keys[i * rank + j - (i * (i + 1) / 2)]
}
// Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j])
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey<DataSelf, FFT64> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank();
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}

View File

@@ -18,6 +18,14 @@ fn automorphism() {
}); });
} }
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank);
});
}
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) { fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n); let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek; let rows = (k_ksk + basek - 1) / basek;
@@ -115,3 +123,94 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize,
}); });
}); });
} }
fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size())
| AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
// gglwe_{s1}(s0) = s0 -> s1
auto_key.encrypt_sk(
&module,
p0,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2
auto_key_apply.encrypt_sk(
&module,
p1,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
(0..rank).for_each(|i| {
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
});
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_auto_dft.dft(&module, &sk_auto);
(0..auto_key.rank_in()).for_each(|col_i| {
(0..auto_key.rows()).for_each(|row_i| {
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ksk,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}

View File

@@ -5,6 +5,7 @@ use base2k::{
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos}, elem::{GetRow, Infos},
ggsw_ciphertext::GGSWCiphertext, ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_ciphertext_fourier::GLWECiphertextFourier,
@@ -104,6 +105,123 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank:
}); });
} }
// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
// let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
// let rows: usize = (k_ggsw + basek - 1) / basek;
//
// let mut ct_ggsw_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
// let mut ct_ggsw_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
// let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
//
// let mut pt_ggsw_in: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
// let mut pt_ggsw_out: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs);
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size())
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size())
// | GGSWCiphertext::automorphism_scratch_space(
// &module,
// ct_ggsw_out.size(),
// ct_ggsw_in.size(),
// auto_key.size(),
// rank,
// ),
// );
//
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk_dft.dft(&module, &sk);
//
// ct_ggsw_in.encrypt_sk(
// &module,
// &pt_ggsw_in,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// auto_key.encrypt_sk(
// &module,
// p,
// &sk,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow());
//
// let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
//
// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
//
// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
//
// if col_j > 0 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
//
// let noise_have: f64 = pt.data.std(0, basek).log2();
//
// let var_gct_err_lhs: f64 = sigma * sigma;
// let var_gct_err_rhs: f64 = 0f64;
//
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
// let var_a0_err: f64 = sigma * sigma;
// let var_a1_err: f64 = 1f64 / 12f64;
//
// let noise_want: f64 = noise_ggsw_product(
// module.n() as f64,
// basek,
// 0.5,
// var_msg,
// var_a0_err,
// var_a1_err,
// var_gct_err_lhs,
// var_gct_err_rhs,
// rank as f64,
// k_ggsw,
// k_ggsw,
// );
//
// assert!(
// (noise_have - noise_want).abs() <= 0.1,
// "have: {} want: {}",
// noise_have,
// noise_want
// );
//
// pt_want.data.zero();
// });
// });
// }
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) { fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n); let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
@@ -126,8 +244,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize,
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new( let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size()) GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size()) | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
| GGSWCiphertext::external_product_scratch_space( | GGSWCiphertext::external_product_scratch_space(
&module, &module,

View File

@@ -1,6 +1,6 @@
use base2k::{ use base2k::{
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
ZnxView, ZnxViewMut, ZnxZero, ZnxViewMut, ZnxZero,
}; };
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
@@ -75,6 +75,22 @@ fn external_product_inplace() {
}); });
} }
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
});
}
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n); let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
@@ -416,14 +432,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
); );
} }
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
});
}
fn test_automorphism( fn test_automorphism(
log_n: usize, log_n: usize,
basek: usize, basek: usize,
@@ -515,14 +523,6 @@ fn test_automorphism(
); );
} }
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
});
}
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) { fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n); let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek; let rows: usize = (k_ct + basek - 1) / basek;

View File

@@ -3,3 +3,4 @@ mod gglwe;
mod ggsw; mod ggsw;
mod glwe; mod glwe;
mod glwe_fourier; mod glwe_fourier;
mod tensor_key;

View File

@@ -0,0 +1,77 @@
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos},
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
tensor_key::TensorKey,
};
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_sk rank: {}", rank);
test_encrypt_sk(12, 16, 54, 3.2, rank);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space(
&module,
rank,
tensor_key.size(),
));
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
tensor_key.encrypt_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
(0..rank).for_each(|i| {
(0..rank).for_each(|j| {
let mut sk_ij_dft: base2k::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
let sk_ij: ScalarZnx<Vec<u8>> = module
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
.to_vec_znx_small()
.to_scalar_znx();
(0..tensor_key.rank_in()).for_each(|col_i| {
(0..tensor_key.rows()).for_each(|row_i| {
tensor_key
.at(i, j)
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i);
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
});
});
})
})
}