mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added basic key-switching + file formatting
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp::{self, vmp_pmat_t};
|
||||
use crate::{
|
||||
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
|
||||
};
|
||||
use crate::{BACKEND, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
@@ -100,8 +98,7 @@ impl VmpPMat {
|
||||
|
||||
if self.n < 8 {
|
||||
res.copy_from_slice(
|
||||
&self.raw::<T>()[(row + col * self.rows()) * self.n()
|
||||
..(row + col * self.rows()) * (self.n() + 1)],
|
||||
&self.raw::<T>()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)],
|
||||
);
|
||||
} else {
|
||||
(0..self.n >> 3).for_each(|blk| {
|
||||
@@ -120,10 +117,7 @@ impl VmpPMat {
|
||||
if col == (ncols - 1) && (ncols & 1 == 1) {
|
||||
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
&self.raw::<T>()[blk * nrows * ncols * 8
|
||||
+ (col / 2) * (2 * nrows) * 8
|
||||
+ row * 2 * 8
|
||||
+ (col % 2) * 8..]
|
||||
&self.raw::<T>()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,13 +214,7 @@ pub trait VmpPMatOps {
|
||||
/// * `a_cols`: number of cols of the input [VecZnx].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `cols`: number of cols of the input [VmpPMat].
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> usize;
|
||||
fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||
///
|
||||
@@ -288,13 +276,7 @@ pub trait VmpPMatOps {
|
||||
/// * `a_cols`: number of cols of the input [VecZnxDft].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `cols`: number of cols of the input [VmpPMat].
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
c_cols: usize,
|
||||
a_cols: usize,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
) -> usize;
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
@@ -348,13 +330,7 @@ pub trait VmpPMatOps {
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft_add(
|
||||
&self,
|
||||
c: &mut VecZnxDft,
|
||||
a: &VecZnxDft,
|
||||
b: &VmpPMat,
|
||||
buf: &mut [u8],
|
||||
);
|
||||
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
@@ -521,13 +497,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(
|
||||
&self,
|
||||
res_cols: usize,
|
||||
a_cols: usize,
|
||||
gct_rows: usize,
|
||||
gct_cols: usize,
|
||||
) -> usize {
|
||||
fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
@@ -540,9 +510,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -564,9 +532,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -587,13 +553,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
&self,
|
||||
res_cols: usize,
|
||||
a_cols: usize,
|
||||
gct_rows: usize,
|
||||
gct_cols: usize,
|
||||
) -> usize {
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
@@ -605,17 +565,8 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft,
|
||||
a: &VecZnxDft,
|
||||
b: &VmpPMat,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -635,17 +586,8 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_add(
|
||||
&self,
|
||||
c: &mut VecZnxDft,
|
||||
a: &VecZnxDft,
|
||||
b: &VmpPMat,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
|
||||
);
|
||||
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -666,10 +608,7 @@ impl VmpPMatOps for Module {
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len()
|
||||
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
|
||||
);
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
@@ -693,8 +632,7 @@ impl VmpPMatOps for Module {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
|
||||
VecZnxOps, VmpPMat, VmpPMatOps,
|
||||
Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -712,8 +650,7 @@ mod tests {
|
||||
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> =
|
||||
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
||||
|
||||
for row_i in 0..vpmat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
|
||||
Reference in New Issue
Block a user