From 783a763ac92671a4b37eecd75fab0684270e3106 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 29 Jan 2025 10:32:10 +0100 Subject: [PATCH] added vmp --- base2k/examples/vector_matrix_product.rs | 60 +++++++++ base2k/src/lib.rs | 8 +- base2k/src/vec_znx.rs | 31 +++++ base2k/src/vec_znx_big_arithmetic.rs | 2 +- base2k/src/vector_matrix_product.rs | 163 +++++++++++++++++++++++ 5 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 base2k/examples/vector_matrix_product.rs create mode 100644 base2k/src/vector_matrix_product.rs diff --git a/base2k/examples/vector_matrix_product.rs b/base2k/examples/vector_matrix_product.rs new file mode 100644 index 0000000..7c4bccc --- /dev/null +++ b/base2k/examples/vector_matrix_product.rs @@ -0,0 +1,60 @@ +use base2k::{Matrix3D, Module, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, FFT64}; + +fn main() { + let log_n = 4; + let n = 1 << log_n; + + let module: Module = Module::new::(n); + let log_base2k: usize = 15; + let log_q: usize = 60; + let limbs: usize = (log_q + log_base2k - 1) / log_base2k; + + let rows: usize = limbs + 1; + let cols: usize = limbs + 1; + + // Maximum size of the byte scratch needed + let tmp_bytes: usize = module.vmp_prepare_contiguous_tmp_bytes(rows, cols) + | module.vmp_apply_dft_to_dft_tmp_bytes(limbs, limbs, rows, cols); + + let mut buf: Vec = vec![0; tmp_bytes]; + + let mut a_values: Vec = vec![i64::default(); n]; + a_values[1] = (1 << log_base2k) + 1; + + let mut a: VecZnx = module.new_vec_znx(log_base2k, log_q); + a.set_i64(&a_values, 32); + a.normalize(&mut buf); + + (0..a.limbs()).for_each(|i| println!("{}: {:?}", i, a.at(i))); + + let mut b_mat: Matrix3D = Matrix3D::new(rows, cols, n); + + (0..rows).for_each(|i| { + b_mat.at_mut(i, i)[0] = ((1 << 15) + 1) << 15; + }); + + println!(); + (0..rows).for_each(|i| { + (0..cols).for_each(|j| println!("{} {}: {:?}", i, j, b_mat.at(i, j))); + println!(); + }); + + let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); + module.vmp_prepare_contiguous(&mut vmp_pmat, &b_mat.data, &mut buf); + + let mut c_dft: VecZnxDft = module.new_vec_znx_dft(limbs); + module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf); + + let mut c_big: VecZnxBig = c_dft.as_vec_znx_big(); + module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft, limbs); + + let mut res: VecZnx = module.new_vec_znx(log_base2k, log_q); + module.vec_znx_big_normalize(&mut res, &c_big, &mut buf); + + let mut values_res: Vec = vec![i64::default(); n]; + res.get_i64(&mut values_res); + + (0..res.limbs()).for_each(|i| println!("{}: {:?}", i, res.at(i))); + + println!("{:?}", values_res) +} diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 4708abb..02129e6 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -37,16 +37,20 @@ pub mod scalar_vector_product; #[allow(unused_imports)] pub use scalar_vector_product::*; +pub mod vector_matrix_product; +#[allow(unused_imports)] +pub use vector_matrix_product::*; + pub const GALOISGENERATOR: u64 = 5; #[allow(dead_code)] -fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { +pub fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; let len: usize = data.len() * std::mem::size_of::(); unsafe { std::slice::from_raw_parts_mut(ptr, len) } } -fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] { +pub fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] { let ptr: *mut i64 = data.as_mut_ptr() as *mut i64; let len: usize = data.len() / std::mem::size_of::(); unsafe { std::slice::from_raw_parts_mut(ptr, len) } diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index d8414c9..d2413b0 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -131,6 +131,37 @@ impl VecZnx { } } + pub fn set_single_i64(&mut self, i: usize, value: i64, log_max: usize) { + assert!(i < self.n()); + let k_rem: usize = self.log_base2k - (self.log_q % self.log_base2k); + + // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + k_rem < 63 || k_rem == self.log_base2k { + self.at_mut(self.limbs() - 1)[i] = value; + } else { + let mask: i64 = (1 << self.log_base2k) - 1; + let limbs = self.limbs(); + let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); + (limbs - steps..limbs) + .rev() + .enumerate() + .for_each(|(j, j_rev)| { + self.at_mut(j_rev)[i] = (value >> (j * self.log_base2k)) & mask; + }) + } + + // Case where self.prec % self.k != 0. + if k_rem != self.log_base2k { + let limbs = self.limbs(); + let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); + (limbs - steps..limbs).rev().for_each(|j| { + self.at_mut(j)[i] <<= k_rem; + }) + } + } + pub fn normalize(&mut self, carry: &mut [u8]) { assert!( carry.len() >= self.n * 8, diff --git a/base2k/src/vec_znx_big_arithmetic.rs b/base2k/src/vec_znx_big_arithmetic.rs index 0db64a4..dcca5e3 100644 --- a/base2k/src/vec_znx_big_arithmetic.rs +++ b/base2k/src/vec_znx_big_arithmetic.rs @@ -128,7 +128,7 @@ impl Module { limbs ); assert!( - tmp_bytes.len() <= self.vec_znx_big_normalize_tmp_bytes(), + tmp_bytes.len() >= self.vec_znx_big_normalize_tmp_bytes(), "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", tmp_bytes.len(), self.vec_znx_big_normalize_tmp_bytes() diff --git a/base2k/src/vector_matrix_product.rs b/base2k/src/vector_matrix_product.rs new file mode 100644 index 0000000..9b0f89b --- /dev/null +++ b/base2k/src/vector_matrix_product.rs @@ -0,0 +1,163 @@ +use crate::bindings::{ + new_vmp_pmat, vmp_apply_dft, vmp_apply_dft_tmp_bytes, vmp_apply_dft_to_dft, + vmp_apply_dft_to_dft_tmp_bytes, vmp_pmat_t, vmp_prepare_contiguous, + vmp_prepare_contiguous_tmp_bytes, vmp_prepare_dblptr, +}; +use crate::{Module, VecZnx, VecZnxDft}; + +pub struct VmpPMat(pub *mut vmp_pmat_t, pub usize, pub usize); + +impl VmpPMat { + pub fn rows(&self) -> usize { + self.1 + } + + pub fn cols(&self) -> usize { + self.2 + } +} + +impl Module { + pub fn new_vmp_pmat(&self, rows: usize, cols: usize) -> VmpPMat { + unsafe { VmpPMat(new_vmp_pmat(self.0, rows as u64, cols as u64), rows, cols) } + } + + pub fn vmp_prepare_contiguous_tmp_bytes(&self, rows: usize, cols: usize) -> usize { + unsafe { vmp_prepare_contiguous_tmp_bytes(self.0, rows as u64, cols as u64) as usize } + } + + pub fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]) { + unsafe { + vmp_prepare_contiguous( + self.0, + b.0, + a.as_ptr(), + b.1 as u64, + b.2 as u64, + buf.as_mut_ptr(), + ); + } + } + + pub fn vmp_apply_dft_tmp_bytes( + &self, + c_limbs: usize, + a_limbs: usize, + rows: usize, + cols: usize, + ) -> usize { + unsafe { + vmp_apply_dft_tmp_bytes( + self.0, + c_limbs as u64, + a_limbs as u64, + rows as u64, + cols as u64, + ) as usize + } + } + + pub fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]) { + unsafe { + vmp_apply_dft( + self.0, + c.0, + c.limbs() as u64, + a.as_ptr(), + a.limbs() as u64, + a.n() as u64, + b.0, + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ) + } + } + + pub fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_limbs: usize, + a_limbs: usize, + rows: usize, + cols: usize, + ) -> usize { + unsafe { + vmp_apply_dft_to_dft_tmp_bytes( + self.0, + c_limbs as u64, + a_limbs as u64, + rows as u64, + cols as u64, + ) as usize + } + } + + pub fn vmp_apply_dft_to_dft( + &self, + c: &mut VecZnxDft, + a: &VecZnxDft, + b: &VmpPMat, + buf: &mut [u8], + ) { + unsafe { + vmp_apply_dft_to_dft( + self.0, + c.0, + c.limbs() as u64, + a.0, + a.limbs() as u64, + b.0, + b.rows() as u64, + b.cols() as u64, + buf.as_mut_ptr(), + ) + } + } + + pub fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]) { + unsafe { + vmp_apply_dft_to_dft( + self.0, + b.0, + b.limbs() as u64, + b.0, + b.limbs() as u64, + a.0, + a.rows() as u64, + a.cols() as u64, + buf.as_mut_ptr(), + ) + } + } +} + +pub struct Matrix3D { + pub data: Vec, + pub rows: usize, + pub cols: usize, + pub n: usize, +} + +impl Matrix3D { + pub fn new(rows: usize, cols: usize, n: usize) -> Self { + let size = rows * cols * n; + Self { + data: vec![T::default(); size], + rows, + cols, + n, + } + } + + pub fn at(&self, row: usize, col: usize) -> &[T] { + assert!(row <= self.rows && col <= self.cols); + let idx: usize = col * (self.n * self.rows) + row * self.n; + &self.data[idx..idx + self.n] + } + + pub fn at_mut(&mut self, row: usize, col: usize) -> &mut [T] { + assert!(row <= self.rows && col <= self.cols); + let idx: usize = col * (self.n * self.rows) + row * self.n; + &mut self.data[idx..idx + self.n] + } +}