wip rlwe + some bug fixes in base2k

This commit is contained in:
Jean-Philippe Bossuat
2025-02-11 18:16:09 +01:00
parent ec6968d52a
commit 8f33442d5a
18 changed files with 801 additions and 86 deletions

View File

@@ -38,7 +38,7 @@ fn main() {
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs()); let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs());
// Applies buf_dft <- s * a // Applies buf_dft <- s * a
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs());
// Alias scratch space // Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
@@ -67,11 +67,11 @@ fn main() {
//Decrypt //Decrypt
// buf_big <- a * s // buf_big <- a * s
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs());
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs()); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs());
// buf_big <- a * s + b // buf_big <- a * s + b
module.vec_znx_big_add_small_inplace(&mut buf_big, &b); module.vec_znx_big_add_small_inplace(&mut buf_big, &b, b.limbs());
// res <- normalize(buf_big) // res <- normalize(buf_big)
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry); module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);

View File

@@ -56,6 +56,8 @@ impl Encoding for VecZnx {
fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
let limbs: usize = (log_k + log_base2k - 1) / log_base2k; let limbs: usize = (log_k + log_base2k - 1) / log_base2k;
println!("limbs: {}", limbs);
assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs()); assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs());
let size: usize = min(data.len(), self.n()); let size: usize = min(data.len(), self.n());
@@ -65,10 +67,10 @@ impl Encoding for VecZnx {
// values on the last limb. // values on the last limb.
// Else we decompose values base2k. // Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k { if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
(0..limbs - 1).for_each(|i| unsafe { (0..self.limbs()).for_each(|i| unsafe {
znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr());
}); });
self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]); self.at_mut(limbs - 1)[..size].copy_from_slice(&data[..size]);
} else { } else {
let mask: i64 = (1 << log_base2k) - 1; let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k); let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);

View File

@@ -91,6 +91,18 @@ unsafe extern "C" {
); );
} }
unsafe extern "C" {
pub unsafe fn vmp_prepare_row(
module: *const MODULE,
pmat: *mut VMP_PMAT,
row: *const i64,
row_i: u64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64; pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
} }

View File

@@ -91,7 +91,7 @@ pub trait SvpPPolOps {
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of /// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol]. /// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx); fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize);
} }
impl SvpPPolOps for Module { impl SvpPPolOps for Module {
@@ -107,14 +107,13 @@ impl SvpPPolOps for Module {
unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
} }
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) { fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize) {
let limbs: u64 = b.limbs() as u64;
assert!( assert!(
c.limbs() as u64 >= limbs, c.limbs() >= b_limbs,
"invalid c_vector: c_vector.limbs()={} < b.limbs()={}", "invalid c_vector: c_vector.limbs()={} < b.limbs()={}",
c.limbs(), c.limbs(),
limbs b_limbs
); );
unsafe { svp::svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) } unsafe { svp::svp_apply_dft(self.0, c.0, b_limbs as u64, a.0, b.as_ptr(), b_limbs as u64, b.n() as u64) }
} }
} }

View File

@@ -117,23 +117,22 @@ impl Module {
} }
// b <- b + a // b <- b + a
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx, a_limbs: usize) {
let limbs: usize = a.limbs();
assert!( assert!(
b.limbs() >= limbs, b.limbs() >= a_limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}", "invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(), b.limbs(),
limbs a_limbs
); );
unsafe { unsafe {
vec_znx_big::vec_znx_big_add_small( vec_znx_big::vec_znx_big_add_small(
self.0, self.0,
b.0, b.0,
limbs as u64, a_limbs as u64,
b.0, b.0,
limbs as u64, a_limbs as u64,
a.as_ptr(), a.as_ptr(),
limbs as u64, a_limbs as u64,
a.n() as u64, a.n() as u64,
) )
} }

View File

