Add bivariate convolution

This commit is contained in:
Jean-Philippe Bossuat
2025-10-23 19:00:26 +02:00
parent 9bb6256fc4
commit af1c98c2c4
18 changed files with 454 additions and 26 deletions

View File

@@ -4,6 +4,7 @@ use std::{
ptr::NonNull,
};
use bytemuck::Pod;
use rand_distr::num_traits::Zero;
use crate::{
@@ -13,8 +14,8 @@ use crate::{
#[allow(clippy::missing_safety_doc)]
pub trait Backend: Sized {
type ScalarBig: Copy + Zero + Display + Debug;
type ScalarPrep: Copy + Zero + Display + Debug;
type ScalarBig: Copy + Zero + Display + Debug + Pod;
type ScalarPrep: Copy + Zero + Display + Debug + Pod;
type Handle: 'static;
fn layout_prep_word_count() -> usize;
fn layout_big_word_count() -> usize;

View File

@@ -6,8 +6,8 @@ use std::{
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ScalarZnx, ToOwnedDeep, WriterTo,
ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
@@ -25,6 +25,26 @@ pub struct VecZnx<D: Data> {
pub max_size: usize,
}
impl<D: DataRef> VecZnx<D> {
pub fn as_scalar_znx_ref(&self, col: usize, limb: usize) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: bytemuck::cast_slice(self.at(col, limb)),
n: self.n,
cols: 1,
}
}
}
impl<D: DataMut> VecZnx<D> {
pub fn as_scalar_znx_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
n: self.n,
cols: 1,
data: bytemuck::cast_slice_mut(self.at_mut(col, limb)),
}
}
}
impl<D: Data + Default> Default for VecZnx<D> {
fn default() -> Self {
Self {

View File

@@ -4,6 +4,7 @@ use crate::{
layouts::{Backend, Data, DataMut, DataRef},
source::Source,
};
use bytemuck::Pod;
use rand_distr::num_traits::Zero;
pub trait ZnxInfos {
@@ -50,7 +51,7 @@ pub trait DataViewMut: DataView {
}
pub trait ZnxView: ZnxInfos + DataView<D: DataRef> {
type Scalar: Copy + Zero + Display + Debug;
type Scalar: Copy + Zero + Display + Debug + Pod;
/// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar {