This commit is contained in:
Jean-Philippe Bossuat
2025-04-25 11:04:17 +02:00
parent 79eee00974
commit 0cca56755b
7 changed files with 130 additions and 128 deletions

View File

@@ -5,27 +5,25 @@ use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each row of the [VmpPMat] can be seen as a [VecZnxDft].
/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft].
///
/// The backend array of [VmpPMat] is allocate in C,
/// and thus must be manually freed.
///
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat].
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat].
/// See the trait [VmpPMatOps] for additional information.
pub struct VmpPMat {
/// Raw data, is empty if borrowing scratch space.
data: Vec<u8>,
/// Pointer to data. Can point to scratch space.
ptr: *mut u8,
/// The number of [VecZnxDft].
/// The size of the decomposition basis (i.e. nb. [VecZnxDft]).
rows: usize,
/// The number of cols in each [VecZnxDft].
/// The size of each [VecZnxDft].
cols: usize,
/// The ring degree of each [VecZnxDft].
n: usize,
/// The number of stacked [VmpPMat], must be a square.
size: usize,
/// The memory layout of the stacked [VmpPMat].
/// 1nd dim: the number of stacked [VecZnxDft] per decomposition basis (row-dimension).
/// A value greater than one enables to compute a sum of [VecZnx] x [VmpPMat].
/// 2st dim: the number of stacked [VecZnxDft] (col-dimension).
/// A value greater than one enables to compute multiple [VecZnx] x [VmpPMat] in parallel.
layout: LAYOUT,
/// The backend fft or ntt.
backend: BACKEND,
@@ -531,6 +529,7 @@ impl VmpPMatOps for Module {
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
assert_eq!(a.size()*a.size(), b.size());
}
unsafe {
vmp::vmp_apply_dft(
@@ -539,7 +538,7 @@ impl VmpPMatOps for Module {
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
(a.n()*a.size()) as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
@@ -561,7 +560,7 @@ impl VmpPMatOps for Module {
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
(a.n()*a.size()) as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,