Added more serialization tests + generalize methods to any n

This commit is contained in:
Pro7ech
2025-08-13 15:28:52 +02:00
parent 068470783e
commit 940742ce6c
117 changed files with 3658 additions and 2577 deletions

View File

@@ -1,10 +1,10 @@
use backend::{
hal::{
api::{
ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare,
VecZnxAddNormal, VecZnxAlloc, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx,
VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64,
VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
},
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
},
@@ -27,17 +27,18 @@ fn main() {
let mut source: Source = Source::new(seed);
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
let mut s: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(s.cols());
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(n, s.cols());
// s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, 0, &s, 0);
// Allocates a VecZnx with two columns: ct=(0, 0)
let mut ct: VecZnx<Vec<u8>> = module.vec_znx_alloc(
let mut ct: VecZnx<Vec<u8>> = VecZnx::alloc(
module.n(),
2, // Number of columns
ct_size, // Number of small poly per column
);
@@ -45,7 +46,7 @@ fn main() {
// Fill the second column with random values: ct = (0, a)
module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(1, ct_size);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(n, 1, ct_size);
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
@@ -60,11 +61,12 @@ fn main() {
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(1, ct_size);
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(n, 1, ct_size);
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column
let mut m = module.vec_znx_alloc(
let mut m = VecZnx::alloc(
module.n(),
1, // Number of columns
msg_size, // Number of small polynomials
);
@@ -125,7 +127,7 @@ fn main() {
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
// m + e <- BIG(ct[1] * s + ct[0])
let mut res = module.vec_znx_alloc(1, ct_size);
let mut res = VecZnx::alloc(module.n(), 1, ct_size);
module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow());
// have = m * 2^{log_scale} + e

View File

@@ -1,17 +0,0 @@
use crate::hal::layouts::MatZnxOwned;
/// Allocates as [crate::hal::layouts::MatZnx].
pub trait MatZnxAlloc {
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
}
/// Returns the size in bytes to allocate a [crate::hal::layouts::MatZnx].
pub trait MatZnxAllocBytes {
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
/// User must ensure that bytes is memory aligned and that it length is equal to [MatZnxAllocBytes].
pub trait MatZnxFromBytes {
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned;
}

View File

@@ -1,6 +1,4 @@
mod mat_znx;
mod module;
mod scalar_znx;
mod scratch;
mod svp_ppol;
mod vec_znx;
@@ -9,9 +7,7 @@ mod vec_znx_dft;
mod vmp_pmat;
mod znx_base;
pub use mat_znx::*;
pub use module::*;
pub use scalar_znx::*;
pub use scratch::*;
pub use svp_ppol::*;
pub use vec_znx::*;

View File

@@ -1,17 +0,0 @@
use crate::hal::layouts::ScalarZnxOwned;
/// Allocates as [crate::hal::layouts::ScalarZnx].
pub trait ScalarZnxAlloc {
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned;
}
/// Returns the size in bytes to allocate a [crate::hal::layouts::ScalarZnx].
pub trait ScalarZnxAllocBytes {
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::hal::layouts::ScalarZnx].
/// User must ensure that bytes is memory aligned and that it length is equal to [ScalarZnxAllocBytes].
pub trait ScalarZnxFromBytes {
fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
}

View File

@@ -1,4 +1,4 @@
use crate::hal::layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
use crate::hal::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes.
pub trait ScratchOwnedAlloc<B: Backend> {
@@ -27,44 +27,38 @@ pub trait TakeSlice {
/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeScalarZnx<B: Backend> {
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
pub trait TakeScalarZnx {
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeSvpPPol<B: Backend> {
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnx<B: Backend> {
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
pub trait TakeVecZnx {
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxSlice<B: Backend> {
fn take_vec_znx_slice(
&mut self,
len: usize,
module: &Module<B>,
cols: usize,
size: usize,
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
pub trait TakeVecZnxSlice {
fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxBig<B: Backend> {
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeVecZnxDft<B: Backend> {
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
}
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it
@@ -73,7 +67,7 @@ pub trait TakeVecZnxDftSlice<B: Backend> {
fn take_vec_znx_dft_slice(
&mut self,
len: usize,
module: &Module<B>,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self);
@@ -84,7 +78,7 @@ pub trait TakeVecZnxDftSlice<B: Backend> {
pub trait TakeVmpPMat<B: Backend> {
fn take_vmp_pmat(
&mut self,
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
@@ -94,10 +88,10 @@ pub trait TakeVmpPMat<B: Backend> {
/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it
/// as well as a new [Scratch] minus the taken array of bytes.
pub trait TakeMatZnx<B: Backend> {
pub trait TakeMatZnx {
fn take_mat_znx(
&mut self,
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,

View File

@@ -2,18 +2,18 @@ use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, S
/// Allocates as [crate::hal::layouts::SvpPPol].
pub trait SvpPPolAlloc<B: Backend> {
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B>;
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B>;
}
/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol].
pub trait SvpPPolAllocBytes {
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize;
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
pub trait SvpPPolFromBytes<B: Backend> {
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
}
/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol].

View File

@@ -2,33 +2,7 @@ use rand_distr::Distribution;
use rug::Float;
use sampling::source::Source;
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
pub trait VecZnxAlloc {
/// Allocates a new [crate::hal::layouts::VecZnx].
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column.
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned;
}
pub trait VecZnxFromBytes {
/// Instantiates a new [crate::hal::layouts::VecZnx] from a slice of bytes.
/// The returned [crate::hal::layouts::VecZnx] takes ownership of the slice of bytes.
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column.
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
}
pub trait VecZnxAllocBytes {
/// Returns the number of bytes necessary to allocate a new [crate::hal::layouts::VecZnx].
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize;
}
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
pub trait VecZnxNormalizeTmpBytes {
/// Returns the minimum number of bytes necessary for normalization.

View File

@@ -5,18 +5,18 @@ use crate::hal::layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZ
/// Allocates as [crate::hal::layouts::VecZnxBig].
pub trait VecZnxBigAlloc<B: Backend> {
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
}
/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig].
pub trait VecZnxBigAllocBytes {
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize;
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig].
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
pub trait VecZnxBigFromBytes<B: Backend> {
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
}
/// Add a discrete normal distribution on res.

View File

@@ -3,19 +3,19 @@ use crate::hal::layouts::{
};
pub trait VecZnxDftAlloc<B: Backend> {
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
}
pub trait VecZnxDftFromBytes<B: Backend> {
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
}
pub trait VecZnxDftAllocBytes {
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
}
pub trait VecZnxDftToVecZnxBigTmpBytes {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize;
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize;
}
pub trait VecZnxDftToVecZnxBig<B: Backend> {

View File

@@ -3,19 +3,27 @@ use crate::hal::layouts::{
};
pub trait VmpPMatAlloc<B: Backend> {
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
}
pub trait VmpPMatAllocBytes {
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub trait VmpPMatFromBytes<B: Backend> {
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
fn vmp_pmat_from_bytes(
&self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B>;
}
pub trait VmpPrepareTmpBytes {
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub trait VmpPMatPrepare<B: Backend> {
@@ -28,6 +36,7 @@ pub trait VmpPMatPrepare<B: Backend> {
pub trait VmpApplyTmpBytes {
fn vmp_apply_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -72,6 +81,7 @@ pub trait VmpApply<B: Backend> {
pub trait VmpApplyAddTmpBytes {
fn vmp_apply_add_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,

View File

@@ -113,3 +113,7 @@ where
pub trait FillUniform {
fn fill_uniform(&mut self, source: &mut Source);
}
pub trait Reset {
fn reset(&mut self);
}

View File

@@ -1,32 +0,0 @@
use crate::hal::{
api::{MatZnxAlloc, MatZnxAllocBytes, MatZnxFromBytes},
layouts::{Backend, MatZnxOwned, Module},
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
};
impl<B> MatZnxAlloc for Module<B>
where
B: Backend + MatZnxAllocImpl<B>,
{
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
B::mat_znx_alloc_impl(self, rows, cols_in, cols_out, size)
}
}
impl<B> MatZnxAllocBytes for Module<B>
where
B: Backend + MatZnxAllocBytesImpl<B>,
{
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::mat_znx_alloc_bytes_impl(self, rows, cols_in, cols_out, size)
}
}
impl<B> MatZnxFromBytes for Module<B>
where
B: Backend + MatZnxFromBytesImpl<B>,
{
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned {
B::mat_znx_from_bytes_impl(self, rows, cols_in, cols_out, size, bytes)
}
}

View File

@@ -1,6 +1,4 @@
mod mat_znx;
mod module;
mod scalar_znx;
mod scratch;
mod svp_ppol;
mod vec_znx;

View File

@@ -1,23 +0,0 @@
use crate::hal::{
api::{ScalarZnxAlloc, ScalarZnxAllocBytes},
layouts::{Backend, Module, ScalarZnxOwned},
oep::{ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl},
};
impl<B> ScalarZnxAllocBytes for Module<B>
where
B: Backend + ScalarZnxAllocBytesImpl<B>,
{
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize {
B::scalar_znx_alloc_bytes_impl(self.n(), cols)
}
}
impl<B> ScalarZnxAlloc for Module<B>
where
B: Backend + ScalarZnxAllocImpl<B>,
{
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned {
B::scalar_znx_alloc_impl(self.n(), cols)
}
}

View File

@@ -3,9 +3,7 @@ use crate::hal::{
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx,
TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
},
layouts::{
Backend, DataRef, MatZnx, Module, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat,
},
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
oep::{
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl,
TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl,
@@ -58,12 +56,12 @@ where
}
}
impl<B> TakeScalarZnx<B> for Scratch<B>
impl<B> TakeScalarZnx for Scratch<B>
where
B: Backend + TakeScalarZnxImpl<B>,
{
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
B::take_scalar_znx_impl(self, module.n(), cols)
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
B::take_scalar_znx_impl(self, n, cols)
}
}
@@ -71,32 +69,26 @@ impl<B> TakeSvpPPol<B> for Scratch<B>
where
B: Backend + TakeSvpPPolImpl<B>,
{
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
B::take_svp_ppol_impl(self, module.n(), cols)
fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
B::take_svp_ppol_impl(self, n, cols)
}
}
impl<B> TakeVecZnx<B> for Scratch<B>
impl<B> TakeVecZnx for Scratch<B>
where
B: Backend + TakeVecZnxImpl<B>,
{
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
B::take_vec_znx_impl(self, module.n(), cols, size)
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
B::take_vec_znx_impl(self, n, cols, size)
}
}
impl<B> TakeVecZnxSlice<B> for Scratch<B>
impl<B> TakeVecZnxSlice for Scratch<B>
where
B: Backend + TakeVecZnxSliceImpl<B>,
{
fn take_vec_znx_slice(
&mut self,
len: usize,
module: &Module<B>,
cols: usize,
size: usize,
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
B::take_vec_znx_slice_impl(self, len, module.n(), cols, size)
fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
B::take_vec_znx_slice_impl(self, len, n, cols, size)
}
}
@@ -104,8 +96,8 @@ impl<B> TakeVecZnxBig<B> for Scratch<B>
where
B: Backend + TakeVecZnxBigImpl<B>,
{
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
B::take_vec_znx_big_impl(self, module.n(), cols, size)
fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
B::take_vec_znx_big_impl(self, n, cols, size)
}
}
@@ -113,8 +105,8 @@ impl<B> TakeVecZnxDft<B> for Scratch<B>
where
B: Backend + TakeVecZnxDftImpl<B>,
{
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
B::take_vec_znx_dft_impl(self, module.n(), cols, size)
fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
B::take_vec_znx_dft_impl(self, n, cols, size)
}
}
@@ -125,11 +117,11 @@ where
fn take_vec_znx_dft_slice(
&mut self,
len: usize,
module: &Module<B>,
n: usize,
cols: usize,
size: usize,
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
B::take_vec_znx_dft_slice_impl(self, len, module.n(), cols, size)
B::take_vec_znx_dft_slice_impl(self, len, n, cols, size)
}
}
@@ -139,29 +131,29 @@ where
{
fn take_vmp_pmat(
&mut self,
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (VmpPMat<&mut [u8], B>, &mut Self) {
B::take_vmp_pmat_impl(self, module.n(), rows, cols_in, cols_out, size)
B::take_vmp_pmat_impl(self, n, rows, cols_in, cols_out, size)
}
}
impl<B> TakeMatZnx<B> for Scratch<B>
impl<B> TakeMatZnx for Scratch<B>
where
B: Backend + TakeMatZnxImpl<B>,
{
fn take_mat_znx(
&mut self,
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> (MatZnx<&mut [u8]>, &mut Self) {
B::take_mat_znx_impl(self, module.n(), rows, cols_in, cols_out, size)
B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size)
}
}

View File

@@ -8,8 +8,8 @@ impl<B> SvpPPolFromBytes<B> for Module<B>
where
B: Backend + SvpPPolFromBytesImpl<B>,
{
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
B::svp_ppol_from_bytes_impl(n, cols, bytes)
}
}
@@ -17,8 +17,8 @@ impl<B> SvpPPolAlloc<B> for Module<B>
where
B: Backend + SvpPPolAllocImpl<B>,
{
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
B::svp_ppol_alloc_impl(self.n(), cols)
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B> {
B::svp_ppol_alloc_impl(n, cols)
}
}
@@ -26,8 +26,8 @@ impl<B> SvpPPolAllocBytes for Module<B>
where
B: Backend + SvpPPolAllocBytesImpl<B>,
{
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize {
B::svp_ppol_alloc_bytes_impl(n, cols)
}
}

View File

@@ -2,54 +2,26 @@ use sampling::source::Source;
use crate::hal::{
api::{
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes,
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat,
VecZnxDecodeVeci64, VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform,
VecZnxFromBytes, VecZnxLshInplace, VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate,
VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace,
VecZnxRshInplace, VecZnxSplit, VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace,
VecZnxSwithcDegree,
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, VecZnxDecodeVeci64,
VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef},
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
};
impl<B> VecZnxAlloc for Module<B>
where
B: Backend + VecZnxAllocImpl<B>,
{
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned {
B::vec_znx_alloc_impl(self.n(), cols, size)
}
}
impl<B> VecZnxFromBytes for Module<B>
where
B: Backend + VecZnxFromBytesImpl<B>,
{
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
B::vec_znx_from_bytes_impl(self.n(), cols, size, bytes)
}
}
impl<B> VecZnxAllocBytes for Module<B>
where
B: Backend + VecZnxAllocBytesImpl<B>,
{
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize {
B::vec_znx_alloc_bytes_impl(self.n(), cols, size)
}
}
impl<B> VecZnxNormalizeTmpBytes for Module<B>
where
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,

View File

@@ -24,8 +24,8 @@ impl<B> VecZnxBigAlloc<B> for Module<B>
where
B: Backend + VecZnxBigAllocImpl<B>,
{
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
B::vec_znx_big_alloc_impl(self.n(), cols, size)
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B> {
B::vec_znx_big_alloc_impl(n, cols, size)
}
}
@@ -33,8 +33,8 @@ impl<B> VecZnxBigFromBytes<B> for Module<B>
where
B: Backend + VecZnxBigFromBytesImpl<B>,
{
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
B::vec_znx_big_from_bytes_impl(n, cols, size, bytes)
}
}
@@ -42,8 +42,8 @@ impl<B> VecZnxBigAllocBytes for Module<B>
where
B: Backend + VecZnxBigAllocBytesImpl<B>,
{
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
B::vec_znx_big_alloc_bytes_impl(n, cols, size)
}
}

View File

@@ -20,8 +20,8 @@ impl<B> VecZnxDftFromBytes<B> for Module<B>
where
B: Backend + VecZnxDftFromBytesImpl<B>,
{
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes)
}
}
@@ -29,8 +29,8 @@ impl<B> VecZnxDftAllocBytes for Module<B>
where
B: Backend + VecZnxDftAllocBytesImpl<B>,
{
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
B::vec_znx_dft_alloc_bytes_impl(n, cols, size)
}
}
@@ -38,8 +38,8 @@ impl<B> VecZnxDftAlloc<B> for Module<B>
where
B: Backend + VecZnxDftAllocImpl<B>,
{
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B> {
B::vec_znx_dft_alloc_impl(n, cols, size)
}
}
@@ -47,8 +47,8 @@ impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
where
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
{
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n)
}
}

View File

@@ -14,8 +14,8 @@ impl<B> VmpPMatAlloc<B> for Module<B>
where
B: Backend + VmpPMatAllocImpl<B>,
{
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
}
}
@@ -23,8 +23,8 @@ impl<B> VmpPMatAllocBytes for Module<B>
where
B: Backend + VmpPMatAllocBytesImpl<B>,
{
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
}
}
@@ -32,8 +32,16 @@ impl<B> VmpPMatFromBytes<B> for Module<B>
where
B: Backend + VmpPMatFromBytesImpl<B>,
{
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
fn vmp_pmat_from_bytes(
&self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B> {
B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
}
}
@@ -41,8 +49,8 @@ impl<B> VmpPrepareTmpBytes for Module<B>
where
B: Backend + VmpPrepareTmpBytesImpl<B>,
{
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
}
}
@@ -65,6 +73,7 @@ where
{
fn vmp_apply_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -73,7 +82,7 @@ where
b_size: usize,
) -> usize {
B::vmp_apply_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
@@ -98,6 +107,7 @@ where
{
fn vmp_apply_add_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -106,7 +116,7 @@ where
b_size: usize,
) -> usize {
B::vmp_apply_add_tmp_bytes_impl(
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}

View File

@@ -1,7 +1,7 @@
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
},
};
@@ -78,15 +78,13 @@ impl<D: Data> MatZnx<D> {
}
}
impl<D: DataRef> MatZnx<D> {
pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size)
impl MatZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size)
}
}
impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n,
@@ -97,16 +95,9 @@ impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
}
}
pub(crate) fn from_bytes(
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: impl Into<Vec<u8>>,
) -> Self {
pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n,
@@ -127,7 +118,7 @@ impl<D: DataRef> MatZnx<D> {
}
let self_ref: MatZnx<&[u8]> = self.to_ref();
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(self.n, self.cols_out, self.size);
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(self.n, self.cols_out, self.size);
let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
let end: usize = start + nb_bytes;
@@ -155,7 +146,7 @@ impl<D: DataMut> MatZnx<D> {
let size: usize = self.size();
let self_ref: MatZnx<&mut [u8]> = self.to_mut();
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size);
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size);
let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
let end: usize = start + nb_bytes;
@@ -175,6 +166,17 @@ impl<D: DataMut> FillUniform for MatZnx<D> {
}
}
impl<D: DataMut> Reset for MatZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.size = 0;
self.rows = 0;
self.cols_in = 0;
self.cols_out = 0;
}
}
pub type MatZnxOwned = MatZnx<Vec<u8>>;
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;

View File

@@ -6,7 +6,7 @@ use sampling::source::Source;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
},
};
@@ -107,15 +107,13 @@ impl<D: DataMut> ScalarZnx<D> {
}
}
impl<D: DataRef> ScalarZnx<D> {
pub fn bytes_of(n: usize, cols: usize) -> usize {
impl ScalarZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, cols: usize) -> usize {
n * cols * size_of::<i64>()
}
}
impl<D: DataRef + From<Vec<u8>>> ScalarZnx<D> {
pub fn alloc(n: usize, cols: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols));
Self {
data: data.into(),
n,
@@ -123,9 +121,9 @@ impl<D: DataRef + From<Vec<u8>>> ScalarZnx<D> {
}
}
pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
pub fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, cols));
assert!(data.len() == Self::alloc_bytes(n, cols));
Self {
data: data.into(),
n,
@@ -149,6 +147,14 @@ impl<D: DataMut> FillUniform for ScalarZnx<D> {
}
}
impl<D: DataMut> Reset for ScalarZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
}
}
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
impl<D: Data> ScalarZnx<D> {

View File

@@ -3,7 +3,7 @@ use std::fmt;
use crate::{
alloc_aligned,
hal::{
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo},
},
};
@@ -79,15 +79,13 @@ impl<D: DataMut> ZnxZero for VecZnx<D> {
}
}
impl<D: DataRef> VecZnx<D> {
pub fn alloc_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<Scalar>()
impl VecZnx<Vec<u8>> {
pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<i64>()
}
}
impl<D: DataRef + From<Vec<u8>>> VecZnx<D> {
pub fn alloc<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes::<Scalar>(n, cols, size));
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols, size));
Self {
data: data.into(),
n,
@@ -99,7 +97,7 @@ impl<D: DataRef + From<Vec<u8>>> VecZnx<D> {
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::alloc_bytes::<Scalar>(n, cols, size));
assert!(data.len() == Self::alloc_bytes(n, cols, size));
Self {
data: data.into(),
n,
@@ -163,6 +161,16 @@ impl<D: DataMut> FillUniform for VecZnx<D> {
}
}
impl<D: DataMut> Reset for VecZnx<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
self.size = 0;
self.max_size = 0;
}
}
pub type VecZnxOwned = VecZnx<Vec<u8>>;
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;

