mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Added more serialization tests + generalize methods to any n
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
use backend::{
|
||||
hal::{
|
||||
api::{
|
||||
ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare,
|
||||
VecZnxAddNormal, VecZnxAlloc, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace, VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx,
|
||||
VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
|
||||
VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallBInplace,
|
||||
VecZnxDecodeVeci64, VecZnxDftAlloc, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigTmpA, VecZnxEncodeVeci64,
|
||||
VecZnxFillUniform, VecZnxNormalizeInplace, ZnxInfos,
|
||||
},
|
||||
layouts::{Module, ScalarZnx, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft},
|
||||
},
|
||||
@@ -27,17 +27,18 @@ fn main() {
|
||||
let mut source: Source = Source::new(seed);
|
||||
|
||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||
let mut s: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
|
||||
let mut s: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
|
||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(s.cols());
|
||||
let mut s_dft: SvpPPol<Vec<u8>, FFT64> = module.svp_ppol_alloc(n, s.cols());
|
||||
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||
|
||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||
let mut ct: VecZnx<Vec<u8>> = module.vec_znx_alloc(
|
||||
let mut ct: VecZnx<Vec<u8>> = VecZnx::alloc(
|
||||
module.n(),
|
||||
2, // Number of columns
|
||||
ct_size, // Number of small poly per column
|
||||
);
|
||||
@@ -45,7 +46,7 @@ fn main() {
|
||||
// Fill the second column with random values: ct = (0, a)
|
||||
module.vec_znx_fill_uniform(basek, &mut ct, 1, ct_size * basek, &mut source);
|
||||
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(1, ct_size);
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.vec_znx_dft_alloc(n, 1, ct_size);
|
||||
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut buf_dft, 0, &ct, 1);
|
||||
|
||||
@@ -60,11 +61,12 @@ fn main() {
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(1, ct_size);
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.vec_znx_big_alloc(n, 1, ct_size);
|
||||
module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m = module.vec_znx_alloc(
|
||||
let mut m = VecZnx::alloc(
|
||||
module.n(),
|
||||
1, // Number of columns
|
||||
msg_size, // Number of small polynomials
|
||||
);
|
||||
@@ -125,7 +127,7 @@ fn main() {
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
// m + e <- BIG(ct[1] * s + ct[0])
|
||||
let mut res = module.vec_znx_alloc(1, ct_size);
|
||||
let mut res = VecZnx::alloc(module.n(), 1, ct_size);
|
||||
module.vec_znx_big_normalize(basek, &mut res, 0, &buf_big, 0, scratch.borrow());
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
use crate::hal::layouts::MatZnxOwned;
|
||||
|
||||
/// Allocates as [crate::hal::layouts::MatZnx].
|
||||
pub trait MatZnxAlloc {
|
||||
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::MatZnx].
|
||||
pub trait MatZnxAllocBytes {
|
||||
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [MatZnxAllocBytes].
|
||||
pub trait MatZnxFromBytes {
|
||||
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned;
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
@@ -9,9 +7,7 @@ mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod znx_base;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
use crate::hal::layouts::ScalarZnxOwned;
|
||||
|
||||
/// Allocates as [crate::hal::layouts::ScalarZnx].
|
||||
pub trait ScalarZnxAlloc {
|
||||
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::ScalarZnx].
|
||||
pub trait ScalarZnxAllocBytes {
|
||||
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::ScalarZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [ScalarZnxAllocBytes].
|
||||
pub trait ScalarZnxFromBytes {
|
||||
fn scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::hal::layouts::{Backend, MatZnx, Module, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
|
||||
use crate::hal::layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat};
|
||||
|
||||
/// Allocates a new [crate::hal::layouts::ScratchOwned] of `size` aligned bytes.
|
||||
pub trait ScratchOwnedAlloc<B: Backend> {
|
||||
@@ -27,44 +27,38 @@ pub trait TakeSlice {
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [ScalarZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeScalarZnx<B: Backend> {
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
|
||||
pub trait TakeScalarZnx {
|
||||
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [SvpPPol] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeSvpPPol<B: Backend> {
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
|
||||
fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnx<B: Backend> {
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
|
||||
pub trait TakeVecZnx {
|
||||
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnx] aand returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxSlice<B: Backend> {
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
|
||||
pub trait TakeVecZnxSlice {
|
||||
fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxBig] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxBig<B: Backend> {
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
|
||||
fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [VecZnxDft] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeVecZnxDft<B: Backend> {
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
|
||||
fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self);
|
||||
}
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], slices it into a vector of [VecZnxDft] and returns it
|
||||
@@ -73,7 +67,7 @@ pub trait TakeVecZnxDftSlice<B: Backend> {
|
||||
fn take_vec_znx_dft_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self);
|
||||
@@ -84,7 +78,7 @@ pub trait TakeVecZnxDftSlice<B: Backend> {
|
||||
pub trait TakeVmpPMat<B: Backend> {
|
||||
fn take_vmp_pmat(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
@@ -94,10 +88,10 @@ pub trait TakeVmpPMat<B: Backend> {
|
||||
|
||||
/// Take a slice of bytes from a [Scratch], wraps it into a [MatZnx] and returns it
|
||||
/// as well as a new [Scratch] minus the taken array of bytes.
|
||||
pub trait TakeMatZnx<B: Backend> {
|
||||
pub trait TakeMatZnx {
|
||||
fn take_mat_znx(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
|
||||
@@ -2,18 +2,18 @@ use crate::hal::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, S
|
||||
|
||||
/// Allocates as [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPPolAlloc<B: Backend> {
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B>;
|
||||
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::SvpPPol].
|
||||
pub trait SvpPPolAllocBytes {
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize;
|
||||
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
|
||||
pub trait SvpPPolFromBytes<B: Backend> {
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Prepare a [crate::hal::layouts::ScalarZnx] into an [crate::hal::layouts::SvpPPol].
|
||||
|
||||
@@ -2,33 +2,7 @@ use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
pub trait VecZnxAlloc {
|
||||
/// Allocates a new [crate::hal::layouts::VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
pub trait VecZnxFromBytes {
|
||||
/// Instantiates a new [crate::hal::layouts::VecZnx] from a slice of bytes.
|
||||
/// The returned [crate::hal::layouts::VecZnx] takes ownership of the slice of bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
pub trait VecZnxAllocBytes {
|
||||
/// Returns the number of bytes necessary to allocate a new [crate::hal::layouts::VecZnx].
|
||||
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
use crate::hal::layouts::{Backend, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
pub trait VecZnxNormalizeTmpBytes {
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
|
||||
@@ -5,18 +5,18 @@ use crate::hal::layouts::{Backend, Scratch, VecZnxBigOwned, VecZnxBigToMut, VecZ
|
||||
|
||||
/// Allocates as [crate::hal::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::hal::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAllocBytes {
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::hal::layouts::VecZnxBig].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
|
||||
pub trait VecZnxBigFromBytes<B: Backend> {
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Add a discrete normal distribution on res.
|
||||
|
||||
@@ -3,19 +3,19 @@ use crate::hal::layouts::{
|
||||
};
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftFromBytes<B: Backend> {
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAllocBytes {
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigTmpBytes {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize;
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBig<B: Backend> {
|
||||
|
||||
@@ -3,19 +3,27 @@ use crate::hal::layouts::{
|
||||
};
|
||||
|
||||
pub trait VmpPMatAlloc<B: Backend> {
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPMatAllocBytes {
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatFromBytes<B: Backend> {
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
|
||||
fn vmp_pmat_from_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPrepareTmpBytes {
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatPrepare<B: Backend> {
|
||||
@@ -28,6 +36,7 @@ pub trait VmpPMatPrepare<B: Backend> {
|
||||
pub trait VmpApplyTmpBytes {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -72,6 +81,7 @@ pub trait VmpApply<B: Backend> {
|
||||
pub trait VmpApplyAddTmpBytes {
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
|
||||
@@ -113,3 +113,7 @@ where
|
||||
pub trait FillUniform {
|
||||
fn fill_uniform(&mut self, source: &mut Source);
|
||||
}
|
||||
|
||||
pub trait Reset {
|
||||
fn reset(&mut self);
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
use crate::hal::{
|
||||
api::{MatZnxAlloc, MatZnxAllocBytes, MatZnxFromBytes},
|
||||
layouts::{Backend, MatZnxOwned, Module},
|
||||
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
|
||||
};
|
||||
|
||||
impl<B> MatZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxAllocImpl<B>,
|
||||
{
|
||||
fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
|
||||
B::mat_znx_alloc_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> MatZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn mat_znx_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::mat_znx_alloc_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> MatZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + MatZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn mat_znx_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> MatZnxOwned {
|
||||
B::mat_znx_from_bytes_impl(self, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
use crate::hal::{
|
||||
api::{ScalarZnxAlloc, ScalarZnxAllocBytes},
|
||||
layouts::{Backend, Module, ScalarZnxOwned},
|
||||
oep::{ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl},
|
||||
};
|
||||
|
||||
impl<B> ScalarZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn scalar_znx_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::scalar_znx_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ScalarZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + ScalarZnxAllocImpl<B>,
|
||||
{
|
||||
fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnxOwned {
|
||||
B::scalar_znx_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,7 @@ use crate::hal::{
|
||||
ScratchAvailable, ScratchFromBytes, ScratchOwnedAlloc, ScratchOwnedBorrow, TakeLike, TakeMatZnx, TakeScalarZnx,
|
||||
TakeSlice, TakeSvpPPol, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice, TakeVecZnxSlice, TakeVmpPMat,
|
||||
},
|
||||
layouts::{
|
||||
Backend, DataRef, MatZnx, Module, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat,
|
||||
},
|
||||
layouts::{Backend, DataRef, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeLikeImpl, TakeMatZnxImpl,
|
||||
TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl,
|
||||
@@ -58,12 +56,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeScalarZnx<B> for Scratch<B>
|
||||
impl<B> TakeScalarZnx for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeScalarZnxImpl<B>,
|
||||
{
|
||||
fn take_scalar_znx(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_scalar_znx_impl(self, module.n(), cols)
|
||||
fn take_scalar_znx(&mut self, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_scalar_znx_impl(self, n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,32 +69,26 @@ impl<B> TakeSvpPPol<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeSvpPPolImpl<B>,
|
||||
{
|
||||
fn take_svp_ppol(&mut self, module: &Module<B>, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
|
||||
B::take_svp_ppol_impl(self, module.n(), cols)
|
||||
fn take_svp_ppol(&mut self, n: usize, cols: usize) -> (SvpPPol<&mut [u8], B>, &mut Self) {
|
||||
B::take_svp_ppol_impl(self, n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnx<B> for Scratch<B>
|
||||
impl<B> TakeVecZnx for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxImpl<B>,
|
||||
{
|
||||
fn take_vec_znx(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_vec_znx_impl(self, module.n(), cols, size)
|
||||
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_vec_znx_impl(self, n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeVecZnxSlice<B> for Scratch<B>
|
||||
impl<B> TakeVecZnxSlice for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxSliceImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
B::take_vec_znx_slice_impl(self, len, module.n(), cols, size)
|
||||
fn take_vec_znx_slice(&mut self, len: usize, n: usize, cols: usize, size: usize) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||
B::take_vec_znx_slice_impl(self, len, n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,8 +96,8 @@ impl<B> TakeVecZnxBig<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxBigImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_big(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_big_impl(self, module.n(), cols, size)
|
||||
fn take_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_big_impl(self, n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,8 +105,8 @@ impl<B> TakeVecZnxDft<B> for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeVecZnxDftImpl<B>,
|
||||
{
|
||||
fn take_vec_znx_dft(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_dft_impl(self, module.n(), cols, size)
|
||||
fn take_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
B::take_vec_znx_dft_impl(self, n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,11 +117,11 @@ where
|
||||
fn take_vec_znx_dft_slice(
|
||||
&mut self,
|
||||
len: usize,
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
||||
B::take_vec_znx_dft_slice_impl(self, len, module.n(), cols, size)
|
||||
B::take_vec_znx_dft_slice_impl(self, len, n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,29 +131,29 @@ where
|
||||
{
|
||||
fn take_vmp_pmat(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (VmpPMat<&mut [u8], B>, &mut Self) {
|
||||
B::take_vmp_pmat_impl(self, module.n(), rows, cols_in, cols_out, size)
|
||||
B::take_vmp_pmat_impl(self, n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> TakeMatZnx<B> for Scratch<B>
|
||||
impl<B> TakeMatZnx for Scratch<B>
|
||||
where
|
||||
B: Backend + TakeMatZnxImpl<B>,
|
||||
{
|
||||
fn take_mat_znx(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> (MatZnx<&mut [u8]>, &mut Self) {
|
||||
B::take_mat_znx_impl(self, module.n(), rows, cols_in, cols_out, size)
|
||||
B::take_mat_znx_impl(self, n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ impl<B> SvpPPolFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolFromBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
|
||||
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ impl<B> SvpPPolAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(self.n(), cols)
|
||||
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ impl<B> SvpPPolAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
|
||||
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,54 +2,26 @@ use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes,
|
||||
VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat,
|
||||
VecZnxDecodeVeci64, VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform,
|
||||
VecZnxFromBytes, VecZnxLshInplace, VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate,
|
||||
VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace,
|
||||
VecZnxRshInplace, VecZnxSplit, VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace,
|
||||
VecZnxSwithcDegree,
|
||||
VecZnxAdd, VecZnxAddDistF64, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism,
|
||||
VecZnxAutomorphismInplace, VecZnxCopy, VecZnxDecodeCoeffsi64, VecZnxDecodeVecFloat, VecZnxDecodeVeci64,
|
||||
VecZnxEncodeCoeffsi64, VecZnxEncodeVeci64, VecZnxFillDistF64, VecZnxFillNormal, VecZnxFillUniform, VecZnxLshInplace,
|
||||
VecZnxMerge, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSplit,
|
||||
VecZnxStd, VecZnxSub, VecZnxSubABInplace, VecZnxSubBAInplace, VecZnxSubScalarInplace, VecZnxSwithcDegree,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef},
|
||||
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
|
||||
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
|
||||
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
|
||||
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
|
||||
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
|
||||
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
|
||||
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
|
||||
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl, VecZnxNormalizeImpl,
|
||||
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl,
|
||||
VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl,
|
||||
VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
};
|
||||
|
||||
impl<B> VecZnxAlloc for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||
B::vec_znx_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxFromBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
B::vec_znx_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> VecZnxNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
|
||||
|
||||
@@ -24,8 +24,8 @@ impl<B> VecZnxBigAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(self.n(), cols, size)
|
||||
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ impl<B> VecZnxBigFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,8 +42,8 @@ impl<B> VecZnxBigAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
|
||||
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ impl<B> VecZnxDftFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ impl<B> VecZnxDftAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
|
||||
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ impl<B> VecZnxDftAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
|
||||
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ impl<B> VmpPMatAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ impl<B> VmpPMatAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,8 +32,16 @@ impl<B> VmpPMatFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatFromBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
|
||||
fn vmp_pmat_from_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,8 +49,8 @@ impl<B> VmpPrepareTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPrepareTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +73,7 @@ where
|
||||
{
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -73,7 +82,7 @@ where
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -98,6 +107,7 @@ where
|
||||
{
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -106,7 +116,7 @@ where
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_add_tmp_bytes_impl(
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
|
||||
},
|
||||
};
|
||||
@@ -78,15 +78,13 @@ impl<D: Data> MatZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> MatZnx<D> {
|
||||
pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size)
|
||||
impl MatZnx<Vec<u8>> {
|
||||
pub fn alloc_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
rows * cols_in * VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
|
||||
pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
pub fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::alloc_bytes(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -97,16 +95,9 @@ impl<D: DataRef + From<Vec<u8>>> MatZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
|
||||
assert!(data.len() == Self::alloc_bytes(n, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -127,7 +118,7 @@ impl<D: DataRef> MatZnx<D> {
|
||||
}
|
||||
|
||||
let self_ref: MatZnx<&[u8]> = self.to_ref();
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(self.n, self.cols_out, self.size);
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(self.n, self.cols_out, self.size);
|
||||
let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
|
||||
let end: usize = start + nb_bytes;
|
||||
|
||||
@@ -155,7 +146,7 @@ impl<D: DataMut> MatZnx<D> {
|
||||
let size: usize = self.size();
|
||||
|
||||
let self_ref: MatZnx<&mut [u8]> = self.to_mut();
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes::<i64>(n, cols_out, size);
|
||||
let nb_bytes: usize = VecZnx::<Vec<u8>>::alloc_bytes(n, cols_out, size);
|
||||
let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
|
||||
let end: usize = start + nb_bytes;
|
||||
|
||||
@@ -175,6 +166,17 @@ impl<D: DataMut> FillUniform for MatZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> Reset for MatZnx<D> {
|
||||
fn reset(&mut self) {
|
||||
self.zero();
|
||||
self.n = 0;
|
||||
self.size = 0;
|
||||
self.rows = 0;
|
||||
self.cols_in = 0;
|
||||
self.cols_out = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxOwned = MatZnx<Vec<u8>>;
|
||||
pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
|
||||
pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
|
||||
|
||||
@@ -6,7 +6,7 @@ use sampling::source::Source;
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo},
|
||||
},
|
||||
};
|
||||
@@ -107,15 +107,13 @@ impl<D: DataMut> ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ScalarZnx<D> {
|
||||
pub fn bytes_of(n: usize, cols: usize) -> usize {
|
||||
impl ScalarZnx<Vec<u8>> {
|
||||
pub fn alloc_bytes(n: usize, cols: usize) -> usize {
|
||||
n * cols * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> ScalarZnx<D> {
|
||||
pub fn alloc(n: usize, cols: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols));
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -123,9 +121,9 @@ impl<D: DataRef + From<Vec<u8>>> ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
pub fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(n, cols));
|
||||
assert!(data.len() == Self::alloc_bytes(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -149,6 +147,14 @@ impl<D: DataMut> FillUniform for ScalarZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> Reset for ScalarZnx<D> {
|
||||
fn reset(&mut self) {
|
||||
self.zero();
|
||||
self.n = 0;
|
||||
self.cols = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
||||
|
||||
impl<D: Data> ScalarZnx<D> {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::fmt;
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
hal::{
|
||||
api::{DataView, DataViewMut, FillUniform, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, WriterTo},
|
||||
},
|
||||
};
|
||||
@@ -79,15 +79,13 @@ impl<D: DataMut> ZnxZero for VecZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> VecZnx<D> {
|
||||
pub fn alloc_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<Scalar>()
|
||||
impl VecZnx<Vec<u8>> {
|
||||
pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef + From<Vec<u8>>> VecZnx<D> {
|
||||
pub fn alloc<Scalar: Sized>(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes::<Scalar>(n, cols, size));
|
||||
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -99,7 +97,7 @@ impl<D: DataRef + From<Vec<u8>>> VecZnx<D> {
|
||||
|
||||
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::alloc_bytes::<Scalar>(n, cols, size));
|
||||
assert!(data.len() == Self::alloc_bytes(n, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
@@ -163,6 +161,16 @@ impl<D: DataMut> FillUniform for VecZnx<D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> Reset for VecZnx<D> {
|
||||
fn reset(&mut self) {
|
||||
self.zero();
|
||||
self.n = 0;
|
||||
self.cols = 0;
|
||||
self.size = 0;
|
||||
self.max_size = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxOwned = VecZnx<Vec<u8>>;
|
||||
pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
use crate::hal::layouts::{Backend, MatZnxOwned, Module};
|
||||
|
||||
pub unsafe trait MatZnxAllocImpl<B: Backend> {
|
||||
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait MatZnxAllocBytesImpl<B: Backend> {
|
||||
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait MatZnxFromBytesImpl<B: Backend> {
|
||||
fn mat_znx_from_bytes_impl(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxOwned;
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
mod mat_znx;
|
||||
mod module;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol;
|
||||
mod vec_znx;
|
||||
@@ -8,9 +6,7 @@ mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scratch::*;
|
||||
pub use svp_ppol::*;
|
||||
pub use vec_znx::*;
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
use crate::hal::layouts::{Backend, ScalarZnxOwned};
|
||||
|
||||
pub unsafe trait ScalarZnxFromBytesImpl<B: Backend> {
|
||||
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAllocBytesImpl<B: Backend> {
|
||||
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait ScalarZnxAllocImpl<B: Backend> {
|
||||
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned;
|
||||
}
|
||||
@@ -2,32 +2,7 @@ use rand_distr::Distribution;
|
||||
use rug::Float;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxOwned, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::new] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAlloc] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
/// * See test \[TODO\]
|
||||
pub unsafe trait VecZnxAllocImpl<B: Backend> {
|
||||
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::from_bytes] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxFromBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxFromBytesImpl<B: Backend> {
|
||||
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::hal::layouts::VecZnx::alloc_bytes] for reference code.
|
||||
/// * See [crate::hal::api::VecZnxAllocBytes] for corresponding public API.
|
||||
/// * See [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxAllocBytesImpl<B: Backend> {
|
||||
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
use crate::hal::layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [vec_znx_normalize_base2k_tmp_bytes_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/vec_znx.c#L245C17-L245C55) for reference code.
|
||||
|
||||
@@ -16,7 +16,7 @@ pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VecZnxDftToVecZnxBigImpl<B: Backend> {
|
||||
|
||||
@@ -22,7 +22,14 @@ pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
fn vmp_prepare_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
@@ -35,6 +42,7 @@ pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -55,6 +63,7 @@ pub unsafe trait VmpApplyImpl<B: Backend> {
|
||||
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::fmt::Debug;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{FillUniform, ZnxZero},
|
||||
api::{FillUniform, Reset},
|
||||
layouts::{ReaderFrom, WriterTo},
|
||||
};
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::hal::{
|
||||
/// - `T` must implement I/O traits, zeroing, cloning, and random filling.
|
||||
pub fn test_reader_writer_interface<T>(mut original: T)
|
||||
where
|
||||
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + ZnxZero + FillUniform,
|
||||
T: WriterTo + ReaderFrom + PartialEq + Eq + Debug + Clone + Reset + FillUniform,
|
||||
{
|
||||
// Fill original with uniform random data
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
@@ -24,7 +24,7 @@ where
|
||||
|
||||
// Prepare receiver: same shape, but zeroed
|
||||
let mut receiver = original.clone();
|
||||
receiver.zero();
|
||||
receiver.reset();
|
||||
|
||||
// Deserialize from buffer
|
||||
let mut reader: &[u8] = &buffer;
|
||||
@@ -45,7 +45,7 @@ fn scalar_znx_serialize() {
|
||||
|
||||
#[test]
|
||||
fn vec_znx_serialize() {
|
||||
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc::<i64>(1024, 3, 4);
|
||||
let original: crate::hal::layouts::VecZnx<Vec<u8>> = crate::hal::layouts::VecZnx::alloc(1024, 3, 4);
|
||||
test_reader_writer_interface(original);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,25 +2,23 @@ use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::hal::{
|
||||
api::{
|
||||
VecZnxAddNormal, VecZnxAlloc, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView,
|
||||
ZnxViewMut,
|
||||
},
|
||||
api::{VecZnxAddNormal, VecZnxDecodeVeci64, VecZnxEncodeVeci64, VecZnxFillUniform, VecZnxStd, ZnxInfos, ZnxView, ZnxViewMut},
|
||||
layouts::{Backend, Module, VecZnx},
|
||||
};
|
||||
|
||||
pub fn test_vec_znx_fill_uniform<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxFillUniform + VecZnxStd + VecZnxAlloc,
|
||||
Module<B>: VecZnxFillUniform + VecZnxStd,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; module.n()];
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
|
||||
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
|
||||
module.vec_znx_fill_uniform(basek, &mut a, col_i, size * basek, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
@@ -42,8 +40,9 @@ where
|
||||
|
||||
pub fn test_vec_znx_add_normal<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxAddNormal + VecZnxStd + VecZnxAlloc,
|
||||
Module<B>: VecZnxAddNormal + VecZnxStd,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
@@ -51,10 +50,10 @@ where
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; module.n()];
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(cols, size);
|
||||
let mut a: VecZnx<_> = VecZnx::alloc(n, cols, size);
|
||||
module.vec_znx_add_normal(basek, &mut a, col_i, k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
@@ -71,21 +70,22 @@ where
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_lo_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
let k: usize = size * basek - 5;
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); module.n()];
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); module.n()];
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
@@ -93,17 +93,18 @@ where
|
||||
|
||||
pub fn test_vec_znx_encode_vec_i64_hi_norm<B: Backend>(module: &Module<B>)
|
||||
where
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64 + VecZnxAlloc,
|
||||
Module<B>: VecZnxEncodeVeci64 + VecZnxDecodeVeci64,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 17;
|
||||
let size: usize = 5;
|
||||
for k in [1, basek / 2, size * basek - 5] {
|
||||
let mut a: VecZnx<_> = module.vec_znx_alloc(2, size);
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, 2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); module.n()];
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| {
|
||||
if k < 64 {
|
||||
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
|
||||
@@ -112,7 +113,7 @@ where
|
||||
}
|
||||
});
|
||||
module.encode_vec_i64(basek, &mut a, col_i, k, &have, 63);
|
||||
let mut want: Vec<i64> = vec![i64::default(); module.n()];
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
module.decode_vec_i64(basek, &a, col_i, k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
|
||||
@@ -56,7 +56,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
@@ -67,19 +67,3 @@ unsafe extern "C" {
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism(
|
||||
module: *const MODULE,
|
||||
d: i64,
|
||||
res_dft: *mut VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
a_dft: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
tmp: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_dft_automorphism_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
@@ -86,6 +86,7 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
nn: u64,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
@@ -109,5 +110,5 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, MatZnxOwned, Module},
|
||||
oep::{MatZnxAllocBytesImpl, MatZnxAllocImpl, MatZnxFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> MatZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_alloc_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxOwned {
|
||||
MatZnxOwned::alloc(module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> MatZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_alloc_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
MatZnxOwned::bytes_of(module.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> MatZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn mat_znx_from_bytes_impl(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxOwned {
|
||||
MatZnxOwned::from_bytes(module.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
mod ffi;
|
||||
mod mat_znx;
|
||||
mod module_fft64;
|
||||
mod module_ntt120;
|
||||
mod scalar_znx;
|
||||
mod scratch;
|
||||
mod svp_ppol_fft64;
|
||||
mod svp_ppol_ntt120;
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
use crate::{
|
||||
hal::{
|
||||
layouts::{Backend, ScalarZnxOwned},
|
||||
oep::{ScalarZnxAllocBytesImpl, ScalarZnxAllocImpl, ScalarZnxFromBytesImpl},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_alloc_bytes_impl(n: usize, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_alloc_impl(n: usize, cols: usize) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::alloc(n, cols)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> ScalarZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn scalar_znx_from_bytes_impl(n: usize, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::from_bytes(n, cols, bytes)
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,10 @@ use crate::{
|
||||
api::ScratchFromBytes,
|
||||
layouts::{Backend, MatZnx, ScalarZnx, Scratch, ScratchOwned, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
|
||||
oep::{
|
||||
ScalarZnxAllocBytesImpl, ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl,
|
||||
SvpPPolAllocBytesImpl, TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl,
|
||||
TakeVecZnxDftImpl, TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
ScratchAvailableImpl, ScratchFromBytesImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, SvpPPolAllocBytesImpl,
|
||||
TakeMatZnxImpl, TakeScalarZnxImpl, TakeSliceImpl, TakeSvpPPolImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, TakeVmpPMatImpl, VecZnxBigAllocBytesImpl,
|
||||
VecZnxDftAllocBytesImpl, VmpPMatAllocBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::CPUAVX,
|
||||
@@ -76,10 +76,10 @@ where
|
||||
|
||||
unsafe impl<B: Backend> TakeScalarZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + ScalarZnxAllocBytesImpl<B>,
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_scalar_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, B::scalar_znx_alloc_bytes_impl(n, cols));
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, ScalarZnx::alloc_bytes(n, cols));
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, n, cols),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
@@ -102,13 +102,10 @@ where
|
||||
|
||||
unsafe impl<B: Backend> TakeVecZnxImpl<B> for B
|
||||
where
|
||||
B: CPUAVX + VecZnxAllocBytesImpl<B>,
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn take_vec_znx_impl(scratch: &mut Scratch<B>, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
B::vec_znx_alloc_bytes_impl(n, cols, size),
|
||||
);
|
||||
let (take_slice, rem_slice) = take_slice_aligned(&mut scratch.data, VecZnx::alloc_bytes(n, cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, n, cols, size),
|
||||
Scratch::from_bytes(rem_slice),
|
||||
@@ -240,7 +237,7 @@ where
|
||||
) -> (MatZnx<&mut [u8]>, &mut Scratch<B>) {
|
||||
let (take_slice, rem_slice) = take_slice_aligned(
|
||||
&mut scratch.data,
|
||||
MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size),
|
||||
MatZnx::alloc_bytes(n, rows, cols_in, cols_out, size),
|
||||
);
|
||||
(
|
||||
MatZnx::from_data(take_slice, n, rows, cols_in, cols_out, size),
|
||||
|
||||
Submodule backend/src/implementation/cpu_spqlios/spqlios-arithmetic updated: 7160f588da...de62af3507
@@ -14,16 +14,16 @@ use crate::{
|
||||
VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef},
|
||||
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
oep::{
|
||||
VecZnxAddDistF64Impl, VecZnxAddImpl, VecZnxAddInplaceImpl, VecZnxAddNormalImpl, VecZnxAddScalarInplaceImpl,
|
||||
VecZnxAllocBytesImpl, VecZnxAllocImpl, VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl,
|
||||
VecZnxDecodeCoeffsi64Impl, VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl,
|
||||
VecZnxEncodeVeci64Impl, VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxFromBytesImpl,
|
||||
VecZnxLshInplaceImpl, VecZnxMergeImpl, VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl,
|
||||
VecZnxNegateInplaceImpl, VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl,
|
||||
VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl,
|
||||
VecZnxSubABInplaceImpl, VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
VecZnxAutomorphismImpl, VecZnxAutomorphismInplaceImpl, VecZnxCopyImpl, VecZnxDecodeCoeffsi64Impl,
|
||||
VecZnxDecodeVecFloatImpl, VecZnxDecodeVeci64Impl, VecZnxEncodeCoeffsi64Impl, VecZnxEncodeVeci64Impl,
|
||||
VecZnxFillDistF64Impl, VecZnxFillNormalImpl, VecZnxFillUniformImpl, VecZnxLshInplaceImpl, VecZnxMergeImpl,
|
||||
VecZnxMulXpMinusOneImpl, VecZnxMulXpMinusOneInplaceImpl, VecZnxNegateImpl, VecZnxNegateInplaceImpl,
|
||||
VecZnxNormalizeImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl,
|
||||
VecZnxRotateInplaceImpl, VecZnxRshInplaceImpl, VecZnxSplitImpl, VecZnxStdImpl, VecZnxSubABInplaceImpl,
|
||||
VecZnxSubBAInplaceImpl, VecZnxSubImpl, VecZnxSubScalarInplaceImpl, VecZnxSwithcDegreeImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::{
|
||||
@@ -32,33 +32,6 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
unsafe impl<B: Backend> VecZnxAllocImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_alloc_impl(n: usize, cols: usize, size: usize) -> VecZnxOwned {
|
||||
VecZnxOwned::alloc::<i64>(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxFromBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_from_bytes_impl(n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
VecZnxOwned::from_bytes::<i64>(n, cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxAllocBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
{
|
||||
fn vec_znx_alloc_bytes_impl(n: usize, cols: usize, size: usize) -> usize {
|
||||
VecZnxOwned::alloc_bytes::<i64>(n, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<B: Backend> VecZnxNormalizeTmpBytesImpl<B> for B
|
||||
where
|
||||
B: CPUAVX,
|
||||
@@ -156,9 +129,8 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -192,8 +164,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
@@ -232,8 +203,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@@ -269,9 +239,8 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -304,8 +273,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -337,8 +305,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -377,8 +344,7 @@ where
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
@@ -411,8 +377,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
@@ -437,10 +402,6 @@ where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr() as *const module_info_t,
|
||||
@@ -604,8 +565,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
@@ -633,7 +593,6 @@ where
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert!(
|
||||
k & 1 != 0,
|
||||
"invalid galois element: must be odd but is {}",
|
||||
@@ -668,8 +627,8 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(res.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_mul_xp_minus_one(
|
||||
@@ -697,7 +656,7 @@ where
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(res.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_mul_xp_minus_one(
|
||||
@@ -749,7 +708,7 @@ pub fn vec_znx_split_ref<R, A, B: Backend>(
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.take_vec_znx(module, 1, a.size());
|
||||
let (mut buf, _) = scratch.take_vec_znx(n_in.max(n_out), 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
|
||||
@@ -210,9 +210,8 @@ unsafe impl VecZnxBigAddImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -244,8 +243,7 @@ unsafe impl VecZnxBigAddInplaceImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
@@ -285,9 +283,8 @@ unsafe impl VecZnxBigAddSmallImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -319,8 +316,7 @@ unsafe impl VecZnxBigAddSmallInplaceImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
@@ -360,9 +356,8 @@ unsafe impl VecZnxBigSubImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -394,8 +389,7 @@ unsafe impl VecZnxBigSubABInplaceImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -426,8 +420,7 @@ unsafe impl VecZnxBigSubBAInplaceImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -467,9 +460,8 @@ unsafe impl VecZnxBigSubSmallAImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -501,8 +493,7 @@ unsafe impl VecZnxBigSubSmallAInplaceImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -542,9 +533,8 @@ unsafe impl VecZnxBigSubSmallBImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
@@ -576,8 +566,7 @@ unsafe impl VecZnxBigSubSmallBInplaceImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
@@ -602,10 +591,6 @@ unsafe impl VecZnxBigNegateInplaceImpl<FFT64> for FFT64 {
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
module.ptr(),
|
||||
@@ -677,8 +662,7 @@ unsafe impl VecZnxBigAutomorphismImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
@@ -702,11 +686,6 @@ unsafe impl VecZnxBigAutomorphismInplaceImpl<FFT64> for FFT64 {
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
module.ptr(),
|
||||
|
||||
@@ -57,8 +57,8 @@ unsafe impl VecZnxDftAllocImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<FFT64>, n: usize) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,26 +74,31 @@ unsafe impl VecZnxDftToVecZnxBigImpl<FFT64> for FFT64 {
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes());
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n()));
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
module.ptr(),
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
res.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
a.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
(min_size..res.size()).for_each(|j| {
|
||||
res.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,10 +57,18 @@ unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
fn vmp_prepare_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(rows * cols_in) as u64,
|
||||
(cols_out * size) as u64,
|
||||
) as usize
|
||||
@@ -79,8 +87,7 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols_in(),
|
||||
a.cols_in(),
|
||||
@@ -111,7 +118,8 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
let (tmp_bytes, _) =
|
||||
scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
@@ -129,6 +137,7 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -139,6 +148,7 @@ unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
@@ -161,9 +171,8 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
@@ -181,6 +190,7 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.n(),
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
@@ -207,6 +217,7 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -217,6 +228,7 @@ unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
@@ -241,9 +253,8 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
{
|
||||
use crate::hal::api::ZnxInfos;
|
||||
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(b.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert_eq!(b.n(), res.n());
|
||||
assert_eq!(a.n(), res.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
@@ -261,6 +272,7 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.n(),
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
|
||||
Reference in New Issue
Block a user