mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add Zn type
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod module;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod svp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx;
|
||||
#[allow(dead_code)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx_big;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx_dft;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vmp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod znx;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod module;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod svp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx;
|
||||
#[allow(dead_code)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx_big;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vec_znx_dft;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod vmp;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod zn64;
|
||||
#[allow(non_camel_case_types)]
|
||||
pub mod znx;
|
||||
|
||||
@@ -103,7 +103,6 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -114,6 +113,7 @@ unsafe extern "C" {
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
pub unsafe fn vec_znx_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
@@ -93,13 +93,12 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -113,7 +112,6 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
n: u64,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
@@ -127,7 +125,7 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
|
||||
@@ -43,7 +43,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE, n: u64) -> u64;
|
||||
pub unsafe fn vec_znx_idft_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_idft_tmp_a(
|
||||
|
||||
@@ -79,7 +79,6 @@ unsafe extern "C" {
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module: *const MODULE,
|
||||
nn: u64,
|
||||
res_size: u64,
|
||||
a_size: u64,
|
||||
nrows: u64,
|
||||
@@ -99,5 +98,5 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nn: u64, nrows: u64, ncols: u64) -> u64;
|
||||
pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
|
||||
}
|
||||
|
||||
13
poulpy-backend/src/cpu_spqlios/ffi/zn64.rs
Normal file
13
poulpy-backend/src/cpu_spqlios/ffi/zn64.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn zn64_normalize_base2k_ref(
|
||||
n: u64,
|
||||
base2k: u64,
|
||||
res: *mut i64,
|
||||
res_size: u64,
|
||||
res_sl: u64,
|
||||
a: *const i64,
|
||||
a_size: u64,
|
||||
a_sl: u64,
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
@@ -5,6 +5,7 @@ mod vec_znx;
|
||||
mod vec_znx_big;
|
||||
mod vec_znx_dft;
|
||||
mod vmp_pmat;
|
||||
mod zn;
|
||||
|
||||
pub use module::FFT64;
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ use crate::cpu_spqlios::{
|
||||
};
|
||||
|
||||
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t, n as u64) as usize }
|
||||
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr() as *const module_info_t) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,12 +54,11 @@ where
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n()));
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr() as *const module_info_t,
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
@@ -88,12 +87,11 @@ where
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes(a.n()));
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr() as *const module_info_t,
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
|
||||
@@ -569,8 +569,8 @@ unsafe impl VecZnxBigNegateInplaceImpl<Self> for FFT64 {
|
||||
}
|
||||
|
||||
unsafe impl VecZnxBigNormalizeTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -598,11 +598,10 @@ where
|
||||
assert_eq!(res.n(), a.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes(a.n()));
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_big_normalize_tmp_bytes());
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
module.ptr(),
|
||||
a.n() as u64,
|
||||
basek as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
|
||||
@@ -36,8 +36,8 @@ unsafe impl VecZnxDftAllocImpl<Self> for FFT64 {
|
||||
}
|
||||
|
||||
unsafe impl VecZnxDftToVecZnxBigTmpBytesImpl<Self> for FFT64 {
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<Self>, n: usize) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr(), n as u64) as usize }
|
||||
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<Self>) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(module.ptr()) as usize }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ unsafe impl VecZnxDftToVecZnxBigImpl<Self> for FFT64 {
|
||||
assert_eq!(res.n(), a.n())
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes(a.n()));
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vec_znx_dft_to_vec_znx_big_tmp_bytes());
|
||||
|
||||
let min_size: usize = res.size().min(a.size());
|
||||
|
||||
|
||||
@@ -41,18 +41,10 @@ unsafe impl VmpPMatAllocImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
unsafe impl VmpPrepareTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_prepare_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
) -> usize {
|
||||
fn vmp_prepare_tmp_bytes_impl(module: &Module<FFT64>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_prepare_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(rows * cols_in) as u64,
|
||||
(cols_out * size) as u64,
|
||||
) as usize
|
||||
@@ -102,8 +94,7 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) =
|
||||
scratch.take_slice(module.vmp_prepare_tmp_bytes(res.n(), a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_prepare_tmp_bytes(a.rows(), a.cols_in(), a.cols_out(), a.size()));
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
@@ -121,7 +112,6 @@ unsafe impl VmpPMatPrepareImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -132,7 +122,6 @@ unsafe impl VmpApplyTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
@@ -174,7 +163,6 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.n(),
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
@@ -201,7 +189,6 @@ unsafe impl VmpApplyImpl<FFT64> for FFT64 {
|
||||
unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
fn vmp_apply_add_tmp_bytes_impl(
|
||||
module: &Module<FFT64>,
|
||||
n: usize,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
@@ -212,7 +199,6 @@ unsafe impl VmpApplyAddTmpBytesImpl<FFT64> for FFT64 {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
module.ptr(),
|
||||
n as u64,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
@@ -254,7 +240,6 @@ unsafe impl VmpApplyAddImpl<FFT64> for FFT64 {
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(module.vmp_apply_tmp_bytes(
|
||||
res.n(),
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
|
||||
201
poulpy-backend/src/cpu_spqlios/fft64/zn.rs
Normal file
201
poulpy-backend/src/cpu_spqlios/fft64/zn.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use poulpy_hal::{
|
||||
api::{TakeSlice, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
|
||||
layouts::{Scratch, Zn, ZnToMut},
|
||||
oep::{
|
||||
TakeSliceImpl, ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl,
|
||||
ZnNormalizeInplaceImpl,
|
||||
},
|
||||
source::Source,
|
||||
};
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::cpu_spqlios::{FFT64, ffi::zn64};
|
||||
|
||||
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64
|
||||
where
|
||||
Self: TakeSliceImpl<Self>,
|
||||
{
|
||||
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
|
||||
where
|
||||
A: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = a.to_mut();
|
||||
|
||||
let (tmp_bytes, _) = scratch.take_slice(n * size_of::<i64>());
|
||||
|
||||
unsafe {
|
||||
zn64::zn64_normalize_base2k_ref(
|
||||
n as u64,
|
||||
basek as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillUniformImpl<Self> for FFT64 {
|
||||
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
|
||||
where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
let base2k: u64 = 1 << basek;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
(0..k.div_ceil(basek)).for_each(|j| {
|
||||
a.at_mut(res_col, j)[..n]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillDistF64Impl<Self> for FFT64 {
|
||||
fn zn_fill_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddDistF64Impl<Self> for FFT64 {
|
||||
fn zn_add_dist_f64_impl<R, D: rand::prelude::Distribution<f64>>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
let mut a: Zn<&mut [u8]> = res.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = k.div_ceil(basek) - 1;
|
||||
let basek_rem: usize = (limb + 1) * basek - k;
|
||||
|
||||
if basek_rem != 0 {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << basek_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(res_col, limb)[..n].iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnFillNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnFillDistF64Impl<Self>,
|
||||
{
|
||||
fn zn_fill_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_fill_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl ZnAddNormalImpl<Self> for FFT64
|
||||
where
|
||||
Self: ZnAddDistF64Impl<Self>,
|
||||
{
|
||||
fn zn_add_normal_impl<R>(
|
||||
n: usize,
|
||||
basek: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
k: usize,
|
||||
source: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
R: ZnToMut,
|
||||
{
|
||||
Self::zn_add_dist_f64_impl(
|
||||
n,
|
||||
basek,
|
||||
res,
|
||||
res_col,
|
||||
k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user