mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add Zn type
This commit is contained in:
@@ -5,3 +5,4 @@ mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
@@ -8,8 +8,8 @@ impl<B> SvpPPolFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolFromBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(n, cols, bytes)
|
||||
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ impl<B> SvpPPolAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(n, cols)
|
||||
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
|
||||
B::svp_ppol_alloc_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ impl<B> SvpPPolAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + SvpPPolAllocBytesImpl<B>,
|
||||
{
|
||||
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(n, cols)
|
||||
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
|
||||
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ impl<B> VecZnxNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_normalize_tmp_bytes_impl(self, n)
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_normalize_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ impl<B> VecZnxBigAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(n, cols, size)
|
||||
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ impl<B> VecZnxBigFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(n, cols, size, bytes)
|
||||
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,8 +42,8 @@ impl<B> VecZnxBigAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(n, cols, size)
|
||||
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,8 +283,8 @@ impl<B> VecZnxBigNormalizeTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxBigNormalizeTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_big_normalize_tmp_bytes_impl(self, n)
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_big_normalize_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ impl<B> VecZnxDftFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftFromBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes)
|
||||
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ impl<B> VecZnxDftAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(n, cols, size)
|
||||
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
|
||||
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ impl<B> VecZnxDftAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftAllocImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(n, cols, size)
|
||||
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
|
||||
{
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n)
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
|
||||
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ impl<B> VmpPMatAlloc<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
|
||||
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ impl<B> VmpPMatAllocBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatAllocBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
|
||||
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,16 +32,8 @@ impl<B> VmpPMatFromBytes<B> for Module<B>
|
||||
where
|
||||
B: Backend + VmpPMatFromBytesImpl<B>,
|
||||
{
|
||||
fn vmp_pmat_from_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
|
||||
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
|
||||
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,8 +41,8 @@ impl<B> VmpPrepareTmpBytes for Module<B>
|
||||
where
|
||||
B: Backend + VmpPrepareTmpBytesImpl<B>,
|
||||
{
|
||||
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +65,6 @@ where
|
||||
{
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -82,7 +73,7 @@ where
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_tmp_bytes_impl(
|
||||
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -107,7 +98,6 @@ where
|
||||
{
|
||||
fn vmp_apply_add_tmp_bytes(
|
||||
&self,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -116,7 +106,7 @@ where
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
B::vmp_apply_add_tmp_bytes_impl(
|
||||
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
114
poulpy-hal/src/delegates/zn.rs
Normal file
114
poulpy-hal/src/delegates/zn.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use crate::{
|
||||
api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace},
|
||||
layouts::{Backend, Module, Scratch, ZnToMut},
|
||||
oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
impl<B> ZnNormalizeInplace<B> for Module<B>
|
||||
where
|
||||
B: Backend + ZnNormalizeInplaceImpl<B>,
|
||||
{
|
||||
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
|
||||
where
|
||||
A: ZnToMut,
|
||||
{
|
||||
B::zn_normalize_inplace_impl(n, basek, a, a_col, scratch)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillUniform for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillUniformImpl<B>,
|
||||
{
|
||||
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_uniform_impl(n, basek, res, res_col, k, source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillDistF64Impl<B>,
|
||||
{
|
||||
fn zn_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnAddDistF64 for Module<B>
|
||||
where
|
||||
B: Backend + ZnAddDistF64Impl<B>,
|
||||
{
|
||||
fn zn_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnFillNormal for Module<B>
|
||||
where
|
||||
B: Backend + ZnFillNormalImpl<B>,
|
||||
{
|
||||
fn zn_fill_normal<R>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_fill_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ZnAddNormal for Module<B>
|
||||
where
|
||||
B: Backend + ZnAddNormalImpl<B>,
|
||||
{
|
||||
fn zn_add_normal<R>(
|
||||
&self,
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
B::zn_add_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user