mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
added vmp
This commit is contained in:
60
base2k/examples/vector_matrix_product.rs
Normal file
60
base2k/examples/vector_matrix_product.rs
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let log_n = 4;
|
||||||
|
let n = 1 << log_n;
|
||||||
|
|
||||||
|
let module: Module = Module::new::<FFT64>(n);
|
||||||
|
let log_base2k: usize = 15;
|
||||||
|
let log_q: usize = 60;
|
||||||
|
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||||
|
|
||||||
|
let rows: usize = limbs + 1;
|
||||||
|
let cols: usize = limbs + 1;
|
||||||
|
|
||||||
|
// Maximum size of the byte scratch needed
|
||||||
|
let tmp_bytes: usize = module.vmp_prepare_contiguous_tmp_bytes(rows, cols)
|
||||||
|
| module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols);
|
||||||
|
|
||||||
|
let mut buf: Vec<u8> = vec![0; tmp_bytes];
|
||||||
|
|
||||||
|
let mut a_values: Vec<i64> = vec![i64::default(); n];
|
||||||
|
a_values[1] = (1 << log_base2k) + 1;
|
||||||
|
|
||||||
|
let mut a: VecZnx = module.new_vec_znx(log_base2k, log_q);
|
||||||
|
a.set_i64(&a_values, 32);
|
||||||
|
a.normalize(&mut buf);
|
||||||
|
|
||||||
|
(0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i)));
|
||||||
|
|
||||||
|
let mut b_mat: Matrix3D<i64> = Matrix3D::new(rows, cols, n);
|
||||||
|
|
||||||
|
(0..rows).for_each(|i| {
|
||||||
|
b_mat.at_mut(i, i)[0] = ((1 << 15) + 1) << 15;
|
||||||
|
});
|
||||||
|
|
||||||
|
println!();
|
||||||
|
(0..rows).for_each(|i| {
|
||||||
|
(0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j)));
|
||||||
|
println!();
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
|
||||||
|
module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf);
|
||||||
|
|
||||||
|
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
|
||||||
|
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
|
||||||
|
|
||||||
|
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, limbs);
|
||||||
|
|
||||||
|
let mut res: VecZnx = module.new_vec_znx(log_base2k, log_q);
|
||||||
|
module.vec_znx_big_normalize(&mut res, &c_big, &mut buf);
|
||||||
|
|
||||||
|
let mut values_res: Vec<i64> = vec![i64::default(); n];
|
||||||
|
res.get_i64(&mut values_res);
|
||||||
|
|
||||||
|
(0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i)));
|
||||||
|
|
||||||
|
println!("{:?}", values_res)
|
||||||
|
}
|
||||||
@@ -37,16 +37,20 @@ pub mod scalar_vector_product;
|
|||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
pub use scalar_vector_product::*;
|
pub use scalar_vector_product::*;
|
||||||
|
|
||||||
|
pub mod vector_matrix_product;
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
pub use vector_matrix_product::*;
|
||||||
|
|
||||||
pub const GALOISGENERATOR: u64 = 5;
|
pub const GALOISGENERATOR: u64 = 5;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] {
|
pub fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] {
|
||||||
let ptr: *mut u8 = data.as_mut_ptr() as *mut u8;
|
let ptr: *mut u8 = data.as_mut_ptr() as *mut u8;
|
||||||
let len: usize = data.len() * std::mem::size_of::<u64>();
|
let len: usize = data.len() * std::mem::size_of::<u64>();
|
||||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] {
|
pub fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] {
|
||||||
let ptr: *mut i64 = data.as_mut_ptr() as *mut i64;
|
let ptr: *mut i64 = data.as_mut_ptr() as *mut i64;
|
||||||
let len: usize = data.len() / std::mem::size_of::<i64>();
|
let len: usize = data.len() / std::mem::size_of::<i64>();
|
||||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||||
|
|||||||
@@ -131,6 +131,37 @@ impl VecZnx {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_single_i64(&mut self, i: usize, value: i64, log_max: usize) {
|
||||||
|
assert!(i < self.n());
|
||||||
|
let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k);
|
||||||
|
|
||||||
|
// If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||||
|
// values on the last limb.
|
||||||
|
// Else we decompose values base2k.
|
||||||
|
if log_max + k_rem < 63 || k_rem == self.log_base2k {
|
||||||
|
self.at_mut(self.limbs() - 1)[i] = value;
|
||||||
|
} else {
|
||||||
|
let mask: i64 = (1 << self.log_base2k) - 1;
|
||||||
|
let limbs = self.limbs();
|
||||||
|
let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k);
|
||||||
|
(limbs - steps..limbs)
|
||||||
|
.rev()
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(j, j_rev)| {
|
||||||
|
self.at_mut(j_rev)[i] = (value >> (j * self.log_base2k)) & mask;
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case where self.prec % self.k != 0.
|
||||||
|
if k_rem != self.log_base2k {
|
||||||
|
let limbs = self.limbs();
|
||||||
|
let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k);
|
||||||
|
(limbs - steps..limbs).rev().for_each(|j| {
|
||||||
|
self.at_mut(j)[i] <<= k_rem;
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn normalize(&mut self, carry: &mut [u8]) {
|
pub fn normalize(&mut self, carry: &mut [u8]) {
|
||||||
assert!(
|
assert!(
|
||||||
carry.len() >= self.n * 8,
|
carry.len() >= self.n * 8,
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ impl Module {
|
|||||||
limbs
|
limbs
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
tmp_bytes.len() <= self.vec_znx_big_normalize_tmp_bytes(),
|
tmp_bytes.len() >= self.vec_znx_big_normalize_tmp_bytes(),
|
||||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
||||||
tmp_bytes.len(),
|
tmp_bytes.len(),
|
||||||
self.vec_znx_big_normalize_tmp_bytes()
|
self.vec_znx_big_normalize_tmp_bytes()
|
||||||
|
|||||||
163
base2k/src/vector_matrix_product.rs
Normal file
163
base2k/src/vector_matrix_product.rs
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
use crate::bindings::{
|
||||||
|
new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft,
|
||||||
|
vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous,
|
||||||
|
vmp_prepare_contiguous_tmp_bytes, vmp_prepare_dblptr,
|
||||||
|
};
|
||||||
|
use crate::{Module, VecZnx, VecZnxDft};
|
||||||
|
|
||||||
|
pub struct VmpPMat(pub *mut vmp_pmat_t, pub usize, pub usize);
|
||||||
|
|
||||||
|
impl VmpPMat {
|
||||||
|
pub fn rows(&self) -> usize {
|
||||||
|
self.1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cols(&self) -> usize {
|
||||||
|
self.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat {
|
||||||
|
unsafe { VmpPMat(new_vmp_pmat(self.0, rows as u64, cols as u64), rows, cols) }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize {
|
||||||
|
unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
|
||||||
|
unsafe {
|
||||||
|
vmp_prepare_contiguous(
|
||||||
|
self.0,
|
||||||
|
b.0,
|
||||||
|
a.as_ptr(),
|
||||||
|
b.1 as u64,
|
||||||
|
b.2 as u64,
|
||||||
|
buf.as_mut_ptr(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_apply_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_limbs: usize,
|
||||||
|
a_limbs: usize,
|
||||||
|
rows: usize,
|
||||||
|
cols: usize,
|
||||||
|
) -> usize {
|
||||||
|
unsafe {
|
||||||
|
vmp_apply_dft_tmp_bytes(
|
||||||
|
self.0,
|
||||||
|
c_limbs as u64,
|
||||||
|
a_limbs as u64,
|
||||||
|
rows as u64,
|
||||||
|
cols as u64,
|
||||||
|
) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) {
|
||||||
|
unsafe {
|
||||||
|
vmp_apply_dft(
|
||||||
|
self.0,
|
||||||
|
c.0,
|
||||||
|
c.limbs() as u64,
|
||||||
|
a.as_ptr(),
|
||||||
|
a.limbs() as u64,
|
||||||
|
a.n() as u64,
|
||||||
|
b.0,
|
||||||
|
b.rows() as u64,
|
||||||
|
b.cols() as u64,
|
||||||
|
buf.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
&self,
|
||||||
|
c_limbs: usize,
|
||||||
|
a_limbs: usize,
|
||||||
|
rows: usize,
|
||||||
|
cols: usize,
|
||||||
|
) -> usize {
|
||||||
|
unsafe {
|
||||||
|
vmp_apply_dft_to_dft_tmp_bytes(
|
||||||
|
self.0,
|
||||||
|
c_limbs as u64,
|
||||||
|
a_limbs as u64,
|
||||||
|
rows as u64,
|
||||||
|
cols as u64,
|
||||||
|
) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_apply_dft_to_dft(
|
||||||
|
&self,
|
||||||
|
c: &mut VecZnxDft,
|
||||||
|
a: &VecZnxDft,
|
||||||
|
b: &VmpPMat,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) {
|
||||||
|
unsafe {
|
||||||
|
vmp_apply_dft_to_dft(
|
||||||
|
self.0,
|
||||||
|
c.0,
|
||||||
|
c.limbs() as u64,
|
||||||
|
a.0,
|
||||||
|
a.limbs() as u64,
|
||||||
|
b.0,
|
||||||
|
b.rows() as u64,
|
||||||
|
b.cols() as u64,
|
||||||
|
buf.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
|
||||||
|
unsafe {
|
||||||
|
vmp_apply_dft_to_dft(
|
||||||
|
self.0,
|
||||||
|
b.0,
|
||||||
|
b.limbs() as u64,
|
||||||
|
b.0,
|
||||||
|
b.limbs() as u64,
|
||||||
|
a.0,
|
||||||
|
a.rows() as u64,
|
||||||
|
a.cols() as u64,
|
||||||
|
buf.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Matrix3D<T> {
|
||||||
|
pub data: Vec<T>,
|
||||||
|
pub rows: usize,
|
||||||
|
pub cols: usize,
|
||||||
|
pub n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Default + Clone> Matrix3D<T> {
|
||||||
|
pub fn new(rows: usize, cols: usize, n: usize) -> Self {
|
||||||
|
let size = rows * cols * n;
|
||||||
|
Self {
|
||||||
|
data: vec![T::default(); size],
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
n,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn at(&self, row: usize, col: usize) -> &[T] {
|
||||||
|
assert!(row <= self.rows && col <= self.cols);
|
||||||
|
let idx: usize = col * (self.n * self.rows) + row * self.n;
|
||||||
|
&self.data[idx..idx + self.n]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] {
|
||||||
|
assert!(row <= self.rows && col <= self.cols);
|
||||||
|
let idx: usize = col * (self.n * self.rows) + row * self.n;
|
||||||
|
&mut self.data[idx..idx + self.n]
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user