From f64d7868196d448aa24679a53b60c24053d2d722 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 24 Apr 2025 21:53:06 +0200 Subject: [PATCH] fixed rlwe package --- rlwe/benches/gadget_product.rs | 6 ++---- rlwe/examples/encryption.rs | 8 ++++---- rlwe/src/automorphism.rs | 6 +++--- rlwe/src/decryptor.rs | 1 - rlwe/src/elem.rs | 16 ++++++++-------- rlwe/src/gadget_product.rs | 8 ++++---- rlwe/src/key_switching.rs | 2 +- rlwe/src/rgsw_product.rs | 4 ++-- rlwe/src/trace.rs | 10 +++++----- 9 files changed, 29 insertions(+), 32 deletions(-) diff --git a/rlwe/benches/gadget_product.rs b/rlwe/benches/gadget_product.rs index 94df0b6..5af41f1 100644 --- a/rlwe/benches/gadget_product.rs +++ b/rlwe/benches/gadget_product.rs @@ -1,6 +1,4 @@ -use base2k::{ - BACKEND, Infos, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8, -}; +use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use rlwe::{ ciphertext::{Ciphertext, new_gadget_ciphertext}, @@ -109,7 +107,7 @@ fn bench_gadget_product_inplace(c: &mut Criterion) { let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(gadget_ct.cols()); - let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); + let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q()); params .module() .fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa); diff --git a/rlwe/examples/encryption.rs b/rlwe/examples/encryption.rs index cd9c7a1..b9d66cd 100644 --- a/rlwe/examples/encryption.rs +++ b/rlwe/examples/encryption.rs @@ -39,11 +39,11 @@ fn main() { let log_k: usize = params.log_q() - 20; - pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32); + pt.0.value[0].encode_vec_i64(0, log_base2k, log_k, &want, 32); pt.0.value[0].normalize(log_base2k, &mut tmp_bytes); println!("log_k: {}", log_k); - pt.0.value[0].print(pt.cols(), 16); + pt.0.value[0].print(0, pt.cols(), 16); println!(); let mut ct: Ciphertext = params.new_ciphertext(params.log_q()); @@ -64,12 +64,12 @@ fn main() { ); params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes); - pt.0.value[0].print(pt.cols(), 16); + pt.0.value[0].print(0, pt.cols(), 16); let mut have = vec![i64::default(); params.n()]; println!("pt: {}", log_k); - pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have); + pt.0.value[0].decode_vec_i64(0, pt.log_base2k(), log_k, &mut have); println!("want: {:?}", &want[..16]); println!("have: {:?}", &have[..16]); diff --git a/rlwe/src/automorphism.rs b/rlwe/src/automorphism.rs index 46a4fc5..6390c11 100644 --- a/rlwe/src/automorphism.rs +++ b/rlwe/src/automorphism.rs @@ -11,7 +11,7 @@ use base2k::{ VmpPMatOps, assert_alignement, }; use sampling::source::Source; -use std::{cmp::min, collections::HashMap}; +use std::collections::HashMap; /// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p} pub struct AutomorphismKey { @@ -295,7 +295,7 @@ mod test { let mut pt: Plaintext = params.new_plaintext(log_q); let mut pt_auto: Plaintext = params.new_plaintext(log_q); - pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0)); encrypt_rlwe_sk( @@ -334,7 +334,7 @@ mod test { // pt.at(0).print(pt.cols(), 16); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); let var_msg: f64 = (params.xs() as f64) / params.n() as f64; let var_a_err: f64 = 1f64 / 12f64; diff --git a/rlwe/src/decryptor.rs b/rlwe/src/decryptor.rs index 04d56bc..8f0ff76 100644 --- a/rlwe/src/decryptor.rs +++ b/rlwe/src/decryptor.rs @@ -9,7 +9,6 @@ use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZn use std::cmp::min; pub struct Decryptor { - #[warn(dead_code)] sk: SvpPPol, } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index e0252a6..96d11f1 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -25,11 +25,11 @@ impl ElemVecZnx for Elem { let n: usize = module.n(); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); let mut value: Vec = Vec::new(); - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, limbs); + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + let elem_size = VecZnx::bytes_of(n, size, cols); let mut ptr: usize = 0; (0..size).for_each(|_| { - value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..])); + value.push(VecZnx::from_bytes(n, 1, cols, &mut bytes[ptr..])); ptr += elem_size }); Self { @@ -45,11 +45,11 @@ impl ElemVecZnx for Elem { let n: usize = module.n(); assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size)); let mut value: Vec = Vec::new(); - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; - let elem_size = VecZnx::bytes_of(n, limbs); + let cols: usize = (log_q + log_base2k - 1) / log_base2k; + let elem_size = VecZnx::bytes_of(n, 1, cols); let mut ptr: usize = 0; (0..size).for_each(|_| { - value.push(VecZnx::from_bytes_borrow(n, limbs, &mut bytes[ptr..])); + value.push(VecZnx::from_bytes_borrow(n, 1, cols, &mut bytes[ptr..])); ptr += elem_size }); Self { @@ -135,9 +135,9 @@ impl ElemCommon for Elem { impl Elem { pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self { assert!(rows > 0); - let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + let cols: usize = (log_q + log_base2k - 1) / log_base2k; let mut value: Vec = Vec::new(); - (0..rows).for_each(|_| value.push(module.new_vec_znx(limbs))); + (0..rows).for_each(|_| value.push(module.new_vec_znx(1, cols))); Self { value, log_q, diff --git a/rlwe/src/gadget_product.rs b/rlwe/src/gadget_product.rs index 85b10e6..87df6f1 100644 --- a/rlwe/src/gadget_product.rs +++ b/rlwe/src/gadget_product.rs @@ -206,7 +206,7 @@ mod test { // Intermediate buffers // Input polynopmial, uniformly distributed - let mut a: VecZnx = params.module().new_vec_znx(params.cols_q()); + let mut a: VecZnx = params.module().new_vec_znx(1, params.cols_q()); params .module() .fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa); @@ -217,7 +217,7 @@ mod test { // Ideal output = a * s let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(a.cols()); let mut a_big: VecZnxBig = a_dft.as_vec_znx_big(); - let mut a_times_s: VecZnx = params.module().new_vec_znx(a.cols()); + let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols()); // a * sk0 params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a); @@ -232,7 +232,7 @@ mod test { // Iterates over all possible cols values for input/output polynomials and gadget ciphertext. (1..a.cols() + 1).for_each(|a_cols| { - let mut a_trunc: VecZnx = params.module().new_vec_znx(a_cols); + let mut a_trunc: VecZnx = params.module().new_vec_znx(1, a_cols); a_trunc.copy_from(&a); (1..gadget_ct.cols() + 1).for_each(|b_cols| { @@ -296,7 +296,7 @@ mod test { // pt.at(0).print(pt.elem().cols(), 16); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); let var_a_err: f64; diff --git a/rlwe/src/key_switching.rs b/rlwe/src/key_switching.rs index 78d48d5..46b557a 100644 --- a/rlwe/src/key_switching.rs +++ b/rlwe/src/key_switching.rs @@ -1,6 +1,6 @@ use crate::ciphertext::Ciphertext; use crate::elem::ElemCommon; -use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; +use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement}; use std::cmp::min; pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize { diff --git a/rlwe/src/rgsw_product.rs b/rlwe/src/rgsw_product.rs index 0c1bdba..71a3cee 100644 --- a/rlwe/src/rgsw_product.rs +++ b/rlwe/src/rgsw_product.rs @@ -193,7 +193,7 @@ mod test { let mut pt: Plaintext = params.new_plaintext(log_q); let mut pt_rotate: Plaintext = params.new_plaintext(log_q); - pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0)); @@ -222,7 +222,7 @@ mod test { // pt.at(0).print(pt.cols(), 16); - let noise_have: f64 = pt.at(0).std(log_base2k).log2(); + let noise_have: f64 = pt.at(0).std(0, log_base2k).log2(); let var_msg: f64 = 1f64 / params.n() as f64; // X^{k} let var_a0_err: f64 = params.xe() * params.xe(); diff --git a/rlwe/src/trace.rs b/rlwe/src/trace.rs index 70bb92d..8fc0dc8 100644 --- a/rlwe/src/trace.rs +++ b/rlwe/src/trace.rs @@ -189,12 +189,12 @@ mod test { let mut ct: Ciphertext = params.new_ciphertext(log_q); let mut pt: Plaintext = params.new_plaintext(log_q); - pt.at_mut(0).encode_vec_i64(log_base2k, log_k, &data, 32); + pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32); pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes); - pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); + pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); - pt.at(0).print(pt.cols(), 16); + pt.at(0).print(0, pt.cols(), 16); encrypt_rlwe_sk( module, @@ -227,9 +227,9 @@ mod test { &mut tmp_bytes, ); - pt.at(0).print(pt.cols(), 16); + pt.at(0).print(0, pt.cols(), 16); - pt.at(0).decode_vec_i64(log_base2k, log_k, &mut data); + pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data); println!("trace: {:?}", &data[..16]); }