View File

@@ -1,20 +0,0 @@
use crate::hal::layouts::{Backend, MatZnxOwned, Module};
pub unsafe trait MatZnxAllocImpl<B: Backend> {
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
}
pub unsafe trait MatZnxAllocBytesImpl<B: Backend> {
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub unsafe trait MatZnxFromBytesImpl<B: Backend> {
fn mat_znx_from_bytes_impl(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxOwned;
}

View File

@@ -1,6 +1,4 @@
mod mat_znx;
mod module;
mod scalar_znx;
mod scratch;
mod svp_ppol;
mod vec_znx;
@@ -8,9 +6,7 @@ mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
pub use mat_znx::*;
pub use module::*;
pub use scalar_znx::*;
pub use scratch::*;
pub use svp_ppol::*;
pub use vec_znx::*;

View File

@@ -1,13 +0,0 @@
use crate::hal::layouts::{Backend, ScalarZnxOwned};
pub unsafe trait ScalarZnxFromBytesImpl<B: Backend> {
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
}
pub unsafe trait ScalarZnxAllocBytesImpl<B: Backend> {
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize;
}
pub unsafe trait ScalarZnxAllocImpl<B: Backend> {
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned;
}

View File

@@ -2,32 +2,7 @@ use rand_distr::Distribution;
use rug::Float;
use sampling::source::Source;
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::layouts::VecZnx::new] for reference code.
/// * See [crate::hal::api::VecZnxAlloc] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
/// * See test \[TODO\]
pub unsafe trait VecZnxAllocImpl<B: Backend> {
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::layouts::VecZnx::from_bytes] for reference code.
/// * See [crate::hal::api::VecZnxFromBytes] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxFromBytesImpl<B: Backend> {
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::hal::layouts::VecZnx::alloc_bytes] for reference code.
/// * See [crate::hal::api::VecZnxAllocBytes] for corresponding public API.
/// * See [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxAllocBytesImpl<B: Backend> {
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
}
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code.

