Added size and memory layout to VecZnxBig, VecZnxDft and VmpPmat

This commit is contained in:
Jean-Philippe Bossuat
2025-04-25 09:19:47 +02:00
parent f0eaddb63e
commit 3bdddd3857
22 changed files with 195 additions and 119 deletions

View File

@@ -104,8 +104,8 @@ fn bench_gadget_product_inplace(c: &mut Criterion) {
&mut tmp_bytes,
);
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols());
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols());
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols());
let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q());
params

View File

@@ -152,7 +152,7 @@ pub fn automorphism(
pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
+ 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols));
}
pub fn automorphism_inplace(
@@ -184,11 +184,11 @@ pub fn automorphism_big(
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));

View File

@@ -1,6 +1,6 @@
use crate::elem::{Elem, ElemCommon};
use crate::parameters::Parameters;
use base2k::{Infos, Module, VecZnx, VmpPMat};
use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat};
pub struct Ciphertext<T>(pub Elem<T>);
@@ -38,6 +38,10 @@ where
self.elem().size()
}
fn layout(&self) -> LAYOUT {
self.elem().layout()
}
fn rows(&self) -> usize {
self.elem().rows()
}

View File

@@ -20,8 +20,8 @@ impl Decryptor {
}
}
pub fn decrypt_rlwe_tmp_byte(module: &Module, limbs: usize) -> usize {
module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes()
pub fn decrypt_rlwe_tmp_byte(module: &Module, cols: usize) -> usize {
module.bytes_of_vec_znx_dft(1, cols) + module.vec_znx_big_normalize_tmp_bytes()
}
impl Parameters {
@@ -47,9 +47,9 @@ pub fn decrypt_rlwe(module: &Module, res: &mut Elem<VecZnx>, a: &Elem<VecZnx>, s
decrypt_rlwe_tmp_byte(module, cols)
);
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk)

View File

@@ -1,4 +1,4 @@
use base2k::{Infos, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
pub struct Elem<T> {
pub value: Vec<T>,
@@ -71,6 +71,7 @@ pub trait ElemCommon<T> {
fn elem(&self) -> &Elem<T>;
fn elem_mut(&mut self) -> &mut Elem<T>;
fn size(&self) -> usize;
fn layout(&self) -> LAYOUT;
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn log_base2k(&self) -> usize;
@@ -101,6 +102,10 @@ impl<T: Infos> ElemCommon<T> for Elem<T> {
self.value.len()
}
fn layout(&self) -> LAYOUT {
self.value[0].layout()
}
fn rows(&self) -> usize {
self.value[0].rows()
}
@@ -152,7 +157,7 @@ impl Elem<VmpPMat> {
assert!(rows > 0);
assert!(cols > 0);
let mut value: Vec<VmpPMat> = Vec::new();
(0..size).for_each(|_| value.push(module.new_vmp_pmat(rows, cols)));
(0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols)));
Self {
value: value,
log_q: 0,

View File

@@ -108,7 +108,7 @@ impl EncryptorSk {
}
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes()
module.bytes_of_vec_znx_dft(1, (log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes()
}
pub fn encrypt_rlwe_sk(
module: &Module,
@@ -151,10 +151,10 @@ fn encrypt_rlwe_sk_core<const PT_POS: u8>(
// c1 <- Z_{2^prec}[X]/(X^{N}+1)
module.fill_uniform(log_base2k, c1, cols, source_xa);
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
// Scratch space for DFT values
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, cols, tmp_bytes_vec_znx_dft);
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft);
// Applies buf_dft <- DFT(s) * DFT(c1)
module.svp_apply_dft(&mut buf_dft, sk, c1);

View File

@@ -46,7 +46,7 @@ pub fn gadget_product_core(
pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(min(c_cols, a_cols));
+ 2 * module.bytes_of_vec_znx_dft(1, min(c_cols, a_cols));
}
/// Evaluates the gadget product: c.at(i) = IDFT(<DFT(a.at(i)), b.at(i)>)
@@ -66,11 +66,11 @@ pub fn gadget_product_big(
) {
let cols: usize = min(c.cols(), a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
@@ -99,11 +99,11 @@ pub fn gadget_product(
) {
let cols: usize = min(c.cols(), a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_res_dft);
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// a1_dft = DFT(a[1])
@@ -215,7 +215,7 @@ mod test {
let mut elem_res: Elem<VecZnx> = Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
// Ideal output = a * s
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols());
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(1, a.cols());
let mut a_big: VecZnxBig = a_dft.as_vec_znx_big();
let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols());
@@ -236,8 +236,8 @@ mod test {
a_trunc.copy_from(&a);
(1..gadget_ct.cols() + 1).for_each(|b_cols| {
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(b_cols);
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(b_cols);
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols);
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols);
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();

View File

@@ -8,8 +8,8 @@ pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize,
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
+ module.bytes_of_vec_znx_dft(gct_cols);
+ module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols))
+ module.bytes_of_vec_znx_dft(1, gct_cols);
}
pub fn key_switch_rlwe(
@@ -54,11 +54,11 @@ fn key_switch_rlwe_core(
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_a1_dft);
let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_a1_dft);
let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft);
let mut res_big = res_dft.as_vec_znx_big();
module.vec_znx_dft(&mut a1_dft, a.at(1));

View File

@@ -1,7 +1,7 @@
use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemCommon, ElemVecZnx};
use crate::parameters::Parameters;
use base2k::{Module, VecZnx};
use base2k::{LAYOUT, Module, VecZnx};
pub struct Plaintext(pub Elem<VecZnx>);
@@ -79,6 +79,10 @@ impl ElemCommon<VecZnx> for Plaintext {
self.elem().size()
}
fn layout(&self) -> LAYOUT {
self.elem().layout()
}
fn rows(&self) -> usize {
self.0.rows()
}

View File

@@ -18,8 +18,8 @@ pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usiz
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(std::cmp::min(res_cols, in_cols))
+ 2 * module.bytes_of_vec_znx_dft(gct_cols);
+ module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols))
+ 2 * module.bytes_of_vec_znx_dft(1, gct_cols);
}
pub fn rgsw_product(
@@ -40,13 +40,13 @@ pub fn rgsw_product(
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft);
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft);
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
@@ -82,13 +82,13 @@ pub fn rgsw_product_inplace(
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_c1_dft);
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft);
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();

View File

@@ -22,7 +22,7 @@ impl Parameters {
pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(std::cmp::min(c_cols, a_cols));
+ 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols));
}
pub fn trace_inplace(
@@ -59,11 +59,11 @@ pub fn trace_inplace(
let cols: usize = std::cmp::min(b_cols, a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(b_cols));
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(b_cols, tmp_bytes_res_dft);
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
let log_base2k: usize = a.log_base2k();