Added tensor key & associated test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-19 18:06:14 +02:00
parent c5fe07188f
commit 8f2eac4928
12 changed files with 610 additions and 28 deletions

View File

@@ -9,9 +9,9 @@ use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
pub struct ScalarZnx<D> {
data: D,
n: usize,
cols: usize,
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
}
impl<D> ZnxInfos for ScalarZnx<D> {

View File

@@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::ffi::svp;
use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64};
pub struct ScalarZnxDft<D, B: Backend> {
data: D,
@@ -92,6 +92,16 @@ impl<D, B: Backend> ScalarZnxDft<D, B> {
_phantom: PhantomData,
}
}
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B>{
VecZnxDft{
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
@@ -158,3 +168,63 @@ impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}

View File

@@ -1,5 +1,6 @@
use crate::DataView;
use crate::DataViewMut;
use crate::ScalarZnx;
use crate::ZnxSliceSize;
use crate::ZnxZero;
use crate::alloc_aligned;
@@ -128,6 +129,15 @@ impl<D> VecZnx<D> {
size,
}
}
pub fn to_scalar_znx(self) -> ScalarZnx<D>{
debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols);
ScalarZnx{
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
/// Copies the coefficients of `a` on the receiver.

View File

@@ -8,11 +8,11 @@ use crate::{
use std::fmt;
pub struct VecZnxDft<D, B: Backend> {
data: D,
n: usize,
cols: usize,
size: usize,
_phantom: PhantomData<B>,
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
pub(crate) size: usize,
pub(crate) _phantom: PhantomData<B>,
}
impl<D, B: Backend> VecZnxDft<D, B> {