View File

@@ -16,7 +16,7 @@ pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
}
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>) -> usize;
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
}
pub unsafe trait VecZnxDftToVecZnxBigImpl<B: Backend> {

View File

@@ -22,7 +22,14 @@ pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
}
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn vmp_prepare_tmp_bytes_impl(
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> usize;
}
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
@@ -35,6 +42,7 @@ pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
fn vmp_apply_tmp_bytes_impl(
module: &Module<B>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -55,6 +63,7 @@ pub unsafe trait VmpApplyImpl<B: Backend> {
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
fn vmp_apply_add_tmp_bytes_impl(
module: &Module<B>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,

View File

@@ -3,7 +3,7 @@ use std::fmt::Debug;
use sampling::source::Source;
use crate::hal::{
api::{FillUniform, ZnxZero},
api::{FillUniform, Reset},
layouts::{ReaderFrom, WriterTo},
};
@@ -12,7 +12,7 @@ use crate::hal::{
/// - `T` must implement I/O traits, zeroing, cloning, and random filling.
pub fn test_reader_writer_interface<T>(mut original: T)
where
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + ZnxZero + FillUniform,
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform,
{
// Fill original with uniform random data
let mut source = Source::new([0u8; 32]);
@@ -24,7 +24,7 @@ where
// Prepare receiver: same shape, but zeroed
let mut receiver = original.clone();
receiver.zero();
receiver.reset();
// Deserialize from buffer
let mut reader: &[u8] = &buffer;
@@ -45,7 +45,7 @@ fn scalar_znx_serialize() {
#[test]
fn vec_znx_serialize() {
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc::<i64>(1024, 3, 4);
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc(1024, 3, 4);
test_reader_writer_interface(original);
}

View File

@@ -2,25 +2,23 @@ use itertools::izip;
use sampling::source::Source;
use crate::hal::{
api::{
VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView,
ZnxViewMut,
},
api::{VecZnxAddNormal, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, ZnxViewMut},
layouts::{Backend, Module, VecZnx},
};
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxFillUniform + VecZnxStd + VecZnxAlloc,
Module<B>: VecZnxFillUniform + VecZnxStd,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; module.n()];
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
@@ -42,8 +40,9 @@ where
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxAddNormal + VecZnxStd + VecZnxAlloc,
Module<B>: VecZnxAddNormal + VecZnxStd,
{
let n: usize = module.n();
let basek: usize = 17;
let k: usize = 2 * 17;
let size: usize = 5;
@@ -51,10 +50,10 @@ where
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; module.n()];
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
@@ -71,21 +70,22 @@ where
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
let k: usize = size * basek - 5;
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); module.n()];
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); module.n()];
let mut want: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
});
@@ -93,17 +93,18 @@ where
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
where
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
{
let n: usize = module.n();
let basek: usize = 17;
let size: usize = 5;
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); module.n()];
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| {
if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
@@ -112,7 +113,7 @@ where
}
});
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
let mut want: Vec<i64> = vec![i64::default(); module.n()];
let mut want: Vec<i64> = vec![i64::default(); n];
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
})

