mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added basic key-switching + file formatting
This commit is contained in:
@@ -2,12 +2,13 @@ use crate::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
elem::ElemCommon,
|
||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||
key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes},
|
||||
keys::SecretKey,
|
||||
parameters::Parameters,
|
||||
};
|
||||
use base2k::{
|
||||
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, assert_alignement,
|
||||
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||
VmpPMatOps, assert_alignement,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
use std::{cmp::min, collections::HashMap};
|
||||
@@ -18,15 +19,8 @@ pub struct AutomorphismKey {
|
||||
pub p: i64,
|
||||
}
|
||||
|
||||
pub fn automorphis_key_new_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
log_q: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_scalar()
|
||||
+ module.bytes_of_svp_ppol()
|
||||
+ encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
@@ -34,12 +28,7 @@ impl Parameters {
|
||||
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||
}
|
||||
|
||||
pub fn automorphism_tmp_bytes(
|
||||
&self,
|
||||
res_logq: usize,
|
||||
in_logq: usize,
|
||||
gct_logq: usize,
|
||||
) -> usize {
|
||||
pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
automorphism_tmp_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
@@ -122,8 +111,7 @@ impl AutomorphismKey {
|
||||
let mut keys: Vec<AutomorphismKey> = Vec::new();
|
||||
|
||||
p.iter().for_each(|pi| {
|
||||
let mut value: Ciphertext<VmpPMat> =
|
||||
new_gadget_ciphertext(module, log_base2k, rows, log_q);
|
||||
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
|
||||
|
||||
let p_inv: i64 = module.galois_element_inv(*pi);
|
||||
|
||||
@@ -143,19 +131,8 @@ impl AutomorphismKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn automorphism_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
res_logq: usize,
|
||||
in_logq: usize,
|
||||
gct_logq: usize,
|
||||
) -> usize {
|
||||
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
|
||||
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
|
||||
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
|
||||
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
|
||||
+ module.bytes_of_vec_znx_dft(gct_cols);
|
||||
pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq)
|
||||
}
|
||||
|
||||
pub fn automorphism(
|
||||
@@ -166,67 +143,14 @@ pub fn automorphism(
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = min(min(c.cols(), a.cols()), b.value.rows());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(b_cols <= b.value.cols());
|
||||
assert!(
|
||||
tmp_bytes.len()
|
||||
>= automorphism_tmp_bytes(
|
||||
module,
|
||||
c.cols(),
|
||||
a.cols(),
|
||||
b.value.rows(),
|
||||
b.value.cols()
|
||||
)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) =
|
||||
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft);
|
||||
let mut res_dft: VecZnxDft =
|
||||
module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
|
||||
// c[0] = NORMALIZE([-b*AUTO(s, -p) + m + e])
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes);
|
||||
|
||||
key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes);
|
||||
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_automorphism_inplace(b.p, c.at_mut(0));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
// c[1] = b
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes);
|
||||
|
||||
// c[1] = AUTO(b, p)
|
||||
module.vec_znx_automorphism_inplace(b.p, c.at_mut(1));
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_tmp_bytes(
|
||||
module: &Module,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
b_rows: usize,
|
||||
b_cols: usize,
|
||||
) -> usize {
|
||||
pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
|
||||
}
|
||||
@@ -238,60 +162,9 @@ pub fn automorphism_inplace(
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = min(a.cols(), b.value.rows());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(b_cols <= b.value.cols());
|
||||
assert!(
|
||||
tmp_bytes.len()
|
||||
>= automorphism_inplace_tmp_bytes(
|
||||
module,
|
||||
a.cols(),
|
||||
a.cols(),
|
||||
b.value.rows(),
|
||||
b.value.cols()
|
||||
)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft =
|
||||
module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
|
||||
// a[0] = NORMALIZE([-b*AUTO(s, -p) + m + e])
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes);
|
||||
|
||||
key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes);
|
||||
// a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_automorphism_inplace(b.p, a.at_mut(0));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
(0..b_cols).for_each(|col_i| {
|
||||
let raw: &[i64] = res_big.raw::<i64>(module);
|
||||
println!("{:?}", &raw[col_i * module.n()..(col_i + 1) * module.n()])
|
||||
});
|
||||
|
||||
// a[1] = b
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
|
||||
|
||||
// a[1] = AUTO(b, p)
|
||||
module.vec_znx_automorphism_inplace(b.p, a.at_mut(1));
|
||||
}
|
||||
@@ -307,16 +180,7 @@ pub fn automorphism_big(
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len()
|
||||
>= automorphism_tmp_bytes(
|
||||
module,
|
||||
c.cols(),
|
||||
a.cols(),
|
||||
b.value.rows(),
|
||||
b.value.cols()
|
||||
)
|
||||
);
|
||||
assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
@@ -359,9 +223,7 @@ mod test {
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{
|
||||
BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned,
|
||||
};
|
||||
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
@@ -470,15 +332,14 @@ mod test {
|
||||
|
||||
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0));
|
||||
|
||||
//pt.at(0).print(pt.cols(), 16);
|
||||
// pt.at(0).print(pt.cols(), 16);
|
||||
|
||||
let noise_have: f64 = pt.at(0).std(log_base2k).log2();
|
||||
|
||||
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
|
||||
let var_a_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_pred: f64 =
|
||||
params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
|
||||
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
|
||||
|
||||
println!("noise_pred: {}", noise_pred);
|
||||
println!("noise_have: {}", noise_have);
|
||||
|
||||
Reference in New Issue
Block a user