@@ -1,7 +1,7 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft; use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::{Module, VecZnxBig}; use crate::{Module, VecZnx, VecZnxBig};
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
@@ -30,6 +30,25 @@ impl Module {
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) }
} }
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `limbs`: the number of limbs of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
assert!(
bytes.len() >= self.bytes_of_vec_znx_dft(limbs),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(),
self.bytes_of_vec_znx_dft(limbs)
);
VecZnxDft::from_bytes(limbs, bytes)
}
/// Returns the minimum number of bytes necessary to allocate /// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes]. /// a new [VecZnxDft] through [VecZnxDft::from_bytes].
pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize { pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize {
@@ -52,6 +71,29 @@ impl Module {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize } unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize }
} }
/// b <- DFT(a)
///
/// # Panics
/// If b.limbs < a_limbs
pub fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize) {
assert!(
b.limbs() >= a_limbs,
"invalid a_limbs: b.limbs()={} < a_limbs={}",
b.limbs(),
a_limbs
);
unsafe {
vec_znx_dft::vec_znx_dft(
self.0,
b.0,
a_limbs as u64,
a.as_ptr(),
a_limbs as u64,
a.n as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
pub fn vec_znx_idft( pub fn vec_znx_idft(
&self, &self,

View File

@@ -169,6 +169,38 @@ pub trait VmpPMatOps {
/// ``` /// ```
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]); fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
/// * `row_i`: the index of the row to prepare.
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// /// # Example
/// ```
/// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
/// use std::cmp::min;
///
/// let n: usize = 1024;
/// let module: Module = Module::new::<FFT64>(n);
/// let rows: usize = 5;
/// let cols: usize = 6;
///
/// let vecznx: module.new_vec_znx(cols);
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_row(&mut vmp_pmat, &vecznx, 0, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft]. /// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
/// ///
/// # Arguments /// # Arguments
@@ -404,6 +436,20 @@ impl VmpPMatOps for Module {
} }
} }
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) {
unsafe {
vmp::vmp_prepare_row(
self.0,
b.data(),
a.data.as_ptr(),
row_i as u64,
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
);
}
}
fn vmp_apply_dft_tmp_bytes( fn vmp_apply_dft_tmp_bytes(
&self, &self,
c_limbs: usize, c_limbs: usize,

View File

@@ -0,0 +1,77 @@
use base2k::{Encoding, FFT64, SvpPPolOps};
use rlwe::{
ciphertext::Ciphertext,
decryptor::{Decryptor, decrypt_rlwe_thread_safe_tmp_byte},
encryptor::{EncryptorSk, encrypt_rlwe_sk_tmp_bytes},
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use sampling::source::{Source, new_seed};
fn main() {
let params_lit: ParametersLiteral = ParametersLiteral {
log_n: 10,
log_q: 54,
log_p: 0,
log_base2k: 17,
log_scale: 20,
xe: 3.2,
xs: 128,
};
let params: Parameters = Parameters::new::<FFT64>(&params_lit);
let mut tmp_bytes: Vec<u8> = vec![
0u8;
params.decrypt_rlwe_thread_safe_tmp_byte(params.log_q())
| params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
];
let sk: SecretKey = SecretKey::new(params.module());
let mut want = vec![i64::default(); params.n()];
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let mut pt: Plaintext = params.new_plaintext(params.log_q() - 20);
let log_base2k = pt.log_base2k();
let log_k: usize = 17;
pt.0.value[0].encode_vec_i64(log_base2k, log_k, &want, 32);
pt.0.value[0].normalize(log_base2k, &mut tmp_bytes);
println!("log_k: {}", log_k);
pt.0.value[0].print_limbs(pt.limbs(), 16);
let mut ct: Ciphertext = params.new_ciphertext(params.log_q());
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut sk_svp_ppol: base2k::SvpPPol = params.module().svp_new_ppol();
params.module().svp_prepare(&mut sk_svp_ppol, &sk.0);
params.encrypt_rlwe_sk_thread_safe(
&mut ct,
Some(&pt),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
&mut tmp_bytes,
);
params.decrypt_rlwe_thread_safe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
pt.0.value[0].print_limbs(pt.limbs(), 16);
let mut have = vec![i64::default(); params.n()];
println!("pt: {}", log_k);
pt.0.value[0].decode_vec_i64(pt.log_base2k(), log_k, &mut have);
println!("want: {:?}", &want[..16]);
println!("have: {:?}", &have[..16]);
}

View File

@@ -1,20 +1,19 @@
use crate::elem::Elem; use crate::elem::Elem;
use crate::parameters::Parameters;
use crate::plaintext::Plaintext; use crate::plaintext::Plaintext;
use base2k::VecZnx; use base2k::{Module, VecZnx, VmpPMat, VmpPMatOps};
pub struct Ciphertext(pub Elem); pub struct Ciphertext(pub Elem);
/*
impl Parameters {
pub fn new_ciphertext(&self, degree: usize, log_base2k: usize, log_q: usize) -> Ciphertext {
Ciphertext(self.new_elem(degree, log_base2k, log_q))
}
}
*/
impl Ciphertext { impl Ciphertext {
pub fn new(n: usize, log_base2k: usize, log_q: usize, degree: usize) -> Self { pub fn new(
Self(Elem::new(n, log_base2k, log_q, degree)) module: &Module,
log_base2k: usize,
log_q: usize,
degree: usize,
log_scale: usize,
) -> Self {
Self(Elem::new(module, log_base2k, log_q, degree, log_scale))
} }
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
@@ -45,7 +44,75 @@ impl Ciphertext {
self.0.log_base2k() self.0.log_base2k()
} }
pub fn log_scale(&self) -> usize {
self.0.log_scale
}
pub fn as_plaintext(&self) -> Plaintext { pub fn as_plaintext(&self) -> Plaintext {
unsafe { Plaintext(std::ptr::read(&self.0)) } unsafe { Plaintext(std::ptr::read(&self.0)) }
} }
} }
impl Parameters {
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext {
Ciphertext::new(self.module(), self.log_base2k(), log_q, self.log_scale(), 1)
}
}
pub struct GadgetCiphertext {
pub value: Vec<VmpPMat>,
pub log_base2k: usize,
pub log_q: usize,
pub log_scale: usize,
}
impl GadgetCiphertext {
pub fn new(
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
log_scale: usize,
) -> Self {
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut value: Vec<VmpPMat> = Vec::new();
(0..rows).for_each(|_| value.push(module.new_vmp_pmat(rows, cols)));
Self {
value,
log_base2k,
log_q,
log_scale,
}
}
pub fn n(&self) -> usize {
self.value[0].n
}
pub fn rows(&self) -> usize {
self.value[0].rows
}
pub fn cols(&self) -> usize {
self.value[0].cols
}
pub fn degree(&self) -> usize {
self.value.len() - 1
}
pub fn log_q(&self) -> usize {
self.log_q
}
pub fn log_base2k(&self) -> usize {
self.log_base2k
}
}
pub struct RGSWCiphertext {
pub value: [GadgetCiphertext; 2],
pub log_base2k: usize,
pub log_q: usize,
pub log_p: usize,
}

80
rlwe/src/decryptor.rs Normal file
View File

@@ -0,0 +1,80 @@
use crate::{
ciphertext::Ciphertext, keys::SecretKey, parameters::Parameters, plaintext::Plaintext,
};
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnxDft};
use std::cmp::min;
pub struct Decryptor {
sk: SvpPPol,
}
impl Decryptor {
pub fn new(params: &Parameters, sk: &SecretKey) -> Self {
let mut sk_svp_ppol: SvpPPol = params.module().svp_new_ppol();
sk.prepare(params.module(), &mut sk_svp_ppol);
Self { sk: sk_svp_ppol }
}
}
pub fn decrypt_rlwe_thread_safe_tmp_byte(module: &Module, limbs: usize) -> usize {
module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes()
}
impl Parameters {
pub fn decrypt_rlwe_thread_safe_tmp_byte(&self, log_q: usize) -> usize {
decrypt_rlwe_thread_safe_tmp_byte(
self.module(),
(log_q + self.log_base2k() - 1) / self.log_base2k(),
)
}
pub fn decrypt_rlwe_thread_safe(
&self,
res: &mut Plaintext,
ct: &Ciphertext,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) {
decrypt_rlwe_thread_safe(self.module(), res, ct, sk, tmp_bytes)
}
}
pub fn decrypt_rlwe_thread_safe(
module: &Module,
res: &mut Plaintext,
ct: &Ciphertext,
sk: &SvpPPol,
tmp_bytes: &mut [u8],
) {
let limbs: usize = min(res.limbs(), ct.limbs());
assert!(
tmp_bytes.len() >= decrypt_rlwe_thread_safe_tmp_byte(module, limbs),
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_thread_safe_tmp_byte={}",
tmp_bytes.len(),
decrypt_rlwe_thread_safe_tmp_byte(module, limbs)
);
let res_dft_bytes: usize = module.bytes_of_vec_znx_dft(limbs);
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes(limbs, tmp_bytes);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk)
module.svp_apply_dft(&mut res_dft, sk, &ct.0.value[1], limbs);
// res_big <- ct[1] x sk
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, limbs);
// res_big <- ct[1] x sk + ct[0]
module.vec_znx_big_add_small_inplace(&mut res_big, &ct.0.value[0], limbs);
// res <- normalize(ct[1] x sk + ct[0])
module.vec_znx_big_normalize(
ct.log_base2k(),
res.at_mut(0),
&res_big,
&mut tmp_bytes[res_dft_bytes..],
);
res.0.log_base2k = ct.log_base2k();
res.0.log_q = min(res.log_q(), ct.log_q());
res.0.log_scale = ct.log_scale();
}

View File

@@ -1,21 +1,59 @@
use crate::parameters::Parameters; use crate::parameters::Parameters;
use base2k::{Infos, VecZnx}; use base2k::{Infos, Module, VecZnx, VecZnxOps};
pub struct Elem { pub struct Elem {
pub value: Vec<VecZnx>, pub value: Vec<VecZnx>,
pub log_base2k: usize, pub log_base2k: usize,
pub log_q: usize, pub log_q: usize,
pub log_scale: usize,
} }
impl Elem { impl Elem {
pub fn new(n: usize, log_base2k: usize, log_q: usize, degree: usize) -> Self { pub fn new(
module: &Module,
log_base2k: usize,
log_q: usize,
degree: usize,
log_scale: usize,
) -> Self {
let limbs: usize = (log_q + log_base2k - 1) / log_base2k; let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let mut value: Vec<VecZnx> = Vec::new(); let mut value: Vec<VecZnx> = Vec::new();
(0..degree + 1).for_each(|_| value.push(VecZnx::new(n, limbs))); (0..degree + 1).for_each(|_| value.push(module.new_vec_znx(limbs)));
Self { Self {
value, value,
log_base2k,
log_q, log_q,
log_base2k,
log_scale: log_scale,
}
}
pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, degree: usize) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k;
module.n() * cols * (degree + 1) * 8
}
pub fn from_bytes(
module: &Module,
log_base2k: usize,
log_q: usize,
degree: usize,
bytes: &mut [u8],
) -> Self {
let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, degree));
let mut value: Vec<VecZnx> = Vec::new();
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let size = VecZnx::bytes(n, limbs);
let mut ptr: usize = 0;
(0..degree + 1).for_each(|_| {
value.push(VecZnx::from_bytes(n, limbs, &mut bytes[ptr..]));
ptr += size
});
Self {
value,
log_q,
log_base2k,
log_scale: 0,
} }
} }
@@ -35,6 +73,10 @@ impl Elem {
self.log_base2k self.log_base2k
} }
pub fn log_scale(&self) -> usize {
self.log_scale
}
pub fn log_q(&self) -> usize { pub fn log_q(&self) -> usize {
self.log_q self.log_q
} }
@@ -49,3 +91,13 @@ impl Elem {
&mut self.value[i] &mut self.value[i]
} }
} }
impl Parameters {
pub fn bytes_of_elem(&self, log_q: usize, degree: usize) -> usize {
Elem::bytes_of(self.module(), self.log_base2k(), log_q, degree)
}
pub fn elem_from_bytes(&self, log_q: usize, degree: usize, bytes: &mut [u8]) -> Elem {
Elem::from_bytes(self.module(), self.log_base2k(), log_q, degree, bytes)
}
}