View File

@@ -56,7 +56,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_idft_tmp_a(
@@ -67,19 +67,3 @@ unsafe extern "C" {
a_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft_automorphism(
module: *const MODULE,
d: i64,
res_dft: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
tmp: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
}

View File

@@ -86,6 +86,7 @@ unsafe extern "C" {
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
module: *const MODULE,
nn: u64,
res_size: u64,
a_size: u64,
nrows: u64,
@@ -109,5 +110,5 @@ unsafe extern "C" {
}
unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
}

View File

@@ -1,41 +0,0 @@
use crate::{
hal::{
layouts::{Backend, MatZnxOwned, Module},
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
},
implementation::cpu_spqlios::CPUAVX,
};
unsafe impl<B: Backend> MatZnxAllocImpl<B> for B
where
B: CPUAVX,
{
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
MatZnxOwned::alloc(module.n(), rows, cols_in, cols_out, size)
}
}
unsafe impl<B: Backend> MatZnxAllocBytesImpl<B> for B
where
B: CPUAVX,
{
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
MatZnxOwned::bytes_of(module.n(), rows, cols_in, cols_out, size)
}
}
unsafe impl<B: Backend> MatZnxFromBytesImpl<B> for B
where
B: CPUAVX,
{
fn mat_znx_from_bytes_impl(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxOwned {
MatZnxOwned::from_bytes(module.n(), rows, cols_in, cols_out, size, bytes)
}
}

View File

@@ -1,8 +1,6 @@
mod ffi;
mod mat_znx;
mod module_fft64;
mod module_ntt120;
mod scalar_znx;
mod scratch;
mod svp_ppol_fft64;
mod svp_ppol_ntt120;

View File

@@ -1,34 +0,0 @@
use crate::{
hal::{
layouts::{Backend, ScalarZnxOwned},
oep::{ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxFromBytesImpl},
},
implementation::cpu_spqlios::CPUAVX,
};
unsafe impl<B: Backend> ScalarZnxAllocBytesImpl<B> for B
where
B: CPUAVX,
{
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize {
ScalarZnxOwned::bytes_of(n, cols)
}
}
unsafe impl<B: Backend> ScalarZnxAllocImpl<B> for B
where
B: CPUAVX,
{
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned {
ScalarZnxOwned::alloc(n, cols)
}
}
unsafe impl<B: Backend> ScalarZnxFromBytesImpl<B> for B
where
B: CPUAVX,
{
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
ScalarZnxOwned::from_bytes(n, cols, bytes)
}
}

