mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip
This commit is contained in:
@@ -5,27 +5,25 @@ use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
/// Each row of the [VmpPMat] can be seen as a [VecZnxDft].
|
||||
/// Each col of the [VmpPMat] can be seen as a collection of [VecZnxDft].
|
||||
///
|
||||
/// The backend array of [VmpPMat] is allocate in C,
|
||||
/// and thus must be manually freed.
|
||||
///
|
||||
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat].
|
||||
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [VmpPMat].
|
||||
/// See the trait [VmpPMatOps] for additional information.
|
||||
pub struct VmpPMat {
|
||||
/// Raw data, is empty if borrowing scratch space.
|
||||
data: Vec<u8>,
|
||||
/// Pointer to data. Can point to scratch space.
|
||||
ptr: *mut u8,
|
||||
/// The number of [VecZnxDft].
|
||||
/// The size of the decomposition basis (i.e. nb. [VecZnxDft]).
|
||||
rows: usize,
|
||||
/// The number of cols in each [VecZnxDft].
|
||||
/// The size of each [VecZnxDft].
|
||||
cols: usize,
|
||||
/// The ring degree of each [VecZnxDft].
|
||||
n: usize,
|
||||
/// The number of stacked [VmpPMat], must be a square.
|
||||
size: usize,
|
||||
/// The memory layout of the stacked [VmpPMat].
|
||||
/// 1nd dim: the number of stacked [VecZnxDft] per decomposition basis (row-dimension).
|
||||
/// A value greater than one enables to compute a sum of [VecZnx] x [VmpPMat].
|
||||
/// 2st dim: the number of stacked [VecZnxDft] (col-dimension).
|
||||
/// A value greater than one enables to compute multiple [VecZnx] x [VmpPMat] in parallel.
|
||||
layout: LAYOUT,
|
||||
/// The backend fft or ntt.
|
||||
backend: BACKEND,
|
||||
@@ -531,6 +529,7 @@ impl VmpPMatOps for Module {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
assert_eq!(a.size()*a.size(), b.size());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft(
|
||||
@@ -539,7 +538,7 @@ impl VmpPMatOps for Module {
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
(a.n()*a.size()) as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
@@ -561,7 +560,7 @@ impl VmpPMatOps for Module {
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
(a.n()*a.size()) as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
|
||||
Reference in New Issue
Block a user