View File

@@ -1,84 +1,141 @@
use crate::ciphertext::Ciphertext; use crate::ciphertext::{Ciphertext, GadgetCiphertext};
use crate::elem::Elem; use crate::elem::Elem;
use crate::keys::SecretKey; use crate::keys::SecretKey;
use crate::parameters::Parameters; use crate::parameters::Parameters;
use crate::plaintext::Plaintext; use crate::plaintext::Plaintext;
use base2k::ffi::znx::znx_zero_i64_ref;
use base2k::sampling::Sampling; use base2k::sampling::Sampling;
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft}; use base2k::{
use sampling::source::Source; Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxDft, VecZnxOps, VmpPMatOps,
};
use sampling::source::{Source, new_seed};
pub struct EncryptorSk { pub struct EncryptorSk {
pub sk: SvpPPol, sk: SvpPPol,
source_xa: Source,
source_xe: Source,
initialized: bool,
tmp_bytes: Vec<u8>,
} }
impl EncryptorSk { impl EncryptorSk {
pub fn new(params: &Parameters, sk: &SecretKey) -> Self { pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self {
let mut sk_svp_ppol: SvpPPol = params.module().svp_new_ppol(); let mut sk_svp_ppol: SvpPPol = params.module().svp_new_ppol();
params.module().svp_prepare(&mut sk_svp_ppol, &sk.0); let mut initialized: bool = false;
Self { sk: sk_svp_ppol } if let Some(sk) = sk {
sk.prepare(params.module(), &mut sk_svp_ppol);
initialized = true;
}
Self {
sk: sk_svp_ppol,
initialized,
source_xa: Source::new(new_seed()),
source_xe: Source::new(new_seed()),
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.limbs_qp())],
}
}
pub fn set_sk(&mut self, module: &Module, sk: &SecretKey) {
sk.prepare(module, &mut self.sk);
self.initialized = true;
}
pub fn seed_source_xa(&mut self, seed: [u8; 32]) {
self.source_xa = Source::new(seed)
}
pub fn seed_source_xe(&mut self, seed: [u8; 32]) {
self.source_xe = Source::new(seed)
} }
pub fn encrypt_rlwe_sk( pub fn encrypt_rlwe_sk(
&mut self,
params: &Parameters,
ct: &mut Ciphertext,
pt: Option<&Plaintext>,
) {
assert!(
self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
);
params.encrypt_rlwe_sk_thread_safe(
ct,
pt,
&self.sk,
&mut self.source_xa,
&mut self.source_xe,
&mut self.tmp_bytes,
);
}
pub fn encrypt_rlwe_sk_thread_safe(
&self, &self,
params: &Parameters, params: &Parameters,
ct: &mut Ciphertext, ct: &mut Ciphertext,
pt: Option<&Plaintext>, pt: Option<&Plaintext>,
xa_source: &mut Source, source_xa: &mut Source,
xe_source: &mut Source, source_xe: &mut Source,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
params.encrypt_rlwe_sk(ct, pt, &self.sk, xa_source, xe_source, tmp_bytes); assert!(
self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk_thread_safe]: [EncryptorSk] has not been initialized with a [SecretKey]"
);
params.encrypt_rlwe_sk_thread_safe(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes);
} }
} }
impl Parameters { impl Parameters {
pub fn encrypt_rlwe_sk_tmp_bytes(&self, limbs: usize) -> usize { pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
encrypt_rlwe_sk_tmp_bytes(self.module(), limbs) encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
} }
pub fn encrypt_rlwe_sk( pub fn encrypt_rlwe_sk_thread_safe(
&self, &self,
ct: &mut Ciphertext, ct: &mut Ciphertext,
pt: Option<&Plaintext>, pt: Option<&Plaintext>,
sk: &SvpPPol, sk: &SvpPPol,
xa_source: &mut Source, source_xa: &mut Source,
xe_source: &mut Source, source_xe: &mut Source,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
encrypt_rlwe_sk( encrypt_rlwe_sk_thread_safe(
self.module(), self.module(),
&mut ct.0, &mut ct.0,
pt.map(|pt| &pt.0), pt.map(|pt| &pt.0),
sk, sk,
xa_source, source_xa,
xe_source, source_xe,
self.xe(), self.xe(),
tmp_bytes, tmp_bytes,
) )
} }
} }
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, limbs: usize) -> usize { pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
module.bytes_of_vec_znx_dft(limbs) + module.vec_znx_big_normalize_tmp_bytes() module.bytes_of_vec_znx_dft((log_q + log_base2k - 1) / log_base2k)
+ module.vec_znx_big_normalize_tmp_bytes()
} }
pub fn encrypt_rlwe_sk( pub fn encrypt_rlwe_sk_thread_safe(
module: &Module, module: &Module,
ct: &mut Elem, ct: &mut Elem,
pt: Option<&Elem>, pt: Option<&Elem>,
sk: &SvpPPol, sk: &SvpPPol,
xa_source: &mut Source, source_xa: &mut Source,
xe_source: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) { ) {
let limbs: usize = ct.limbs(); let limbs: usize = ct.limbs();
let log_base2k: usize = ct.log_base2k();
let log_q: usize = ct.log_q();
assert!( assert!(
tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, limbs), tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q),
"invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}", "invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}",
tmp_bytes.len(), tmp_bytes.len(),
encrypt_rlwe_sk_tmp_bytes(module, limbs) encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
); );
let log_q: usize = ct.log_q(); let log_q: usize = ct.log_q();
@@ -86,7 +143,7 @@ pub fn encrypt_rlwe_sk(
let c1: &mut VecZnx = ct.at_mut(1); let c1: &mut VecZnx = ct.at_mut(1);
// c1 <- Z_{2^prec}[X]/(X^{N}+1) // c1 <- Z_{2^prec}[X]/(X^{N}+1)
c1.fill_uniform(limbs, log_base2k, xa_source); c1.fill_uniform(log_base2k, limbs, source_xa);
let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(limbs); let bytes_of_vec_znx_dft = module.bytes_of_vec_znx_dft(limbs);
@@ -95,7 +152,7 @@ pub fn encrypt_rlwe_sk(
VecZnxDft::from_bytes(limbs, &mut tmp_bytes[..bytes_of_vec_znx_dft]); VecZnxDft::from_bytes(limbs, &mut tmp_bytes[..bytes_of_vec_znx_dft]);
// Applies buf_dft <- s * c1 // Applies buf_dft <- s * c1
module.svp_apply_dft(&mut buf_dft, sk, c1); module.svp_apply_dft(&mut buf_dft, sk, c1, limbs);
// Alias scratch space // Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big(); let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
@@ -110,5 +167,90 @@ pub fn encrypt_rlwe_sk(
// c0 <- normalize(buf_big) + e // c0 <- normalize(buf_big) + e
let c0: &mut VecZnx = ct.at_mut(0); let c0: &mut VecZnx = ct.at_mut(0);
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry); module.vec_znx_big_normalize(log_base2k, c0, &buf_big, carry);
c0.add_normal(log_base2k, log_q, xe_source, sigma, (sigma * 6.0).ceil()); c0.add_normal(log_base2k, log_q, source_xe, sigma, (sigma * 6.0).ceil());
}
pub fn encrypt_grlwe_sk_tmp_bytes(
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k;
Elem::bytes_of(module, log_base2k, log_q, 1)
+ Plaintext::bytes_of(module, log_base2k, log_q)
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
+ module.vmp_prepare_tmp_bytes(rows, cols)
}
pub fn encrypt_grlwe_sk_thread_safe(
module: &Module,
ct: &mut GadgetCiphertext,
m: &Scalar,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let rows: usize = ct.rows();
let log_q: usize = ct.log_q();
let log_base2k: usize = ct.log_base2k();
let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q);
assert!(
tmp_bytes.len() >= min_tmp_bytes_len,
"invalid tmp_bytes: tmp_bytes.len()={} < encrypt_grlwe_sk_tmp_bytes={}",
tmp_bytes.len(),
min_tmp_bytes_len
);
let mut ptr: usize = 0;
let mut tmp_elem: Elem = Elem::from_bytes(module, log_base2k, ct.log_q(), 1, tmp_bytes);
let bytes_of_elem: usize = Elem::bytes_of(module, log_base2k, log_q, 1);
ptr += bytes_of_elem;
let mut tmp_pt: Plaintext =
Plaintext::from_bytes(module, log_base2k, log_q, &mut tmp_bytes[ptr..]);
ptr += Plaintext::bytes_of(module, log_base2k, log_q);
let (tmp_bytes_encrypt_sk, tmp_bytes_vmp_prepare_row) =
tmp_bytes[ptr..].split_at_mut(encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q));
(0..rows).for_each(|row_i| {
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
tmp_pt.0.value[0].at_mut(row_i).copy_from_slice(&m.0);
// Encrypts RLWE(m * 2^{-log_base2k*i})
encrypt_rlwe_sk_thread_safe(
module,
&mut tmp_elem,
Some(&tmp_pt.0),
sk,
source_xa,
source_xe,
sigma,
tmp_bytes_encrypt_sk,
);
// Zeroes the ith-row of tmp_pt
tmp_pt.0.value[0].at_mut(row_i).fill(0);
// GRLWE[row_i][0] = -as + m * 2^{-i*log_base2k} + e*2^{-log_q}
module.vmp_prepare_row(
&mut ct.value[0],
tmp_elem.at(0),
row_i,
tmp_bytes_vmp_prepare_row,
);
// GRLWE[row_i][1] = a
module.vmp_prepare_row(
&mut ct.value[1],
tmp_elem.at(1),
row_i,
tmp_bytes_vmp_prepare_row,
);
})
} }

