trace working

This commit is contained in:
Jean-Philippe Bossuat
2025-04-23 11:32:52 +02:00
parent 9695761ff1
commit 09981b78b5
11 changed files with 301 additions and 105 deletions

View File

@@ -10,7 +10,7 @@ use base2k::{
VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, assert_alignement,
};
use sampling::source::Source;
use std::cmp::min;
use std::{cmp::min, collections::HashMap};
/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p}
pub struct AutomorphismKey {
@@ -33,6 +33,21 @@ impl Parameters {
pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
pub fn automorphism_tmp_bytes(
&self,
res_logq: usize,
in_logq: usize,
gct_logq: usize,
) -> usize {
automorphism_tmp_bytes(
self.module(),
self.log_base2k(),
res_logq,
in_logq,
gct_logq,
)
}
}
impl AutomorphismKey {
@@ -48,34 +63,68 @@ impl AutomorphismKey {
sigma: f64,
tmp_bytes: &mut [u8],
) -> Self {
Self::new_many_core(module, &vec![p], sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes).into_iter().next().unwrap()
}
pub fn new_many(module: &Module, p: &Vec<i64>, sk: &SecretKey, log_base2k: usize, rows: usize, log_q: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8]) -> HashMap<i64, AutomorphismKey>{
Self::new_many_core(
module,
p,
sk,
log_base2k,
rows,
log_q,
source_xa,
source_xe,
sigma,
tmp_bytes,
)
.into_iter()
.zip(p.iter().cloned())
.map(|(key, pi)| (pi, key))
.collect()
}
fn new_many_core(module: &Module, p: &Vec<i64>, sk: &SecretKey, log_base2k: usize, rows: usize, log_q: usize, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, tmp_bytes: &mut [u8]) -> Vec<Self>{
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
let p_inv: i64 = module.galois_element_inv(p);
let mut keys: Vec<AutomorphismKey> = Vec::new();
p.iter().for_each(|pi|{
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
module.svp_prepare(&mut sk_out, &sk_auto);
encrypt_grlwe_sk(
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
);
let p_inv: i64 = module.galois_element_inv(*pi);
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
module.svp_prepare(&mut sk_out, &sk_auto);
encrypt_grlwe_sk(
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
);
Self { value: value, p: p }
keys.push(Self { value: value, p: *pi })
});
keys
}
}
pub fn automorphism_tmp_bytes(
module: &Module,
c_cols: usize,
a_cols: usize,
b_rows: usize,
b_cols: usize,
log_base2k: usize,
res_logq: usize,
in_logq: usize,
gct_logq: usize,
) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
+ module.bytes_of_vec_znx_dft(gct_cols);
}
pub fn automorphism(
@@ -83,12 +132,14 @@ pub fn automorphism(
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &AutomorphismKey,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
let cols: usize = min(min(c.cols(), a.cols()), b.value.rows());
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.value.cols());
assert!(
tmp_bytes.len()
>= automorphism_tmp_bytes(
@@ -102,11 +153,13 @@ pub fn automorphism(
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) =
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft);
let mut res_dft: VecZnxDft =
module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// a1_dft = DFT(a[1])
@@ -151,12 +204,14 @@ pub fn automorphism_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &AutomorphismKey,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
let cols: usize = min(a.cols(), b.value.rows());
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.value.cols());
assert!(
tmp_bytes.len()
>= automorphism_inplace_tmp_bytes(
@@ -174,7 +229,8 @@ pub fn automorphism_inplace(
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
let mut res_dft: VecZnxDft =
module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// a1_dft = DFT(a[1])
@@ -197,6 +253,11 @@ pub fn automorphism_inplace(
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
(0..b_cols).for_each(|col_i| {
let raw: &[i64] = res_big.raw::<i64>(module);
println!("{:?}", &raw[col_i * module.n()..(col_i + 1) * module.n()])
});
// a[1] = b
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
@@ -257,28 +318,33 @@ pub fn automorphism_big(
#[cfg(test)]
mod test {
use super::{AutomorphismKey, automorphism};
use crate::{
ciphertext::{new_gadget_ciphertext, Ciphertext}, decryptor::decrypt_rlwe, elem::{Elem, ElemCommon, ElemVecZnx}, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, plaintext::Plaintext
ciphertext::Ciphertext,
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::encrypt_rlwe_sk,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{
alloc_aligned, Encoding, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, BACKEND
BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned,
};
use sampling::source::{Source, new_seed};
use super::{automorphis_key_new_tmp_bytes, automorphism, AutomorphismKey};
#[test]
fn test_automorphism() {
let log_base2k: usize = 10;
let q_cols: usize = 4;
let p_cols: usize = 1;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: q_cols * log_base2k,
log_p: p_cols * log_base2k,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
@@ -287,23 +353,18 @@ mod test {
let params: Parameters = Parameters::new(&params_lit);
let module: &base2k::Module = params.module();
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let rows: usize = params.cols_q();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.gadget_product_tmp_bytes(
log_qp,
log_qp,
rows,
log_qp,
)
| params.encrypt_grlwe_sk_tmp_bytes(rows, log_qp)
| params.automorphism_key_new_tmp_bytes(rows, log_qp),
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
);
// Samplers for public and private randomness
@@ -311,10 +372,9 @@ mod test {
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
// Two secret keys
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: base2k::SvpPPol = module.new_svp_ppol();
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let p: i64 = -5;
@@ -324,7 +384,7 @@ mod test {
p,
&sk,
log_base2k,
rows,
gct_rows,
log_qp,
&mut source_xa,
&mut source_xe,
@@ -334,50 +394,64 @@ mod test {
let mut data: Vec<i64> = vec![0i64; params.n()];
data.iter_mut().enumerate().for_each(|(i, x)|{
*x = i as i64
});
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let log_k: usize = 2*log_base2k;
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = Ciphertext::new(module, log_base2k, log_q, 2);
let mut pt: Plaintext = Plaintext::new(module, log_base2k, log_q);
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
let mut pt_auto: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32);
module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0));
encrypt_rlwe_sk(module, &mut ct.elem_mut(), Some(pt.elem()), &sk_svp_ppol, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes);
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.elem()),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
module.vec_znx_automorphism_inplace(p, pt.at_mut(0));
let mut ct_auto: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut ct_auto: Ciphertext<VecZnx> = Ciphertext::new(module, log_base2k, log_q, 2);
// ct <- AUTO(ct)
automorphism(
module,
&mut ct_auto,
&ct,
&auto_key,
gct_cols,
&mut tmp_bytes,
);
automorphism(module, &mut ct_auto, &ct, &auto_key, &mut tmp_bytes);
// pt = dec(auto(ct)) - auto(pt)
decrypt_rlwe(
module,
pt.elem_mut(),
ct_auto.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
module.vec_znx_sub_inplace(ct_auto.at_mut(0), pt.at(0));
ct_auto.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0));
//pt.at(0).print(pt.cols(), 16);
decrypt_rlwe(module, pt.elem_mut(), ct_auto.elem(), &sk_svp_ppol, &mut tmp_bytes);
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
let var_a_err: f64;
if ct_auto.cols() < ct.cols() {
var_a_err = 1f64 / 12f64;
} else {
var_a_err = 0f64;
}
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
let var_a_err: f64 = 1f64 / 12f64;
let noise_pred: f64 =
params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
println!("noise_pred: {}", noise_have);
println!("noise_have: {}", noise_pred);
println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_have);
assert!(noise_have <= noise_pred + 1.0);
}
}