View File

@@ -6,10 +6,10 @@ use crate::{
api::ScratchFromBytes,
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
oep::{
ScalarZnxAllocBytesImpl, ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl,
SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl,
TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
VecZnxAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
},
},
implementation::cpu_spqlios::CPUAVX,
@@ -76,10 +76,10 @@ where
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
where
B: CPUAVX + ScalarZnxAllocBytesImpl<B>,
B: CPUAVX,
{
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::scalar_znx_alloc_bytes_impl(n, cols));
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
(
ScalarZnx::from_data(take_slice, n, cols),
Scratch::from_bytes(rem_slice),
@@ -102,13 +102,10 @@ where
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
where
B: CPUAVX + VecZnxAllocBytesImpl<B>,
B: CPUAVX,
{
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(
&mut scratch.data,
B::vec_znx_alloc_bytes_impl(n, cols, size),
);
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
(
VecZnx::from_data(take_slice, n, cols, size),
Scratch::from_bytes(rem_slice),
@@ -240,7 +237,7 @@ where
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
let (take_slice, rem_slice) = take_slice_aligned(
&mut scratch.data,
MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size),
MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size),
);
(
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),

View File

@@ -14,16 +14,16 @@ use crate::{
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
oep::{
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl,
VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
},
},
implementation::cpu_spqlios::{
@@ -32,33 +32,6 @@ use crate::{
},
};
unsafe impl<B: Backend> VecZnxAllocImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned {
VecZnxOwned::alloc::<i64>(n, cols, size)
}
}
unsafe impl<B: Backend> VecZnxFromBytesImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
VecZnxOwned::from_bytes::<i64>(n, cols, size, bytes)
}
}
unsafe impl<B: Backend> VecZnxAllocBytesImpl<B> for B
where
B: CPUAVX,
{
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
VecZnxOwned::alloc_bytes::<i64>(n, cols, size)
}
}
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
where
B: CPUAVX,
@@ -156,9 +129,8 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -192,8 +164,7 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
@@ -232,8 +203,7 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
@@ -269,9 +239,8 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -304,8 +273,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -337,8 +305,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -377,8 +344,7 @@ where
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
@@ -411,8 +377,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_negate(
@@ -437,10 +402,6 @@ where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
}
unsafe {
vec_znx::vec_znx_negate(
module.ptr() as *const module_info_t,
@@ -604,8 +565,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
@@ -633,7 +593,6 @@ where
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
@@ -668,8 +627,8 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(res.n(), res.n());
}
unsafe {
vec_znx::vec_znx_mul_xp_minus_one(
@@ -697,7 +656,7 @@ where
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), module.n());
assert_eq!(res.n(), res.n());
}
unsafe {
vec_znx::vec_znx_mul_xp_minus_one(
@@ -749,7 +708,7 @@ pub fn vec_znx_split_ref<R, A, B: Backend>(
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
let (mut buf, _) = scratch.take_vec_znx(module, 1, a.size());
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
debug_assert!(
n_out < n_in,

View File

@@ -210,9 +210,8 @@ unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -244,8 +243,7 @@ unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
@@ -285,9 +283,8 @@ unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -319,8 +316,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_add(
@@ -360,9 +356,8 @@ unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -394,8 +389,7 @@ unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -426,8 +420,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -467,9 +460,8 @@ unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -501,8 +493,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -542,9 +533,8 @@ unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(b.n(), res.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
@@ -576,8 +566,7 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_sub(
@@ -602,10 +591,6 @@ unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
}
unsafe {
vec_znx::vec_znx_negate(
module.ptr(),
@@ -677,8 +662,7 @@ unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), res.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
@@ -702,11 +686,6 @@ unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
module.ptr(),

View File

@@ -57,8 +57,8 @@ unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
}
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
}
}
@@ -74,26 +74,31 @@ unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes());
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), a.n())
}
let min_size: usize = res_mut.size().min(a_ref.size());
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n()));
let min_size: usize = res.size().min(a.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft(
module.ptr(),
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1 as u64,
tmp_bytes.as_mut_ptr(),
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
(min_size..res.size()).for_each(|j| {
res.zero_at(res_col, j);
});
}
}