82
rlwe/src/evaluator.rs Normal file
View File

@@ -0,0 +1,82 @@
use crate::ciphertext::{Ciphertext, GadgetCiphertext};
use base2k::{Module, VecZnxBig, VecZnxDft, VmpPMatOps};
pub fn gadget_product_tmp_bytes(
module: &Module,
log_base2k: usize,
out_log_q: usize,
in_log_q: usize,
gct_rows: usize,
gct_log_q: usize,
) -> usize {
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
let in_limbs: usize = (in_log_q + log_base2k - 1) / log_base2k;
let out_limbs: usize = (out_log_q + log_base2k - 1) / log_base2k;
module.vmp_apply_dft_to_dft_tmp_bytes(out_limbs, in_limbs, gct_rows, gct_cols)
+ 2 * module.bytes_of_vec_znx_dft(gct_cols)
}
pub fn gadget_product_inplace(
module: &Module,
a: &mut Ciphertext,
b: &GadgetCiphertext,
tmp_bytes: &mut [u8],
) {
// This is safe to do because the relevant values of a are copied to a buffer before being
// overwritten.
unsafe {
let a_ptr: *mut Ciphertext = a;
gadget_product(module, a, &*a_ptr, b, tmp_bytes)
}
}
pub fn gadget_product(
module: &Module,
res: &mut Ciphertext,
a: &Ciphertext,
b: &GadgetCiphertext,
tmp_bytes: &mut [u8],
) {
assert!(
a.log_base2k() == b.log_base2k(),
"invalid inputs: a.log_base2k={} != b.log_base2k={}",
a.log_base2k(),
b.log_base2k()
);
let log_base2k: usize = b.log_base2k();
let cols: usize = b.cols();
let (tmp_bytes_vmp_apply_dft, tmp_bytes) =
tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(cols));
let (tmp_bytes_c1_dft, tmp_bytes_res_dft) = tmp_bytes.split_at_mut(tmp_bytes.len() >> 1);
let mut c1_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_c1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_from_bytes(cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// c1_dft <- DFT(b[1])
module.vec_znx_dft(&mut c1_dft, a.at(1), a.limbs());
// res_dft <- DFT(c1) x GadgetCiphertext[0]
module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[0], tmp_bytes_vmp_apply_dft);
// res_big <- IDFT(DFT(c1) x GadgetCiphertext[0])
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
// res_big <- c0 + c1_dft x GadgetCiphertext[0]
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0), cols);
// res[0] = normalize(c0 + c1_dft x GadgetCiphertext[0])
module.vec_znx_big_normalize(log_base2k, res.at_mut(0), &res_big, tmp_bytes_vmp_apply_dft);
// res_dft <- DFT(c1) x GadgetCiphertext[1]
module.vmp_apply_dft_to_dft(&mut res_dft, &c1_dft, &b.value[1], tmp_bytes_vmp_apply_dft);
// res_big <- IDFT(DFT(c1) x GadgetCiphertext[1])
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft, cols);
// res[1] = normalize(c1_dft x GadgetCiphertext[1])
module.vec_znx_big_normalize(log_base2k, res.at_mut(1), &res_big, tmp_bytes_vmp_apply_dft);
}

