mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added tensor key & associated test
This commit is contained in:
@@ -18,6 +18,14 @@ fn automorphism() {
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism_inplace rank: {}", rank);
|
||||
test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows = (k_ksk + basek - 1) / basek;
|
||||
@@ -115,3 +123,94 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows = (k_ksk + basek - 1) / basek;
|
||||
|
||||
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size())
|
||||
| AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
auto_key.encrypt_sk(
|
||||
&module,
|
||||
p0,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s2}(s1) -> s1 -> s2
|
||||
auto_key_apply.encrypt_sk(
|
||||
&module,
|
||||
p1,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
|
||||
auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||
|
||||
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
|
||||
(0..rank).for_each(|i| {
|
||||
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
|
||||
});
|
||||
|
||||
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_auto_dft.dft(&module, &sk_auto);
|
||||
|
||||
(0..auto_key.rank_in()).for_each(|col_i| {
|
||||
(0..auto_key.rows()).for_each(|row_i| {
|
||||
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
let noise_want: f64 = noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ksk,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use base2k::{
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
elem::{GetRow, Infos},
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
@@ -104,6 +105,123 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank:
|
||||
});
|
||||
}
|
||||
|
||||
// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
|
||||
// let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
// let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||
//
|
||||
// let mut ct_ggsw_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
// let mut ct_ggsw_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
// let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
|
||||
//
|
||||
// let mut pt_ggsw_in: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
// let mut pt_ggsw_out: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
//
|
||||
// let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
// let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
// let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
//
|
||||
// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs);
|
||||
//
|
||||
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size())
|
||||
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size())
|
||||
// | GGSWCiphertext::automorphism_scratch_space(
|
||||
// &module,
|
||||
// ct_ggsw_out.size(),
|
||||
// ct_ggsw_in.size(),
|
||||
// auto_key.size(),
|
||||
// rank,
|
||||
// ),
|
||||
// );
|
||||
//
|
||||
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
// sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
//
|
||||
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
// sk_dft.dft(&module, &sk);
|
||||
//
|
||||
// ct_ggsw_in.encrypt_sk(
|
||||
// &module,
|
||||
// &pt_ggsw_in,
|
||||
// &sk_dft,
|
||||
// &mut source_xa,
|
||||
// &mut source_xe,
|
||||
// sigma,
|
||||
// scratch.borrow(),
|
||||
// );
|
||||
//
|
||||
// auto_key.encrypt_sk(
|
||||
// &module,
|
||||
// p,
|
||||
// &sk,
|
||||
// &mut source_xa,
|
||||
// &mut source_xe,
|
||||
// sigma,
|
||||
// scratch.borrow(),
|
||||
// );
|
||||
//
|
||||
// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow());
|
||||
//
|
||||
// let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
|
||||
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
|
||||
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
//
|
||||
// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
|
||||
//
|
||||
// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
|
||||
// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
|
||||
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
|
||||
//
|
||||
// if col_j > 0 {
|
||||
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
// }
|
||||
//
|
||||
// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
//
|
||||
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
||||
//
|
||||
// let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
//
|
||||
// let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
// let var_gct_err_rhs: f64 = 0f64;
|
||||
//
|
||||
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
// let var_a0_err: f64 = sigma * sigma;
|
||||
// let var_a1_err: f64 = 1f64 / 12f64;
|
||||
//
|
||||
// let noise_want: f64 = noise_ggsw_product(
|
||||
// module.n() as f64,
|
||||
// basek,
|
||||
// 0.5,
|
||||
// var_msg,
|
||||
// var_a0_err,
|
||||
// var_a1_err,
|
||||
// var_gct_err_lhs,
|
||||
// var_gct_err_rhs,
|
||||
// rank as f64,
|
||||
// k_ggsw,
|
||||
// k_ggsw,
|
||||
// );
|
||||
//
|
||||
// assert!(
|
||||
// (noise_have - noise_want).abs() <= 0.1,
|
||||
// "have: {} want: {}",
|
||||
// noise_have,
|
||||
// noise_want
|
||||
// );
|
||||
//
|
||||
// pt_want.data.zero();
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
@@ -126,8 +244,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize,
|
||||
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
|
||||
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
|
||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
|
||||
| GGSWCiphertext::external_product_scratch_space(
|
||||
&module,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use base2k::{
|
||||
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
|
||||
ZnxView, ZnxViewMut, ZnxZero,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
@@ -75,6 +75,22 @@ fn external_product_inplace() {
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism_inplace rank: {}", rank);
|
||||
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism rank: {}", rank);
|
||||
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
@@ -416,14 +432,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism rank: {}", rank);
|
||||
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism(
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
@@ -515,14 +523,6 @@ fn test_automorphism(
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism_inplace rank: {}", rank);
|
||||
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct + basek - 1) / basek;
|
||||
|
||||
@@ -3,3 +3,4 @@ mod gglwe;
|
||||
mod ggsw;
|
||||
mod glwe;
|
||||
mod glwe_fourier;
|
||||
mod tensor_key;
|
||||
|
||||
77
core/src/test_fft64/tensor_key.rs
Normal file
77
core/src/test_fft64/tensor_key.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos},
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
tensor_key::TensorKey,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test encrypt_sk rank: {}", rank);
|
||||
test_encrypt_sk(12, 16, 54, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space(
|
||||
&module,
|
||||
rank,
|
||||
tensor_key.size(),
|
||||
));
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
tensor_key.encrypt_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
(0..rank).for_each(|j| {
|
||||
let mut sk_ij_dft: base2k::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
|
||||
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
|
||||
let sk_ij: ScalarZnx<Vec<u8>> = module
|
||||
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
|
||||
.to_vec_znx_small()
|
||||
.to_scalar_znx();
|
||||
|
||||
(0..tensor_key.rank_in()).for_each(|col_i| {
|
||||
(0..tensor_key.rows()).for_each(|row_i| {
|
||||
tensor_key
|
||||
.at(i, j)
|
||||
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i);
|
||||
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user