Added basic key-switching + file formatting

This commit is contained in:
Jean-Philippe Bossuat
2025-04-24 10:43:51 +02:00
parent 4196477300
commit ad6e8169e5
33 changed files with 319 additions and 715 deletions

View File

@@ -1,9 +1,7 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{
alloc_aligned, assert_alignement, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, BACKEND,
};
use crate::{BACKEND, Infos, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
@@ -100,8 +98,7 @@ impl VmpPMat {
if self.n < 8 {
res.copy_from_slice(
&self.raw::<T>()[(row + col * self.rows()) * self.n()
..(row + col * self.rows()) * (self.n() + 1)],
&self.raw::<T>()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)],
);
} else {
(0..self.n >> 3).for_each(|blk| {
@@ -120,10 +117,7 @@ impl VmpPMat {
if col == (ncols - 1) && (ncols & 1 == 1) {
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
} else {
&self.raw::<T>()[blk * nrows * ncols * 8
+ (col / 2) * (2 * nrows) * 8
+ row * 2 * 8
+ (col % 2) * 8..]
&self.raw::<T>()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
}
}
}
@@ -220,13 +214,7 @@ pub trait VmpPMatOps {
/// * `a_cols`: number of cols of the input [VecZnx].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_tmp_bytes(
&self,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize;
fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
///
@@ -288,13 +276,7 @@ pub trait VmpPMatOps {
/// * `a_cols`: number of cols of the input [VecZnxDft].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
c_cols: usize,
a_cols: usize,
rows: usize,
cols: usize,
) -> usize;
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -348,13 +330,7 @@ pub trait VmpPMatOps {
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
buf: &mut [u8],
);
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -521,13 +497,7 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_tmp_bytes(
&self,
res_cols: usize,
a_cols: usize,
gct_rows: usize,
gct_cols: usize,
) -> usize {
fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.ptr,
@@ -540,9 +510,7 @@ impl VmpPMatOps for Module {
}
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -564,9 +532,7 @@ impl VmpPMatOps for Module {
}
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -587,13 +553,7 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
res_cols: usize,
a_cols: usize,
gct_rows: usize,
gct_cols: usize,
) -> usize {
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr,
@@ -605,17 +565,8 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_to_dft(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -635,17 +586,8 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_to_dft_add(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -666,10 +608,7 @@ impl VmpPMatOps for Module {
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols())
);
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
@@ -693,8 +632,7 @@ impl VmpPMatOps for Module {
#[cfg(test)]
mod tests {
use crate::{
alloc_aligned, Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps,
VecZnxOps, VmpPMat, VmpPMatOps,
Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned,
};
use sampling::source::Source;
@@ -712,8 +650,7 @@ mod tests {
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(vpmat_rows, vpmat_cols);
let mut tmp_bytes: Vec<u8> =
alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]);