View File

@@ -1 +1,53 @@
use crate::keys::{PublicKey, SecretKey, SwitchingKey};
use crate::parameters::Parameters;
use base2k::SvpPPol;
use sampling::source::Source;
pub struct KeyGenerator {} pub struct KeyGenerator {}
impl KeyGenerator {
pub fn gen_secret_key_thread_safe(
&self,
params: &Parameters,
source: &mut Source,
) -> SecretKey {
let mut sk: SecretKey = SecretKey::new(params.module());
sk.fill_ternary_hw(params.xs(), source);
sk
}
pub fn gen_public_key_thread_safe(
&self,
params: &Parameters,
sk_ppol: &SvpPPol,
source: &mut Source,
tmp_bytes: &mut [u8],
) -> PublicKey {
let mut xa_source: Source = source.branch();
let mut xe_source: Source = source.branch();
let mut pk: PublicKey =
PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
pk.gen_thread_safe(
params.module(),
sk_ppol,
params.xe(),
&mut xa_source,
&mut xe_source,
tmp_bytes,
);
pk
}
pub fn gen_switching_key(
&self,
params: &Parameters,
sk_in: &SecretKey,
sk_out: &SecretKey,
rows: usize,
log_q: usize,
) -> SwitchingKey {
let swk = SwitchingKey::new(params.module(), params.log_base2k(), rows, log_q, 0);
swk
}
}