View File

@@ -57,10 +57,18 @@ unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
}
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
fn vmp_prepare_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> usize {
unsafe {
vmp::vmp_prepare_tmp_bytes(
module.ptr(),
n as u64,
(rows * cols_in) as u64,
(cols_out * size) as u64,
) as usize
@@ -79,8 +87,7 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols_in(),
a.cols_in(),
@@ -111,7 +118,8 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
);
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()));
let (tmp_bytes, _) =
scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size()));
unsafe {
vmp::vmp_prepare_contiguous(
@@ -129,6 +137,7 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -139,6 +148,7 @@ unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
module.ptr(),
n as u64,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
@@ -161,9 +171,8 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), res.n());
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols(),
b.cols_out(),
@@ -181,6 +190,7 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
res.n(),
res.size(),
a.size(),
b.rows(),
@@ -207,6 +217,7 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
fn vmp_apply_add_tmp_bytes_impl(
module: &Module<FFT64>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -217,6 +228,7 @@ unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
module.ptr(),
n as u64,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
@@ -241,9 +253,8 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
{
use crate::hal::api::ZnxInfos;
assert_eq!(res.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), res.n());
assert_eq!(a.n(), res.n());
assert_eq!(
res.cols(),
b.cols_out(),
@@ -261,6 +272,7 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
}
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
res.n(),
res.size(),
a.size(),
b.rows(),