wip rlwe + some bug fixes in base2k

This commit is contained in:
Jean-Philippe Bossuat
2025-02-11 18:16:09 +01:00
parent ec6968d52a
commit 8f33442d5a
18 changed files with 801 additions and 86 deletions

View File

@@ -38,7 +38,7 @@ fn main() {
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs());
// Applies buf_dft <- s * a
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs());
// Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
@@ -67,11 +67,11 @@ fn main() {
//Decrypt
// buf_big <- a * s
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs());
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs());
// buf_big <- a * s + b
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
module.vec_znx_big_add_small_inplace(&mut buf_big, &b, b.limbs());
// res <- normalize(buf_big)
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);

View File

@@ -56,6 +56,8 @@ impl Encoding for VecZnx {
fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
let limbs: usize = (log_k + log_base2k - 1) / log_base2k;
println!("limbs: {}", limbs);
assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs());
let size: usize = min(data.len(), self.n());
@@ -65,10 +67,10 @@ impl Encoding for VecZnx {
// values on the last limb.
// Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
(0..limbs - 1).for_each(|i| unsafe {
(0..self.limbs()).for_each(|i| unsafe {
znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr());
});
self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]);
self.at_mut(limbs - 1)[..size].copy_from_slice(&data[..size]);
} else {
let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);

View File

@@ -91,6 +91,18 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_row(
module: *const MODULE,
pmat: *mut VMP_PMAT,
row: *const i64,
row_i: u64,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
}

View File

@@ -91,7 +91,7 @@ pub trait SvpPPolOps {
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize);
}
impl SvpPPolOps for Module {
@@ -107,14 +107,13 @@ impl SvpPPolOps for Module {
unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
}
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
let limbs: u64 = b.limbs() as u64;
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize) {
assert!(
c.limbs() as u64 >= limbs,
c.limbs() >= b_limbs,
"invalid c_vector: c_vector.limbs()={} < b.limbs()={}",
c.limbs(),
limbs
b_limbs
);
unsafe { svp::svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) }
unsafe { svp::svp_apply_dft(self.0, c.0, b_limbs as u64, a.0, b.as_ptr(), b_limbs as u64, b.n() as u64) }
}
}

View File

@@ -117,23 +117,22 @@ impl Module {
}
// b <- b + a
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
let limbs: usize = a.limbs();
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx, a_limbs: usize) {
assert!(
b.limbs() >= limbs,
b.limbs() >= a_limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
a_limbs
);
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.0,
b.0,
limbs as u64,
a_limbs as u64,
b.0,
limbs as u64,
a_limbs as u64,
a.as_ptr(),
limbs as u64,
a_limbs as u64,
a.n() as u64,
)
}

View File

@@ -1,7 +1,7 @@
use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::{Module, VecZnxBig};
use crate::{Module, VecZnx, VecZnxBig};
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
@@ -30,6 +30,25 @@ impl Module {
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) }
}
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `limbs`: the number of limbs of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
assert!(
bytes.len() >= self.bytes_of_vec_znx_dft(limbs),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(),
self.bytes_of_vec_znx_dft(limbs)
);
VecZnxDft::from_bytes(limbs, bytes)
}
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize {
@@ -52,6 +71,29 @@ impl Module {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.limbs < a_limbs
pub fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize) {
assert!(
b.limbs() >= a_limbs,
"invalid a_limbs: b.limbs()={} < a_limbs={}",
b.limbs(),
a_limbs
);
unsafe {
vec_znx_dft::vec_znx_dft(
self.0,
b.0,
a_limbs as u64,
a.as_ptr(),
a_limbs as u64,
a.n as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
pub fn vec_znx_idft(
&self,

View File

@@ -169,6 +169,38 @@ pub trait VmpPMatOps {
/// ```
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
/// * `row_i`: the index of the row to prepare.
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
/// /// # Example
/// ```
/// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
/// use std::cmp::min;
///
/// let n: usize = 1024;
/// let module: Module = Module::new::<FFT64>(n);
/// let rows: usize = 5;
/// let cols: usize = 6;
///
/// let vecznx: module.new_vec_znx(cols);
///
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
///
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
/// module.vmp_prepare_row(&mut vmp_pmat, &vecznx, 0, &mut buf);
///
/// vmp_pmat.free();
/// module.free();
/// ```
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
/// # Arguments
@@ -404,6 +436,20 @@ impl VmpPMatOps for Module {
}
}
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) {
unsafe {
vmp::vmp_prepare_row(
self.0,
b.data(),
a.data.as_ptr(),
row_i as u64,
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
);
}
}
fn vmp_apply_dft_tmp_bytes(
&self,
c_limbs: usize,