mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add Zn type
This commit is contained in:
@@ -5,6 +5,7 @@ mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
mod znx_base;
|
||||
|
||||
pub use module::*;
|
||||
@@ -14,4 +15,5 @@ pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
pub use zn::*;
|
||||
pub use znx_base::*;
|
||||
|
||||
@@ -2,18 +2,18 @@ use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPo
|
||||
|
||||
/// Allocates as [crate::layouts::SvpPPol].
|
||||
pub trait SvpPPolAlloc<B: Backend> {
|
||||
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B>;
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::layouts::SvpPPol].
|
||||
pub trait SvpPPolAllocBytes {
|
||||
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize;
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::layouts::MatZnx].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
|
||||
pub trait SvpPPolFromBytes<B: Backend> {
|
||||
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
|
||||
}
|
||||
|
||||
/// Prepare a [crate::layouts::ScalarZnx] into an [crate::layouts::SvpPPol].
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::{
|
||||
|
||||
pub trait VecZnxNormalizeTmpBytes {
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize;
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxNormalize<B: Backend> {
|
||||
|
||||
@@ -7,18 +7,18 @@ use crate::{
|
||||
|
||||
/// Allocates as [crate::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
/// Returns the size in bytes to allocate a [crate::layouts::VecZnxBig].
|
||||
pub trait VecZnxBigAllocBytes {
|
||||
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// Consume a vector of bytes into a [crate::layouts::VecZnxBig].
|
||||
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
|
||||
pub trait VecZnxBigFromBytes<B: Backend> {
|
||||
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -187,7 +187,7 @@ pub trait VecZnxBigNegateInplace<B: Backend> {
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalizeTmpBytes {
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize;
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigNormalize<B: Backend> {
|
||||
|
||||
@@ -3,19 +3,19 @@ use crate::layouts::{
|
||||
};
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftFromBytes<B: Backend> {
|
||||
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftAllocBytes {
|
||||
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBigTmpBytes {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize;
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToVecZnxBig<B: Backend> {
|
||||
|
||||
@@ -1,27 +1,19 @@
|
||||
use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef};
|
||||
|
||||
pub trait VmpPMatAlloc<B: Backend> {
|
||||
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPMatAllocBytes {
|
||||
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPMatFromBytes<B: Backend> {
|
||||
fn vmp_pmat_from_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B>;
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
|
||||
}
|
||||
|
||||
pub trait VmpPrepareTmpBytes {
|
||||
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VmpPrepare<B: Backend> {
|
||||
@@ -35,7 +27,6 @@ pub trait VmpPrepare<B: Backend> {
|
||||
pub trait VmpApplyTmpBytes {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -81,7 +72,6 @@ pub trait VmpApply<B: Backend> {
|
||||
pub trait VmpApplyAddTmpBytes {
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
|
||||
86
poulpy-hal/src/api/zn.rs
Normal file
86
poulpy-hal/src/api/zn.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Scratch, ZnToMut},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
pub trait ZnNormalizeInplace<B: Backend> {
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: ZnToMut;
|
||||
}
|
||||
|
||||
pub trait ZnFillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnFillDistF64 {
|
||||
fn zn_fill_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnAddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
|
||||
fn zn_add_dist_f64<R, D: Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnFillNormal {
|
||||
fn zn_fill_normal<R>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub trait ZnAddNormal {
|
||||
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn zn_add_normal<R>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
@@ -5,3 +5,4 @@ mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
@@ -8,8 +8,8 @@ impl<B> SvpPPolFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolFromBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(n, cols, bytes)
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ impl<B> SvpPPolAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(n, cols)
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ impl<B> SvpPPolAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(n, cols)
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ impl<B> VecZnxNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_normalize_tmp_bytes_impl(self, n)
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_normalize_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ impl<B> VecZnxBigAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(n, cols, size)
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ impl<B> VecZnxBigFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(n, cols, size, bytes)
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,8 +42,8 @@ impl<B> VecZnxBigAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size)
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,8 +283,8 @@ impl<B> VecZnxBigNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_big_normalize_tmp_bytes_impl(self, n)
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_big_normalize_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ impl<B> VecZnxDftFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes)
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ impl<B> VecZnxDftAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size)
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ impl<B> VecZnxDftAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(n, cols, size)
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n)
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ impl<B> VmpPMatAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ impl<B> VmpPMatAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,16 +32,8 @@ impl<B> VmpPMatFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatFromBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_from_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,8 +41,8 @@ impl<B> VmpPrepareTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPrepareTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +65,6 @@ where
|
||||
{
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -82,7 +73,7 @@ where
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_tmp_bytes_impl(
|
||||
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -107,7 +98,6 @@ where
|
||||
{
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -116,7 +106,7 @@ where
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_add_tmp_bytes_impl(
|
||||
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
114
poulpy-hal/src/delegates/zn.rs
Normal file
114
poulpy-hal/src/delegates/zn.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use crate::{
|
||||
api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace},
|
||||
layouts::{Backend, Module, Scratch, ZnToMut},
|
||||
oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
impl<B> ZnNormalizeInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + ZnNormalizeInplaceImpl<B>,
|
||||
{
|
||||
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: ZnToMut,
|
||||
{
|
||||
B::zn_normalize_inplace_impl(n, basek, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillUniformImpl<B>,
|
||||
{
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_uniform_impl(n, basek, res, res_col, k, source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillDistF64Impl<B>,
|
||||
{
|
||||
fn zn_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnAddDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + ZnAddDistF64Impl<B>,
|
||||
{
|
||||
fn zn_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillNormal for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillNormalImpl<B>,
|
||||
{
|
||||
fn zn_fill_normal<R>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnAddNormal for Module<B>
|
||||
where
|
||||
B: Backend + ZnAddNormalImpl<B>,
|
||||
{
|
||||
fn zn_add_normal<R>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_add_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ use rug::{Assign, Float};
|
||||
|
||||
use crate::{
|
||||
api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef},
|
||||
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef},
|
||||
};
|
||||
|
||||
impl<D: DataMut> VecZnx<D> {
|
||||
@@ -202,3 +202,90 @@ impl<D: DataRef> VecZnx<D> {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> Zn<D> {
|
||||
pub fn encode_i64(&mut self, basek: usize, k: usize, data: i64, log_max: usize) {
|
||||
let size: usize = k.div_ceil(basek);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a: Zn<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
size <= a.size(),
|
||||
"invalid argument k.div_ceil(basek)={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
let k_rem: usize = basek - (k % basek);
|
||||
let mut a: Zn<&mut [u8]> = self.to_mut();
|
||||
(0..a.size()).for_each(|j| a.at_mut(0, j)[0] = 0);
|
||||
|
||||
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + k_rem < 63 || k_rem == basek {
|
||||
a.at_mut(0, size - 1)[0] = data;
|
||||
} else {
|
||||
let mask: i64 = (1 << basek) - 1;
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_mut(0, j_rev)[0] = (data >> (j * basek)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if k_rem != basek {
|
||||
let steps: usize = size.min(log_max.div_ceil(basek));
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(0, j)[0] <<= k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> Zn<D> {
|
||||
pub fn decode_i64(&self, basek: usize, k: usize) -> i64 {
|
||||
let a: Zn<&[u8]> = self.to_ref();
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let mut res: i64 = 0;
|
||||
let rem: usize = basek - (k % basek);
|
||||
(0..size).for_each(|j| {
|
||||
let x: i64 = a.at(0, j)[0];
|
||||
if j == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
res = (res << k_rem) + (x >> rem);
|
||||
} else {
|
||||
res = (res << basek) + x;
|
||||
}
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
pub fn decode_float(&self, basek: usize) -> Float {
|
||||
let a: Zn<&[u8]> = self.to_ref();
|
||||
let size: usize = a.size();
|
||||
let prec: u32 = (basek * size) as u32;
|
||||
|
||||
// 2^{basek}
|
||||
let base: Float = Float::with_val(prec, (1 << basek) as f64);
|
||||
let mut res: Float = Float::with_val(prec, (1 << basek) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-basek*j}
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
res.assign(a.at(0, size - i - 1)[0]);
|
||||
res /= &base;
|
||||
} else {
|
||||
res += Float::with_val(prec, a.at(0, size - i - 1)[0]);
|
||||
res /= &base;
|
||||
}
|
||||
});
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
pub use mat_znx::*;
|
||||
pub use module::*;
|
||||
@@ -21,6 +22,7 @@ pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
pub use zn::*;
|
||||
|
||||
pub trait Data = PartialEq + Eq + Sized;
|
||||
pub trait DataRef = Data + AsRef<[u8]>;
|
||||
|
||||
255
poulpy-hal/src/layouts/zn.rs
Normal file
255
poulpy-hal/src/layouts/zn.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, WriterTo},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use rand::RngCore;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub struct Zn<D: Data> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub size: usize,
|
||||
pub max_size: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> ToOwnedDeep for Zn<D> {
|
||||
type Owned = Zn<Vec<u8>>;
|
||||
fn to_owned_deep(&self) -> Self::Owned {
|
||||
Zn {
|
||||
data: self.data.as_ref().to_vec(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Debug for Zn<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxInfos for Zn<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for Zn<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataView for Zn<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataViewMut for Zn<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for Zn<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl Zn<Vec<u8>> {
|
||||
pub fn rsh_scratch_space(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnxZero for Zn<D> {
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(0)
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Zn<Vec<u8>> {
|
||||
pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize {
|
||||
n * cols * size * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols, size));
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::alloc_bytes(n, cols, size));
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> Zn<D> {
|
||||
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
max_size: size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> fmt::Display for Zn<D> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"Zn(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for Zn<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
source.fill_bytes(self.data.as_mut());
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> Reset for Zn<D> {
|
||||
fn reset(&mut self) {
|
||||
self.zero();
|
||||
self.n = 0;
|
||||
self.cols = 0;
|
||||
self.size = 0;
|
||||
self.max_size = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub type ZnOwned = Zn<Vec<u8>>;
|
||||
pub type ZnMut<'a> = Zn<&'a mut [u8]>;
|
||||
pub type ZnRef<'a> = Zn<&'a [u8]>;
|
||||
|
||||
pub trait ZnToRef {
|
||||
fn to_ref(&self) -> Zn<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnToRef for Zn<D> {
|
||||
fn to_ref(&self) -> Zn<&[u8]> {
|
||||
Zn {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnToMut {
|
||||
fn to_mut(&mut self) -> Zn<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnToMut for Zn<D> {
|
||||
fn to_mut(&mut self) -> Zn<&mut [u8]> {
|
||||
Zn {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ReaderFrom for Zn<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.max_size = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for Zn<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.size as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.max_size as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
pub use module::*;
|
||||
pub use scratch::*;
|
||||
@@ -13,3 +14,4 @@ pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp_pmat::*;
|
||||
pub use zn::*;
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
/// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
|
||||
@@ -263,7 +263,7 @@ pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
|
||||
@@ -32,7 +32,7 @@ pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
|
||||
@@ -38,14 +38,7 @@ pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
|
||||
/// * See TODO for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
|
||||
fn vmp_prepare_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> usize;
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
@@ -67,7 +60,6 @@ pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
|
||||
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -97,7 +89,6 @@ pub unsafe trait VmpApplyImpl<B: Backend> {
|
||||
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
|
||||
97
poulpy-hal/src/oep/zn.rs
Normal file
97
poulpy-hal/src/oep/zn.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use rand_distr::Distribution;
|
||||
|
||||
use crate::{
|
||||
layouts::{Backend, Scratch, ZnToMut},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code.
|
||||
/// * See [crate::api::ZnxNormalizeInplace] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnNormalizeInplaceImpl<B: Backend> {
|
||||
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: ZnToMut;
|
||||
}
|
||||
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnFillUniform] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnFillUniformImpl<B: Backend> {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnFillDistF64] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnFillDistF64Impl<B: Backend> {
|
||||
fn zn_fill_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnAddDistF64] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnAddDistF64Impl<B: Backend> {
|
||||
fn zn_add_dist_f64_impl<R, D: Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnFillNormal] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnFillNormalImpl<B: Backend> {
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||
/// * See [crate::api::ZnAddNormal] for corresponding public API.
|
||||
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||
pub unsafe trait ZnAddNormalImpl<B: Backend> {
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut;
|
||||
}
|
||||
@@ -51,14 +51,13 @@ where
|
||||
|
||||
let mut scratch = ScratchOwned::alloc(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
n,
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(n),
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, a_cols, a_size);
|
||||
@@ -67,10 +66,10 @@ where
|
||||
a.at_mut(i, a_size - 1)[i + 1] = 1;
|
||||
});
|
||||
|
||||
let mut vmp: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
let mut vmp: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(n, mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(n, mat_cols_out, mat_size);
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(mat_cols_out, mat_size);
|
||||
|
||||
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
@@ -86,7 +85,7 @@ where
|
||||
|
||||
module.vmp_prepare(&mut vmp, &mat, scratch.borrow());
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(n, a_cols, a_size);
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user