updated ggsw product noise prediction & added test for ggsw x glwe of rank > 1

This commit is contained in:
Jean-Philippe Bossuat
2025-05-14 16:57:57 +02:00
parent f517a730a3
commit 4c55a7df44
2 changed files with 62 additions and 57 deletions

View File

@@ -576,6 +576,7 @@ pub(crate) fn noise_rgsw_product(
var_a1_err: f64, var_a1_err: f64,
var_gct_err_lhs: f64, var_gct_err_lhs: f64,
var_gct_err_rhs: f64, var_gct_err_rhs: f64,
rank: f64,
a_logq: usize, a_logq: usize,
b_logq: usize, b_logq: usize,
) -> f64 { ) -> f64 {
@@ -590,9 +591,9 @@ pub(crate) fn noise_rgsw_product(
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2) // lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs // rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a0_err * a_scale * a_scale * n; noise += var_msg * var_a0_err * a_scale * a_scale * n;
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs; noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank;
noise = noise.sqrt(); noise = noise.sqrt();
noise /= b_scale; noise /= b_scale;
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}] noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]

View File

@@ -40,6 +40,24 @@ fn encrypt_pk() {
}); });
} }
#[test]
fn keyswitch() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out);
test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2);
});
});
}
#[test]
fn keyswitch_inplace() {
(1..4).for_each(|rank| {
println!("test keyswitch_inplace rank: {}", rank);
test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) { fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n); let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
@@ -195,16 +213,6 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma:
); );
} }
#[test]
fn keyswitch() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out);
test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2);
});
});
}
fn test_keyswitch( fn test_keyswitch(
log_n: usize, log_n: usize,
basek: usize, basek: usize,
@@ -307,21 +315,14 @@ fn test_keyswitch(
); );
} }
#[test] fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) {
fn keyswich_inplace() { let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let module: Module<FFT64> = Module::<FFT64>::new(2048); let rows: usize = (k_ct + basek - 1) / basek;
let basek: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe: usize = 45;
let rows: usize = (log_k_rlwe + basek - 1) / basek;
let rank: usize = 1;
let sigma: f64 = 3.2; let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, log_k_grlwe, rows, rank, rank); let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, log_k_rlwe, rank); let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, log_k_rlwe);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, log_k_rlwe);
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
@@ -387,8 +388,8 @@ fn keyswich_inplace() {
sigma * sigma, sigma * sigma,
0f64, 0f64,
rank as f64, rank as f64,
log_k_rlwe, k_ct,
log_k_grlwe, k_ksk,
); );
assert!( assert!(
@@ -401,22 +402,23 @@ fn keyswich_inplace() {
#[test] #[test]
fn external_product() { fn external_product() {
let module: Module<FFT64> = Module::<FFT64>::new(2048); (1..4).for_each(|rank| {
let basek: usize = 12; println!("test external_product rank: {}", rank);
let log_k_grlwe: usize = 60; test_external_product(12, 12, 60, 45, 60, rank, 3.2);
let log_k_rlwe_in: usize = 45; });
let log_k_rlwe_out: usize = 60; }
let rows: usize = (log_k_rlwe_in + basek - 1) / basek;
let rank: usize = 1;
let sigma: f64 = 3.2; fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); let rows: usize = (k_ct_in + basek - 1) / basek;
let mut ct_rlwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank);
let mut ct_rlwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, log_k_rlwe_out, rank); let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_rlwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
let mut ct_rlwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1); let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
@@ -498,8 +500,9 @@ fn external_product() {
var_a1_err, var_a1_err,
var_gct_err_lhs, var_gct_err_lhs,
var_gct_err_rhs, var_gct_err_rhs,
log_k_rlwe_in, rank as f64,
log_k_grlwe, k_ct_in,
k_ggsw,
); );
assert!( assert!(
@@ -512,21 +515,21 @@ fn external_product() {
#[test] #[test]
fn external_product_inplace() { fn external_product_inplace() {
let module: Module<FFT64> = Module::<FFT64>::new(2048); (1..4).for_each(|rank| {
let basek: usize = 12; println!("test external_product rank: {}", rank);
let log_k_grlwe: usize = 60; test_external_product_inplace(12, 15, 60, 60, rank, 3.2);
let log_k_rlwe_in: usize = 45; });
let log_k_rlwe_out: usize = 60; }
let rows: usize = (log_k_rlwe_in + basek - 1) / basek;
let rank: usize = 1;
let sigma: f64 = 3.2; fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, log_k_grlwe, rows, rank); let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, log_k_rlwe_in, rank); let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1); let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, log_k_rlwe_in); let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, log_k_rlwe_out); let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
@@ -602,8 +605,9 @@ fn external_product_inplace() {
var_a1_err, var_a1_err,
var_gct_err_lhs, var_gct_err_lhs,
var_gct_err_rhs, var_gct_err_rhs,
log_k_rlwe_in, rank as f64,
log_k_grlwe, k_ct,
k_ggsw,
); );
assert!( assert!(