mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip rlwe + some bug fixes in base2k
This commit is contained in:
@@ -38,7 +38,7 @@ fn main() {
|
||||
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(a.limbs());
|
||||
|
||||
// Applies buf_dft <- s * a
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs());
|
||||
|
||||
// Alias scratch space
|
||||
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
||||
@@ -67,11 +67,11 @@ fn main() {
|
||||
//Decrypt
|
||||
|
||||
// buf_big <- a * s
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a, a.limbs());
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs());
|
||||
|
||||
// buf_big <- a * s + b
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, &b, b.limbs());
|
||||
|
||||
// res <- normalize(buf_big)
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);
|
||||
|
||||
Submodule base2k/spqlios-arithmetic updated: 83555cc664...546113166e
@@ -56,6 +56,8 @@ impl Encoding for VecZnx {
|
||||
fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
|
||||
let limbs: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
|
||||
println!("limbs: {}", limbs);
|
||||
|
||||
assert!(limbs <= self.limbs(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.limbs()={}", limbs, self.limbs());
|
||||
|
||||
let size: usize = min(data.len(), self.n());
|
||||
@@ -65,10 +67,10 @@ impl Encoding for VecZnx {
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||
(0..limbs - 1).for_each(|i| unsafe {
|
||||
(0..self.limbs()).for_each(|i| unsafe {
|
||||
znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr());
|
||||
});
|
||||
self.at_mut(self.limbs() - 1)[..size].copy_from_slice(&data[..size]);
|
||||
self.at_mut(limbs - 1)[..size].copy_from_slice(&data[..size]);
|
||||
} else {
|
||||
let mask: i64 = (1 << log_base2k) - 1;
|
||||
let steps: usize = min(limbs, (log_max + log_base2k - 1) / log_base2k);
|
||||
|
||||
@@ -91,6 +91,18 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_row(
|
||||
module: *const MODULE,
|
||||
pmat: *mut VMP_PMAT,
|
||||
row: *const i64,
|
||||
row_i: u64,
|
||||
nrows: u64,
|
||||
ncols: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ pub trait SvpPPolOps {
|
||||
|
||||
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
||||
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize);
|
||||
}
|
||||
|
||||
impl SvpPPolOps for Module {
|
||||
@@ -107,14 +107,13 @@ impl SvpPPolOps for Module {
|
||||
unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
|
||||
}
|
||||
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
|
||||
let limbs: u64 = b.limbs() as u64;
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_limbs: usize) {
|
||||
assert!(
|
||||
c.limbs() as u64 >= limbs,
|
||||
c.limbs() >= b_limbs,
|
||||
"invalid c_vector: c_vector.limbs()={} < b.limbs()={}",
|
||||
c.limbs(),
|
||||
limbs
|
||||
b_limbs
|
||||
);
|
||||
unsafe { svp::svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) }
|
||||
unsafe { svp::svp_apply_dft(self.0, c.0, b_limbs as u64, a.0, b.as_ptr(), b_limbs as u64, b.n() as u64) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,23 +117,22 @@ impl Module {
|
||||
}
|
||||
|
||||
// b <- b + a
|
||||
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
|
||||
let limbs: usize = a.limbs();
|
||||
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx, a_limbs: usize) {
|
||||
assert!(
|
||||
b.limbs() >= limbs,
|
||||
b.limbs() >= a_limbs,
|
||||
"invalid c_vector: b.limbs()={} < a.limbs()={}",
|
||||
b.limbs(),
|
||||
limbs
|
||||
a_limbs
|
||||
);
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_add_small(
|
||||
self.0,
|
||||
b.0,
|
||||
limbs as u64,
|
||||
a_limbs as u64,
|
||||
b.0,
|
||||
limbs as u64,
|
||||
a_limbs as u64,
|
||||
a.as_ptr(),
|
||||
limbs as u64,
|
||||
a_limbs as u64,
|
||||
a.n() as u64,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||
use crate::{Module, VecZnxBig};
|
||||
use crate::{Module, VecZnx, VecZnxBig};
|
||||
|
||||
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
|
||||
|
||||
@@ -30,6 +30,25 @@ impl Module {
|
||||
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) }
|
||||
}
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `limbs`: the number of limbs of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
pub fn new_vec_znx_from_bytes(&self, limbs: usize, bytes: &mut [u8]) -> VecZnxDft {
|
||||
assert!(
|
||||
bytes.len() >= self.bytes_of_vec_znx_dft(limbs),
|
||||
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
|
||||
bytes.len(),
|
||||
self.bytes_of_vec_znx_dft(limbs)
|
||||
);
|
||||
VecZnxDft::from_bytes(limbs, bytes)
|
||||
}
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize {
|
||||
@@ -52,6 +71,29 @@ impl Module {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.0) as usize }
|
||||
}
|
||||
|
||||
/// b <- DFT(a)
|
||||
///
|
||||
/// # Panics
|
||||
/// If b.limbs < a_limbs
|
||||
pub fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_limbs: usize) {
|
||||
assert!(
|
||||
b.limbs() >= a_limbs,
|
||||
"invalid a_limbs: b.limbs()={} < a_limbs={}",
|
||||
b.limbs(),
|
||||
a_limbs
|
||||
);
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.0,
|
||||
b.0,
|
||||
a_limbs as u64,
|
||||
a.as_ptr(),
|
||||
a_limbs as u64,
|
||||
a.n as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
pub fn vec_znx_idft(
|
||||
&self,
|
||||
|
||||
@@ -169,6 +169,38 @@ pub trait VmpPMatOps {
|
||||
/// ```
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &Vec<VecZnx>, buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a vector of [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
/// /// # Example
|
||||
/// ```
|
||||
/// use base2k::{Module, FFT64, Matrix3D, VmpPMat, VmpPMatOps, VecZnx, VecZnxOps, Free};
|
||||
/// use std::cmp::min;
|
||||
///
|
||||
/// let n: usize = 1024;
|
||||
/// let module: Module = Module::new::<FFT64>(n);
|
||||
/// let rows: usize = 5;
|
||||
/// let cols: usize = 6;
|
||||
///
|
||||
/// let vecznx: module.new_vec_znx(cols);
|
||||
///
|
||||
/// let mut buf: Vec<u8> = vec![u8::default(); module.vmp_prepare_tmp_bytes(rows, cols)];
|
||||
///
|
||||
/// let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||
/// module.vmp_prepare_row(&mut vmp_pmat, &vecznx, 0, &mut buf);
|
||||
///
|
||||
/// vmp_pmat.free();
|
||||
/// module.free();
|
||||
/// ```
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -404,6 +436,20 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &VecZnx, row_i: usize, buf: &mut [u8]) {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row(
|
||||
self.0,
|
||||
b.data(),
|
||||
a.data.as_ptr(),
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
buf.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_limbs: usize,
|
||||
|
||||
Reference in New Issue
Block a user