Updated arguments to get scratch space size for ops

This commit is contained in:
Jean-Philippe Bossuat
2025-05-28 18:46:24 +02:00
parent 8209fb4e40
commit f9440c5407
20 changed files with 599 additions and 529 deletions

View File

@@ -1,4 +1,5 @@
use crate::{
GLWEOps,
elem::Infos,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
@@ -79,16 +80,17 @@ fn test_keyswitch(
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_glwe_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe_in.size())
GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, k_ksk, rank_out)
| GLWECiphertext::decrypt_scratch_space(&module, basek, k_ct_out)
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, k_ct_in)
| GLWECiphertextFourier::keyswitch_scratch_space(
&module,
ct_glwe_out.size(),
basek,
ct_glwe_out.k(),
rank_out,
ct_glwe_in.size(),
ct_glwe_in.k(),
rank_in,
ksk.size(),
ksk.k(),
),
);
@@ -174,10 +176,10 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_glwe.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe.size())
| GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ksk.size(), rank),
GLWESwitchingKey::encrypt_sk_scratch_space(&module, basek, ksk.k(), rank)
| GLWECiphertext::decrypt_scratch_space(&module, basek, ct_glwe.k())
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_glwe.k())
| GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, basek, ct_rlwe_dft.k(), ksk.k(), rank),
);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank);
@@ -247,7 +249,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
let rows: usize = (k_ct_in + basek - 1) / basek;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::alloc(&module, basek, k_ggsw, rows, rank);
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank);
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank);
let mut ct_in_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank);
@@ -267,15 +269,16 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
pt_want.data.at_mut(0, 0)[1] = 1;
let k: usize = 1;
let k: i64 = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
pt_rgsw.raw_mut()[0] = 1; // X^{0}
module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
| GLWECiphertextFourier::external_product_scratch_space(&module, ct_out.size(), ct_in.size(), ct_rgsw.size(), rank),
GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank)
| GLWECiphertext::decrypt_scratch_space(&module, basek, ct_out.k())
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_in.k())
| GLWECiphertextFourier::external_product_scratch_space(&module, basek, ct_out.k(), ct_in.k(), ct_ggsw.k(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank);
@@ -284,7 +287,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
ct_ggsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
@@ -305,14 +308,13 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
);
ct_in.dft(&module, &mut ct_in_dft);
ct_out_dft.external_product(&module, &ct_in_dft, &ct_rgsw, scratch.borrow());
ct_out_dft.external_product(&module, &ct_in_dft, &ct_ggsw, scratch.borrow());
ct_out_dft.idft(&module, &mut ct_out, scratch.borrow());
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
pt_want.rotate_inplace(&module, k);
pt_have.sub_inplace_ab(&module, &pt_want);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
@@ -367,15 +369,16 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
pt_want.data.at_mut(0, 0)[1] = 1;
let k: usize = 1;
let k: i64 = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
pt_rgsw.raw_mut()[0] = 1; // X^{0}
module.vec_znx_rotate_inplace(k, &mut pt_rgsw, 0); // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
| GLWECiphertextFourier::external_product_inplace_scratch_space(&module, ct.size(), ct_ggsw.size(), rank),
GGSWCiphertext::encrypt_sk_scratch_space(&module, basek, ct_ggsw.k(), rank)
| GLWECiphertext::decrypt_scratch_space(&module, basek, ct.k())
| GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct.k())
| GLWECiphertextFourier::external_product_inplace_scratch_space(&module, basek, ct.k(), ct_ggsw.k(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank);
@@ -410,9 +413,8 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
pt_want.rotate_inplace(&module, k);
pt_have.sub_inplace_ab(&module, &pt_want);
let noise_have: f64 = pt_have.data.std(0, basek).log2();