added vmp

This commit is contained in:
Jean-Philippe Bossuat
2025-01-29 10:32:10 +01:00
parent 6fcd5c743d
commit 783a763ac9
5 changed files with 261 additions and 3 deletions

View File

@@ -0,0 +1,60 @@
use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64};
fn main() {
let log_n = 4;
let n = 1 << log_n;
let module: Module = Module::new::<FFT64>(n);
let log_base2k: usize = 15;
let log_q: usize = 60;
let limbs: usize = (log_q + log_base2k - 1) / log_base2k;
let rows: usize = limbs + 1;
let cols: usize = limbs + 1;
// Maximum size of the byte scratch needed
let tmp_bytes: usize = module.vmp_prepare_contiguous_tmp_bytes(rows, cols)
| module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols);
let mut buf: Vec<u8> = vec![0; tmp_bytes];
let mut a_values: Vec<i64> = vec![i64::default(); n];
a_values[1] = (1 << log_base2k) + 1;
let mut a: VecZnx = module.new_vec_znx(log_base2k, log_q);
a.set_i64(&a_values, 32);
a.normalize(&mut buf);
(0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i)));
let mut b_mat: Matrix3D<i64> = Matrix3D::new(rows, cols, n);
(0..rows).for_each(|i| {
b_mat.at_mut(i, i)[0] = ((1 << 15) + 1) << 15;
});
println!();
(0..rows).for_each(|i| {
(0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j)));
println!();
});
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf);
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs);
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, limbs);
let mut res: VecZnx = module.new_vec_znx(log_base2k, log_q);
module.vec_znx_big_normalize(&mut res, &c_big, &mut buf);
let mut values_res: Vec<i64> = vec![i64::default(); n];
res.get_i64(&mut values_res);
(0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i)));
println!("{:?}", values_res)
}

View File

@@ -37,16 +37,20 @@ pub mod scalar_vector_product;
#[allow(unused_imports)] #[allow(unused_imports)]
pub use scalar_vector_product::*; pub use scalar_vector_product::*;
pub mod vector_matrix_product;
#[allow(unused_imports)]
pub use vector_matrix_product::*;
pub const GALOISGENERATOR: u64 = 5; pub const GALOISGENERATOR: u64 = 5;
#[allow(dead_code)] #[allow(dead_code)]
fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { pub fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] {
let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; let ptr: *mut u8 = data.as_mut_ptr() as *mut u8;
let len: usize = data.len() * std::mem::size_of::<u64>(); let len: usize = data.len() * std::mem::size_of::<u64>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) } unsafe { std::slice::from_raw_parts_mut(ptr, len) }
} }
fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] { pub fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] {
let ptr: *mut i64 = data.as_mut_ptr() as *mut i64; let ptr: *mut i64 = data.as_mut_ptr() as *mut i64;
let len: usize = data.len() / std::mem::size_of::<i64>(); let len: usize = data.len() / std::mem::size_of::<i64>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) } unsafe { std::slice::from_raw_parts_mut(ptr, len) }

View File

@@ -131,6 +131,37 @@ impl VecZnx {
} }
} }
pub fn set_single_i64(&mut self, i: usize, value: i64, log_max: usize) {
assert!(i < self.n());
let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k);
// If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == self.log_base2k {
self.at_mut(self.limbs() - 1)[i] = value;
} else {
let mask: i64 = (1 << self.log_base2k) - 1;
let limbs = self.limbs();
let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k);
(limbs - steps..limbs)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
self.at_mut(j_rev)[i] = (value >> (j * self.log_base2k)) & mask;
})
}
// Case where self.prec % self.k != 0.
if k_rem != self.log_base2k {
let limbs = self.limbs();
let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k);
(limbs - steps..limbs).rev().for_each(|j| {
self.at_mut(j)[i] <<= k_rem;
})
}
}
pub fn normalize(&mut self, carry: &mut [u8]) { pub fn normalize(&mut self, carry: &mut [u8]) {
assert!( assert!(
carry.len() >= self.n * 8, carry.len() >= self.n * 8,

View File

@@ -128,7 +128,7 @@ impl Module {
limbs limbs
); );
assert!( assert!(
tmp_bytes.len() <= self.vec_znx_big_normalize_tmp_bytes(), tmp_bytes.len() >= self.vec_znx_big_normalize_tmp_bytes(),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
tmp_bytes.len(), tmp_bytes.len(),
self.vec_znx_big_normalize_tmp_bytes() self.vec_znx_big_normalize_tmp_bytes()

View File

@@ -0,0 +1,163 @@
use crate::bindings::{
new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft,
vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous,
vmp_prepare_contiguous_tmp_bytes, vmp_prepare_dblptr,
};
use crate::{Module, VecZnx, VecZnxDft};
pub struct VmpPMat(pub *mut vmp_pmat_t, pub usize, pub usize);
impl VmpPMat {
pub fn rows(&self) -> usize {
self.1
}
pub fn cols(&self) -> usize {
self.2
}
}
impl Module {
pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat {
unsafe { VmpPMat(new_vmp_pmat(self.0, rows as u64, cols as u64), rows, cols) }
}
pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize {
unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize }
}
pub fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) {
unsafe {
vmp_prepare_contiguous(
self.0,
b.0,
a.as_ptr(),
b.1 as u64,
b.2 as u64,
buf.as_mut_ptr(),
);
}
}
pub fn vmp_apply_dft_tmp_bytes(
&self,
c_limbs: usize,
a_limbs: usize,
rows: usize,
cols: usize,
) -> usize {
unsafe {
vmp_apply_dft_tmp_bytes(
self.0,
c_limbs as u64,
a_limbs as u64,
rows as u64,
cols as u64,
) as usize
}
}
pub fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) {
unsafe {
vmp_apply_dft(
self.0,
c.0,
c.limbs() as u64,
a.as_ptr(),
a.limbs() as u64,
a.n() as u64,
b.0,
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
)
}
}
pub fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_limbs: usize,
a_limbs: usize,
rows: usize,
cols: usize,
) -> usize {
unsafe {
vmp_apply_dft_to_dft_tmp_bytes(
self.0,
c_limbs as u64,
a_limbs as u64,
rows as u64,
cols as u64,
) as usize
}
}
pub fn vmp_apply_dft_to_dft(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
buf: &mut [u8],
) {
unsafe {
vmp_apply_dft_to_dft(
self.0,
c.0,
c.limbs() as u64,
a.0,
a.limbs() as u64,
b.0,
b.rows() as u64,
b.cols() as u64,
buf.as_mut_ptr(),
)
}
}
pub fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) {
unsafe {
vmp_apply_dft_to_dft(
self.0,
b.0,
b.limbs() as u64,
b.0,
b.limbs() as u64,
a.0,
a.rows() as u64,
a.cols() as u64,
buf.as_mut_ptr(),
)
}
}
}
pub struct Matrix3D<T> {
pub data: Vec<T>,
pub rows: usize,
pub cols: usize,
pub n: usize,
}
impl<T: Default + Clone> Matrix3D<T> {
pub fn new(rows: usize, cols: usize, n: usize) -> Self {
let size = rows * cols * n;
Self {
data: vec![T::default(); size],
rows,
cols,
n,
}
}
pub fn at(&self, row: usize, col: usize) -> &[T] {
assert!(row <= self.rows && col <= self.cols);
let idx: usize = col * (self.n * self.rows) + row * self.n;
&self.data[idx..idx + self.n]
}
pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] {
assert!(row <= self.rows && col <= self.cols);
let idx: usize = col * (self.n * self.rows) + row * self.n;
&mut self.data[idx..idx + self.n]
}
}