Added tensor key & associated test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-19 18:06:14 +02:00
parent c5fe07188f
commit 8f2eac4928
12 changed files with 610 additions and 28 deletions

View File

@@ -18,6 +18,14 @@ fn automorphism() {
});
}
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank);
});
}
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
@@ -115,3 +123,94 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize,
});
});
}
fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size())
| AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
// gglwe_{s1}(s0) = s0 -> s1
auto_key.encrypt_sk(
&module,
p0,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2
auto_key_apply.encrypt_sk(
&module,
p1,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
(0..rank).for_each(|i| {
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
});
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_auto_dft.dft(&module, &sk_auto);
(0..auto_key.rank_in()).for_each(|col_i| {
(0..auto_key.rows()).for_each(|row_i| {
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ksk,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}

View File

@@ -5,6 +5,7 @@ use base2k::{
use sampling::source::Source;
use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos},
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
@@ -104,6 +105,123 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank:
});
}
// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
// let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
// let rows: usize = (k_ggsw + basek - 1) / basek;
//
// let mut ct_ggsw_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
// let mut ct_ggsw_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
// let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
//
// let mut pt_ggsw_in: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
// let mut pt_ggsw_out: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
//
// let mut source_xs: Source = Source::new([0u8; 32]);
// let mut source_xe: Source = Source::new([0u8; 32]);
// let mut source_xa: Source = Source::new([0u8; 32]);
//
// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs);
//
// let mut scratch: ScratchOwned = ScratchOwned::new(
// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size())
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size())
// | GGSWCiphertext::automorphism_scratch_space(
// &module,
// ct_ggsw_out.size(),
// ct_ggsw_in.size(),
// auto_key.size(),
// rank,
// ),
// );
//
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
// sk.fill_ternary_prob(0.5, &mut source_xs);
//
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
// sk_dft.dft(&module, &sk);
//
// ct_ggsw_in.encrypt_sk(
// &module,
// &pt_ggsw_in,
// &sk_dft,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// auto_key.encrypt_sk(
// &module,
// p,
// &sk,
// &mut source_xa,
// &mut source_xe,
// sigma,
// scratch.borrow(),
// );
//
// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow());
//
// let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
//
// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
//
// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
//
// if col_j > 0 {
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
// }
//
// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
//
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
//
// let noise_have: f64 = pt.data.std(0, basek).log2();
//
// let var_gct_err_lhs: f64 = sigma * sigma;
// let var_gct_err_rhs: f64 = 0f64;
//
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
// let var_a0_err: f64 = sigma * sigma;
// let var_a1_err: f64 = 1f64 / 12f64;
//
// let noise_want: f64 = noise_ggsw_product(
// module.n() as f64,
// basek,
// 0.5,
// var_msg,
// var_a0_err,
// var_a1_err,
// var_gct_err_lhs,
// var_gct_err_rhs,
// rank as f64,
// k_ggsw,
// k_ggsw,
// );
//
// assert!(
// (noise_have - noise_want).abs() <= 0.1,
// "have: {} want: {}",
// noise_have,
// noise_want
// );
//
// pt_want.data.zero();
// });
// });
// }
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
@@ -126,8 +244,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize,
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
| GGSWCiphertext::external_product_scratch_space(
&module,

View File

@@ -1,6 +1,6 @@
use base2k::{
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
ZnxView, ZnxViewMut, ZnxZero,
ZnxViewMut, ZnxZero,
};
use itertools::izip;
use sampling::source::Source;
@@ -75,6 +75,22 @@ fn external_product_inplace() {
});
}
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
});
}
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
@@ -416,14 +432,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
);
}
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
});
}
fn test_automorphism(
log_n: usize,
basek: usize,
@@ -515,14 +523,6 @@ fn test_automorphism(
);
}
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
});
}
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;

View File

@@ -3,3 +3,4 @@ mod gglwe;
mod ggsw;
mod glwe;
mod glwe_fourier;
mod tensor_key;

View File

@@ -0,0 +1,77 @@
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos},
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
tensor_key::TensorKey,
};
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_sk rank: {}", rank);
test_encrypt_sk(12, 16, 54, 3.2, rank);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space(
&module,
rank,
tensor_key.size(),
));
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
tensor_key.encrypt_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
(0..rank).for_each(|i| {
(0..rank).for_each(|j| {
let mut sk_ij_dft: base2k::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
let sk_ij: ScalarZnx<Vec<u8>> = module
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
.to_vec_znx_small()
.to_scalar_znx();
(0..tensor_key.rank_in()).for_each(|col_i| {
(0..tensor_key.rows()).for_each(|row_i| {
tensor_key
.at(i, j)
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i);
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
});
});
})
})
}