View File

@@ -1,44 +1,86 @@
use crate::ciphertext::GadgetCiphertext;
use crate::elem::Elem; use crate::elem::Elem;
use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes}; use crate::encryptor::{encrypt_rlwe_sk_thread_safe, encrypt_rlwe_sk_tmp_bytes};
use crate::parameters::Parameters; use crate::parameters::Parameters;
use base2k::{Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx}; use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VmpPMat, VmpPMatOps};
use sampling::source::Source; use sampling::source::Source;
pub struct SecretKey(pub Scalar); pub struct SecretKey(pub Scalar);
impl SecretKey { impl SecretKey {
pub fn new_ternary_prob(module: &Module, limbs: usize, prob: f64, source: &mut Source) -> Self { pub fn new(params: &Module) -> Self {
let mut sk: Scalar = Scalar::new(module.n()); SecretKey(Scalar::new(params.n()))
sk.fill_ternary_prob(prob, source); }
SecretKey(sk)
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
self.0.fill_ternary_prob(prob, source);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
self.0.fill_ternary_hw(hw, source);
}
pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) {
module.svp_prepare(sk_ppol, &self.0)
} }
} }
pub struct PublicKey(pub Elem); pub struct PublicKey(pub Elem);
impl PublicKey { impl PublicKey {
pub fn new( pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey {
params: &Parameters, PublicKey(Elem::new(module, log_base2k, log_q, 1, 0))
}
pub fn gen_thread_safe(
&mut self,
module: &Module,
sk: &SvpPPol, sk: &SvpPPol,
xe: f64,
xa_source: &mut Source, xa_source: &mut Source,
xe_source: &mut Source, xe_source: &mut Source,
tmp_bytes: &mut [u8], tmp_bytes: &mut [u8],
) -> Self { ) {
let mut pk: Elem = Elem::new(params.n(), params.log_base2k(), params.log_qp(), 1); encrypt_rlwe_sk_thread_safe(
encrypt_rlwe_sk( module,
params.module(), &mut self.0,
&mut pk,
None, None,
sk, sk,
xa_source, xa_source,
xe_source, xe_source,
params.xe(), xe,
tmp_bytes, tmp_bytes,
); );
PublicKey(pk)
} }
pub fn new_tmp_bytes(params: &Parameters) -> usize { pub fn gen_thread_safe_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
encrypt_rlwe_sk_tmp_bytes(params.module(), params.limbs_qp()) encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
}
}
pub struct SwitchingKey(GadgetCiphertext);
impl SwitchingKey {
pub fn new(
module: &Module,
log_base2k: usize,
rows: usize,
log_q: usize,
log_scale: usize,
) -> SwitchingKey {
SwitchingKey(GadgetCiphertext::new(
module, log_base2k, rows, log_q, log_scale,
))
}
pub fn gen_thread_safe(
&mut self,
params: &mut Parameters,
sk_in: &SvpPPol,
sk_out: &SvpPPol,
xa_source: &mut Source,
xe_source: &mut Source,
tmp_bytes: &mut [u8],
) {
} }
} }

