mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
272
rlwe/src/automorphism.rs
Normal file
272
rlwe/src/automorphism.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use crate::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
elem::ElemCommon,
|
||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||
keys::SecretKey,
|
||||
parameters::Parameters,
|
||||
};
|
||||
use base2k::{
|
||||
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, assert_alignement,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p}
|
||||
pub struct AutomorphismKey {
|
||||
pub value: Ciphertext<VmpPMat>,
|
||||
pub p: i64,
|
||||
}
|
||||
|
||||
pub fn automorphis_key_new_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
log_q: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_scalar()
|
||||
+ module.bytes_of_svp_ppol()
|
||||
+ encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||
}
|
||||
}
|
||||
|
||||
impl AutomorphismKey {
|
||||
pub fn new(
|
||||
module: &Module,
|
||||
p: i64,
|
||||
sk: &SecretKey,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
log_q: usize,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) -> Self {
|
||||
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
|
||||
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
|
||||
|
||||
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
|
||||
let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
|
||||
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
|
||||
|
||||
let p_inv: i64 = module.galois_element_inv(p);
|
||||
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
|
||||
module.svp_prepare(&mut sk_out, &sk_auto);
|
||||
encrypt_grlwe_sk(
|
||||
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
|
||||
);
|
||||
|
||||
Self { value: value, p: p }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn automorphism_tmp_bytes(
|
||||
module: &Module,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
b_rows: usize,
|
||||
b_cols: usize,
|
||||
) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
|
||||
}
|
||||
|
||||
pub fn automorphism(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnx>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &AutomorphismKey,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols = std::cmp::min(c.cols(), a.cols());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len()
|
||||
>= automorphism_tmp_bytes(
|
||||
module,
|
||||
c.cols(),
|
||||
a.cols(),
|
||||
b.value.rows(),
|
||||
b.value.cols()
|
||||
)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
|
||||
// c[0] = NORMALIZE([-b*AUTO(s, -p) + m + e])
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes);
|
||||
|
||||
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_automorphism_inplace(b.p, c.at_mut(0));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
// c[1] = b
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes);
|
||||
|
||||
// c[1] = AUTO(b, p)
|
||||
module.vec_znx_automorphism_inplace(b.p, c.at_mut(1));
|
||||
}
|
||||
|
||||
pub fn automorphism_big(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnxBig>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &AutomorphismKey,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols = std::cmp::min(c.cols(), a.cols());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len()
|
||||
>= automorphism_tmp_bytes(
|
||||
module,
|
||||
c.cols(),
|
||||
a.cols(),
|
||||
b.value.rows(),
|
||||
b.value.cols()
|
||||
)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft);
|
||||
|
||||
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
|
||||
module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0));
|
||||
|
||||
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft);
|
||||
|
||||
// c[1] = AUTO(b, p)
|
||||
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
ciphertext::{new_gadget_ciphertext, Ciphertext}, decryptor::decrypt_rlwe, elem::{Elem, ElemCommon, ElemVecZnx}, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral}, plaintext::Plaintext
|
||||
};
|
||||
use base2k::{
|
||||
alloc_aligned, Encoding, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, BACKEND
|
||||
};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
use super::{AutomorphismKey, automorphis_key_new_tmp_bytes};
|
||||
|
||||
#[test]
|
||||
fn test_automorphism() {
|
||||
let log_base2k: usize = 10;
|
||||
let q_cols: usize = 4;
|
||||
let p_cols: usize = 1;
|
||||
|
||||
// Basic parameters with enough limbs to test edge cases
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: 12,
|
||||
log_q: q_cols * log_base2k,
|
||||
log_p: p_cols * log_base2k,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 1 << 11,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
let rows: usize = params.cols_q();
|
||||
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||
params.decrypt_rlwe_tmp_byte(params.log_q())
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||
| params.gadget_product_tmp_bytes(
|
||||
params.log_qp(),
|
||||
params.log_qp(),
|
||||
params.cols_qp(),
|
||||
params.log_qp(),
|
||||
)
|
||||
| params.encrypt_grlwe_sk_tmp_bytes(rows, params.log_qp())
|
||||
| params.automorphism_key_new_tmp_bytes(rows, params.log_qp()),
|
||||
);
|
||||
|
||||
// Samplers for public and private randomness
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
// Two secret keys
|
||||
let mut sk: SecretKey = SecretKey::new(params.module());
|
||||
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||
|
||||
let p: i64 = -5;
|
||||
|
||||
let auto_key: AutomorphismKey = AutomorphismKey::new(
|
||||
params.module(),
|
||||
p,
|
||||
&sk,
|
||||
params.log_base2k(),
|
||||
rows,
|
||||
params.log_qp(),
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let data: Vec<i64> = vec![0i64; params.n()];
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = Ciphertext::new(params.module(), params.log_base2k(), params.log_q(), 2);
|
||||
let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_q());
|
||||
|
||||
pt.at_mut(0).encode_vec_i64(params.log_base2k(), 2*params.log_base2k(), &data, 32);
|
||||
|
||||
encrypt_rlwe_sk(params.module(), &mut ct.elem_mut(), Some(&pt.elem()), &sk_svp_ppol, &mut source_xa, &mut source_xe, params.xe(), &mut tmp_bytes);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -62,14 +62,13 @@ pub fn decrypt_rlwe(
|
||||
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) =
|
||||
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut res_dft: VecZnxDft =
|
||||
VecZnxDft::from_bytes_borrow(module, a.cols(), tmp_bytes_vec_znx_dft);
|
||||
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
|
||||
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// res_dft <- DFT(ct[1]) * DFT(sk)
|
||||
module.svp_apply_dft(&mut res_dft, sk, a.at(1), cols);
|
||||
module.svp_apply_dft(&mut res_dft, sk, a.at(1));
|
||||
// res_big <- ct[1] x sk
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
// res_big <- ct[1] x sk + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
// res <- normalize(ct[1] x sk + ct[0])
|
||||
|
||||
@@ -153,13 +153,13 @@ pub fn encrypt_rlwe_sk(
|
||||
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
|
||||
|
||||
// Applies buf_dft <- DFT(s) * DFT(c1)
|
||||
module.svp_apply_dft(&mut buf_dft, sk, c1, cols);
|
||||
module.svp_apply_dft(&mut buf_dft, sk, c1);
|
||||
|
||||
// Alias scratch space
|
||||
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
||||
|
||||
// buf_big = s x c1
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, cols);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
|
||||
// c0 <- -s x c1 + m
|
||||
let c0: &mut VecZnx = ct.at_mut(0);
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||
use base2k::{Module, VecZnx, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
||||
use base2k::{
|
||||
Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps,
|
||||
};
|
||||
use std::cmp::min;
|
||||
|
||||
pub fn gadget_product_tmp_bytes(
|
||||
pub fn gadget_product_core_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
res_log_q: usize,
|
||||
@@ -24,7 +26,7 @@ impl Parameters {
|
||||
gct_rows: usize,
|
||||
gct_log_q: usize,
|
||||
) -> usize {
|
||||
gadget_product_tmp_bytes(
|
||||
gadget_product_core_tmp_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
res_log_q,
|
||||
@@ -35,54 +37,99 @@ impl Parameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product res <- a x b.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `res`: an [Elem] to store (-cs + m * a + e, c) with res_ncols cols.
|
||||
/// * `a`: a [VecZnx] of a_ncols cols.
|
||||
/// * `b`: a [Ciphertext<VmpPMat>] as a vector of (-Bs + m * 2^{-k} + E, B)
|
||||
/// containing b_nrows [VecZnx], each of b_ncols cols.
|
||||
///
|
||||
/// # Computation
|
||||
///
|
||||
/// res = sum[min(a_ncols, b_nrows)] decomp(a, i) * (-B[i]s + m * 2^{-k*i} + E[i], B[i])
|
||||
/// = (cs + m * a + e, c) with min(res_cols, b_cols) cols.
|
||||
pub fn gadget_product_core(
|
||||
module: &Module,
|
||||
res_dft_0: &mut VecZnxDft,
|
||||
res_dft_1: &mut VecZnxDft,
|
||||
a: &VecZnx,
|
||||
a_cols: usize,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
assert!(b_cols <= b.cols());
|
||||
module.vec_znx_dft(res_dft_1, a, min(a_cols, b_cols));
|
||||
module.vec_znx_dft(res_dft_1, a);
|
||||
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
|
||||
}
|
||||
|
||||
/*
|
||||
// res_big[a * (G0|G1)] <- IDFT(res_dft[a * (G0|G1)])
|
||||
module.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
|
||||
|
||||
// res_big <- res[0] + res_big[a*G0]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big_0, res.at(0));
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big_0, tmp_bytes_carry);
|
||||
|
||||
if OVERWRITE {
|
||||
// res[1] = normalize(res_big[a*G1])
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
|
||||
} else {
|
||||
// res[1] = normalize(res_big[a*G1] + res[1])
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big_1, res.at(1));
|
||||
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big_1, tmp_bytes_carry);
|
||||
pub fn gadget_product_big_tmp_bytes(
|
||||
module: &Module,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
b_rows: usize,
|
||||
b_cols: usize,
|
||||
) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(min(c_cols, a_cols));
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product: c.at(i) = IDFT(<DFT(a.at(i)), b.at(i)>)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `c`: a [Ciphertext<VecZnxBig>] with cols_c cols.
|
||||
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
|
||||
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
|
||||
pub fn gadget_product_big(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnxBig>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = min(c.cols(), a.cols());
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// c[i] = IDFT(DFT(a[1]) * b[i])
|
||||
(0..2).for_each(|i| {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft);
|
||||
})
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT(<DFT(a.at(i)), b.at(i)>)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `c`: a [Ciphertext<VecZnx>] with cols_c cols.
|
||||
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
|
||||
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
|
||||
pub fn gadget_product(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnx>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = min(c.cols(), a.cols());
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// c[i] = IDFT(DFT(a[1]) * b[i])
|
||||
(0..2).for_each(|i| {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes);
|
||||
})
|
||||
}
|
||||
*/
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
@@ -97,7 +144,7 @@ mod test {
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{
|
||||
Infos, BACKEND, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
|
||||
};
|
||||
use sampling::source::{Source, new_seed};
|
||||
@@ -125,7 +172,6 @@ mod test {
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||
params.decrypt_rlwe_tmp_byte(params.log_qp())
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(params.log_qp())
|
||||
| params.gadget_product_tmp_bytes(
|
||||
params.log_qp(),
|
||||
params.log_qp(),
|
||||
@@ -193,12 +239,8 @@ mod test {
|
||||
let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols());
|
||||
|
||||
// a * sk0
|
||||
params
|
||||
.module()
|
||||
.svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a, a.cols());
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut a_big, &mut a_dft, a.cols());
|
||||
params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a);
|
||||
params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft);
|
||||
params.module().vec_znx_big_normalize(
|
||||
params.log_base2k(),
|
||||
&mut a_times_s,
|
||||
@@ -228,7 +270,6 @@ mod test {
|
||||
&mut res_dft_0,
|
||||
&mut res_dft_1,
|
||||
&a,
|
||||
a_cols,
|
||||
&gadget_ct,
|
||||
b_cols,
|
||||
&mut tmp_bytes,
|
||||
@@ -237,11 +278,11 @@ mod test {
|
||||
// res_big_0 = IDFT(res_dft_0)
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0, b_cols);
|
||||
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0);
|
||||
// res_big_1 = IDFT(res_dft_1);
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1, b_cols);
|
||||
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1);
|
||||
|
||||
// res_big_0 = normalize(res_big_0)
|
||||
params.module().vec_znx_big_normalize(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod automorphism;
|
||||
pub mod ciphertext;
|
||||
pub mod decryptor;
|
||||
pub mod elem;
|
||||
@@ -8,3 +9,4 @@ pub mod keys;
|
||||
pub mod parameters;
|
||||
pub mod plaintext;
|
||||
pub mod rgsw_product;
|
||||
pub mod trace;
|
||||
|
||||
@@ -39,7 +39,7 @@ pub fn rgsw_product(
|
||||
let mut _r2_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes(cols, tmp_bytes_r2_dft);
|
||||
|
||||
// c0_dft <- DFT(a[0])
|
||||
module.vec_znx_dft(&mut c0_dft, a.at(0), in_cols);
|
||||
module.vec_znx_dft(&mut c0_dft, a.at(0));
|
||||
|
||||
// r_dft <- sum[rows] c0_dft[cols] x RGSW[0][cols]
|
||||
module.vmp_apply_dft_to_dft(
|
||||
@@ -50,5 +50,5 @@ pub fn rgsw_product(
|
||||
);
|
||||
|
||||
// c1_dft <- DFT(a[1])
|
||||
module.vec_znx_dft(&mut c1_dft, a.at(1), in_cols);
|
||||
module.vec_znx_dft(&mut c1_dft, a.at(1));
|
||||
}
|
||||
|
||||
112
rlwe/src/trace.rs
Normal file
112
rlwe/src/trace.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon};
|
||||
use base2k::{
|
||||
Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMatOps,
|
||||
assert_alignement,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub fn trace_tmp_bytes(
|
||||
module: &Module,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
b_rows: usize,
|
||||
b_cols: usize,
|
||||
) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
|
||||
}
|
||||
|
||||
pub fn trace_inplace(
|
||||
module: &Module,
|
||||
a: &mut Ciphertext<VecZnx>,
|
||||
start: usize,
|
||||
end: usize,
|
||||
b: HashMap<i64, AutomorphismKey>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = a.cols();
|
||||
|
||||
let b_rows: usize;
|
||||
let b_cols: usize;
|
||||
|
||||
if let Some((_, key)) = b.iter().next() {
|
||||
b_rows = key.value.rows();
|
||||
b_cols = key.value.cols();
|
||||
} else {
|
||||
panic!("b: HashMap<i64, AutomorphismKey>, is empty")
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(start <= end);
|
||||
assert!(end <= module.n());
|
||||
assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let cols: usize = std::cmp::min(b_cols, a.cols());
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
let log_base2k: usize = a.log_base2k();
|
||||
|
||||
(start..end).for_each(|i| {
|
||||
a.at_mut(0).rsh(log_base2k, 1, tmp_bytes);
|
||||
a.at_mut(1).rsh(log_base2k, 1, tmp_bytes);
|
||||
|
||||
let p: i64;
|
||||
if i == 0 {
|
||||
p = -1;
|
||||
} else {
|
||||
p = module.galois_element(1 << (i - 1));
|
||||
}
|
||||
|
||||
if let Some(key) = b.get(&p) {
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT(<DFT(a[1]), key[0]>)))
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes);
|
||||
|
||||
// a[1] = NORMALIZE(a[1] + AUTO(IDFT(<DFT(a[1]), key[1]>)))
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1));
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
|
||||
} else {
|
||||
panic!("b[{}] is empty", p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
decryptor::decrypt_rlwe,
|
||||
elem::{Elem, ElemCommon, ElemVecZnx},
|
||||
encryptor::encrypt_grlwe_sk,
|
||||
gadget_product::gadget_product_core,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{
|
||||
BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8,
|
||||
};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
fn test_trace_inplace() {}
|
||||
}
|
||||
Reference in New Issue
Block a user