Updated arguments to get scratch space size for ops

This commit is contained in:
Jean-Philippe Bossuat
2025-05-28 18:46:24 +02:00
parent 8209fb4e40
commit f9440c5407
20 changed files with 599 additions and 529 deletions

View File

@@ -10,7 +10,6 @@ use std::collections::HashMap;
use backend::{Encoding, FFT64, Module, ScratchOwned, Stats};
use sampling::source::Source;
use std::time::Instant;
#[test]
fn packing() {
@@ -22,20 +21,18 @@ fn packing() {
let mut source_xa: Source = Source::new([0u8; 32]);
let basek: usize = 18;
let k_ct: usize = 36;
let k_auto_key: usize = k_ct + basek;
let k_pt: usize = 18;
let ct_k: usize = 36;
let atk_k: usize = ct_k + basek;
let pt_k: usize = 18;
let rank: usize = 3;
let rows: usize = (k_ct + basek - 1) / basek;
let rows: usize = (ct_k + basek - 1) / basek;
let sigma: f64 = 3.2;
let ct_size: usize = rows;
let auto_key_size: usize = (k_auto_key + basek - 1) / basek;
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWECiphertext::encrypt_sk_scratch_space(&module, ct_size)
| GLWECiphertext::decrypt_scratch_space(&module, ct_size)
| AutomorphismKey::generate_from_sk_scratch_space(&module, rank, auto_key_size)
| StreamPacker::scratch_space(&module, ct_size, auto_key_size, rank),
GLWECiphertext::encrypt_sk_scratch_space(&module, basek, ct_k)
| GLWECiphertext::decrypt_scratch_space(&module, basek, ct_k)
| AutomorphismKey::generate_from_sk_scratch_space(&module, basek, atk_k, rank)
| StreamPacker::scratch_space(&module, basek, ct_k, atk_k, rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank);
@@ -44,18 +41,18 @@ fn packing() {
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk_dft.dft(&module, &sk);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, ct_k);
let mut data: Vec<i64> = vec![0i64; module.n()];
data.iter_mut().enumerate().for_each(|(i, x)| {
*x = i as i64;
});
pt.data.encode_vec_i64(0, basek, k_pt, &data, 32);
pt.data.encode_vec_i64(0, basek, pt_k, &data, 32);
let gal_els: Vec<i64> = StreamPacker::galois_elements(&module);
let mut auto_keys: HashMap<i64, AutomorphismKey<Vec<u8>, FFT64>> = HashMap::new();
gal_els.iter().for_each(|gal_el| {
let mut key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::alloc(&module, basek, k_auto_key, rows, rank);
let mut key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::alloc(&module, basek, atk_k, rows, rank);
key.generate_from_sk(
&module,
*gal_el,
@@ -70,9 +67,9 @@ fn packing() {
let log_batch: usize = 0;
let mut packer: StreamPacker = StreamPacker::new(&module, log_batch, basek, k_ct, rank);
let mut packer: StreamPacker = StreamPacker::new(&module, log_batch, basek, ct_k, rank);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct, rank);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, ct_k, rank);
ct.encrypt_sk(
&module,
@@ -86,9 +83,7 @@ fn packing() {
let mut res: Vec<GLWECiphertext<Vec<u8>>> = Vec::new();
let start = Instant::now();
(0..module.n() >> log_batch).for_each(|i| {
println!("pt {}", pt.data);
ct.encrypt_sk(
&module,
&pt,
@@ -113,15 +108,11 @@ fn packing() {
)
}
});
let duration = start.elapsed();
println!("Elapsed time: {} ms", duration.as_millis());
packer.flush(&module, &mut res, &auto_keys, scratch.borrow());
packer.reset();
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
println!("{}", res.len());
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, ct_k);
res.iter().enumerate().for_each(|(i, res_i)| {
let mut data: Vec<i64> = vec![0i64; module.n()];
@@ -130,12 +121,10 @@ fn packing() {
*x = reverse_bits_msb(i, log_n as u32) as i64;
}
});
pt_want.data.encode_vec_i64(0, basek, k_pt, &data, 32);
pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32);
res_i.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
println!("{}", pt.data);
if i & 1 == 0 {
pt.sub_inplace_ab(&module, &pt_want);
} else {
@@ -143,9 +132,9 @@ fn packing() {
}
let noise_have = pt.data.std(0, basek).log2();
println!("noise_have: {}", noise_have);
// println!("noise_have: {}", noise_have);
assert!(
noise_have < -((k_ct - basek) as f64),
noise_have < -((ct_k - basek) as f64),
"noise: {}",
noise_have
);