View File

@@ -1,6 +1,8 @@
pub mod ciphertext; pub mod ciphertext;
pub mod decryptor;
pub mod elem; pub mod elem;
pub mod encryptor; pub mod encryptor;
pub mod evaluator;
pub mod key_generator; pub mod key_generator;
pub mod keys; pub mod keys;
pub mod parameters; pub mod parameters;

View File

@@ -1,20 +1,35 @@
use crate::ciphertext::Ciphertext; use crate::ciphertext::Ciphertext;
use crate::elem::Elem; use crate::elem::Elem;
use base2k::VecZnx; use crate::parameters::Parameters;
use base2k::{Module, VecZnx};
pub struct Plaintext(pub Elem); pub struct Plaintext(pub Elem);
/*
impl Parameters { impl Parameters {
pub fn new_plaintext(&self, log_q: usize) -> Plaintext { pub fn new_plaintext(&self, log_q: usize) -> Plaintext {
Plaintext(self.new_elem(0, log_q)) Plaintext::new(self.module(), self.log_base2k(), log_q, self.log_scale())
}
pub fn bytes_of_plaintext(&self, log_q: usize) -> usize {
Elem::bytes_of(self.module(), self.log_base2k(), log_q, 0)
}
pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext {
Plaintext(self.elem_from_bytes(log_q, 0, bytes))
} }
} }
*/
impl Plaintext { impl Plaintext {
pub fn new(n: usize, log_base2k: usize, log_q: usize) -> Self { pub fn new(module: &Module, log_base2k: usize, log_q: usize, log_scale: usize) -> Self {
Self(Elem::new(n, log_base2k, log_q, 0)) Self(Elem::new(module, log_base2k, log_q, 0, log_scale))
}
pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize {
Elem::bytes_of(module, log_base2k, log_q, 0)
}
pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
Self(Elem::from_bytes(module, log_base2k, log_q, 0, bytes))
} }
pub fn n(&self) -> usize { pub fn n(&self) -> usize {
@@ -45,6 +60,10 @@ impl Plaintext {
self.0.log_base2k() self.0.log_base2k()
} }
pub fn log_scale(&self) -> usize {
self.0.log_scale()
}
pub fn as_ciphertext(&self) -> Ciphertext { pub fn as_ciphertext(&self) -> Ciphertext {
unsafe { Ciphertext(std::ptr::read(&self.0)) } unsafe { Ciphertext(std::ptr::read(&self.0)) }
} }