Add Zn type

This commit is contained in:
Pro7ech
2025-08-21 12:16:53 +02:00
parent ccd94e36cc
commit bf513dc555
129 changed files with 1400 additions and 686 deletions

View File

@@ -5,6 +5,7 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;
mod znx_base;
pub use module::*;
@@ -14,4 +15,5 @@ pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_dft::*;
pub use vmp_pmat::*;
pub use zn::*;
pub use znx_base::*;

View File

@@ -2,18 +2,18 @@ use crate::layouts::{Backend, ScalarZnxToRef, SvpPPolOwned, SvpPPolToMut, SvpPPo
/// Allocates as [crate::layouts::SvpPPol].
pub trait SvpPPolAlloc<B: Backend> {
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B>;
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B>;
}
/// Returns the size in bytes to allocate a [crate::layouts::SvpPPol].
pub trait SvpPPolAllocBytes {
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize;
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::layouts::MatZnx].
/// User must ensure that bytes is memory aligned and that it length is equal to [SvpPPolAllocBytes].
pub trait SvpPPolFromBytes<B: Backend> {
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B>;
}
/// Prepare a [crate::layouts::ScalarZnx] into an [crate::layouts::SvpPPol].

View File

@@ -7,7 +7,7 @@ use crate::{
pub trait VecZnxNormalizeTmpBytes {
/// Returns the minimum number of bytes necessary for normalization.
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize;
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
}
pub trait VecZnxNormalize<B: Backend> {

View File

@@ -7,18 +7,18 @@ use crate::{
/// Allocates as [crate::layouts::VecZnxBig].
pub trait VecZnxBigAlloc<B: Backend> {
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B>;
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
}
/// Returns the size in bytes to allocate a [crate::layouts::VecZnxBig].
pub trait VecZnxBigAllocBytes {
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize;
}
/// Consume a vector of bytes into a [crate::layouts::VecZnxBig].
/// User must ensure that bytes is memory aligned and that it length is equal to [VecZnxBigAllocBytes].
pub trait VecZnxBigFromBytes<B: Backend> {
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
}
#[allow(clippy::too_many_arguments)]
@@ -187,7 +187,7 @@ pub trait VecZnxBigNegateInplace<B: Backend> {
}
pub trait VecZnxBigNormalizeTmpBytes {
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize;
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
}
pub trait VecZnxBigNormalize<B: Backend> {

View File

@@ -3,19 +3,19 @@ use crate::layouts::{
};
pub trait VecZnxDftAlloc<B: Backend> {
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B>;
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
}
pub trait VecZnxDftFromBytes<B: Backend> {
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
}
pub trait VecZnxDftAllocBytes {
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize;
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxDftToVecZnxBigTmpBytes {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize;
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize;
}
pub trait VecZnxDftToVecZnxBig<B: Backend> {

View File

@@ -1,27 +1,19 @@
use crate::layouts::{Backend, MatZnxToRef, Scratch, VecZnxDftToMut, VecZnxDftToRef, VmpPMatOwned, VmpPMatToMut, VmpPMatToRef};
pub trait VmpPMatAlloc<B: Backend> {
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B>;
}
pub trait VmpPMatAllocBytes {
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub trait VmpPMatFromBytes<B: Backend> {
fn vmp_pmat_from_bytes(
&self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B>;
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B>;
}
pub trait VmpPrepareTmpBytes {
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
pub trait VmpPrepare<B: Backend> {
@@ -35,7 +27,6 @@ pub trait VmpPrepare<B: Backend> {
pub trait VmpApplyTmpBytes {
fn vmp_apply_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -81,7 +72,6 @@ pub trait VmpApply<B: Backend> {
pub trait VmpApplyAddTmpBytes {
fn vmp_apply_add_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,

86
poulpy-hal/src/api/zn.rs Normal file
View File

@@ -0,0 +1,86 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Scratch, ZnToMut},
source::Source,
};
pub trait ZnNormalizeInplace<B: Backend> {
/// Normalizes the selected column of `a`.
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut;
}
pub trait ZnFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{basek-1}, 2^{basek-1}\]
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnFillDistF64 {
fn zn_fill_dist_f64<R, D: Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnAddDistF64 {
/// Adds vector sampled according to the provided distribution, scaled by 2^{-k} and bounded to \[-bound, bound\].
fn zn_add_dist_f64<R, D: Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnFillNormal {
fn zn_fill_normal<R>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnAddNormal {
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
fn zn_add_normal<R>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}

View File

@@ -5,3 +5,4 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;

View File

@@ -8,8 +8,8 @@ impl<B> SvpPPolFromBytes<B> for Module<B>
where
B: Backend + SvpPPolFromBytesImpl<B>,
{
fn svp_ppol_from_bytes(&self, n: usize, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
B::svp_ppol_from_bytes_impl(n, cols, bytes)
fn svp_ppol_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> SvpPPolOwned<B> {
B::svp_ppol_from_bytes_impl(self.n(), cols, bytes)
}
}
@@ -17,8 +17,8 @@ impl<B> SvpPPolAlloc<B> for Module<B>
where
B: Backend + SvpPPolAllocImpl<B>,
{
fn svp_ppol_alloc(&self, n: usize, cols: usize) -> SvpPPolOwned<B> {
B::svp_ppol_alloc_impl(n, cols)
fn svp_ppol_alloc(&self, cols: usize) -> SvpPPolOwned<B> {
B::svp_ppol_alloc_impl(self.n(), cols)
}
}
@@ -26,8 +26,8 @@ impl<B> SvpPPolAllocBytes for Module<B>
where
B: Backend + SvpPPolAllocBytesImpl<B>,
{
fn svp_ppol_alloc_bytes(&self, n: usize, cols: usize) -> usize {
B::svp_ppol_alloc_bytes_impl(n, cols)
fn svp_ppol_alloc_bytes(&self, cols: usize) -> usize {
B::svp_ppol_alloc_bytes_impl(self.n(), cols)
}
}

View File

@@ -22,8 +22,8 @@ impl<B> VecZnxNormalizeTmpBytes for Module<B>
where
B: Backend + VecZnxNormalizeTmpBytesImpl<B>,
{
fn vec_znx_normalize_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_normalize_tmp_bytes_impl(self, n)
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
B::vec_znx_normalize_tmp_bytes_impl(self)
}
}

View File

@@ -24,8 +24,8 @@ impl<B> VecZnxBigAlloc<B> for Module<B>
where
B: Backend + VecZnxBigAllocImpl<B>,
{
fn vec_znx_big_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxBigOwned<B> {
B::vec_znx_big_alloc_impl(n, cols, size)
fn vec_znx_big_alloc(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
B::vec_znx_big_alloc_impl(self.n(), cols, size)
}
}
@@ -33,8 +33,8 @@ impl<B> VecZnxBigFromBytes<B> for Module<B>
where
B: Backend + VecZnxBigFromBytesImpl<B>,
{
fn vec_znx_big_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
B::vec_znx_big_from_bytes_impl(n, cols, size, bytes)
fn vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
B::vec_znx_big_from_bytes_impl(self.n(), cols, size, bytes)
}
}
@@ -42,8 +42,8 @@ impl<B> VecZnxBigAllocBytes for Module<B>
where
B: Backend + VecZnxBigAllocBytesImpl<B>,
{
fn vec_znx_big_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
B::vec_znx_big_alloc_bytes_impl(n, cols, size)
fn vec_znx_big_alloc_bytes(&self, cols: usize, size: usize) -> usize {
B::vec_znx_big_alloc_bytes_impl(self.n(), cols, size)
}
}
@@ -283,8 +283,8 @@ impl<B> VecZnxBigNormalizeTmpBytes for Module<B>
where
B: Backend + VecZnxBigNormalizeTmpBytesImpl<B>,
{
fn vec_znx_big_normalize_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_big_normalize_tmp_bytes_impl(self, n)
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
B::vec_znx_big_normalize_tmp_bytes_impl(self)
}
}

View File

@@ -20,8 +20,8 @@ impl<B> VecZnxDftFromBytes<B> for Module<B>
where
B: Backend + VecZnxDftFromBytesImpl<B>,
{
fn vec_znx_dft_from_bytes(&self, n: usize, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
B::vec_znx_dft_from_bytes_impl(n, cols, size, bytes)
fn vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
B::vec_znx_dft_from_bytes_impl(self.n(), cols, size, bytes)
}
}
@@ -29,8 +29,8 @@ impl<B> VecZnxDftAllocBytes for Module<B>
where
B: Backend + VecZnxDftAllocBytesImpl<B>,
{
fn vec_znx_dft_alloc_bytes(&self, n: usize, cols: usize, size: usize) -> usize {
B::vec_znx_dft_alloc_bytes_impl(n, cols, size)
fn vec_znx_dft_alloc_bytes(&self, cols: usize, size: usize) -> usize {
B::vec_znx_dft_alloc_bytes_impl(self.n(), cols, size)
}
}
@@ -38,8 +38,8 @@ impl<B> VecZnxDftAlloc<B> for Module<B>
where
B: Backend + VecZnxDftAllocImpl<B>,
{
fn vec_znx_dft_alloc(&self, n: usize, cols: usize, size: usize) -> VecZnxDftOwned<B> {
B::vec_znx_dft_alloc_impl(n, cols, size)
fn vec_znx_dft_alloc(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
B::vec_znx_dft_alloc_impl(self.n(), cols, size)
}
}
@@ -47,8 +47,8 @@ impl<B> VecZnxDftToVecZnxBigTmpBytes for Module<B>
where
B: Backend + VecZnxDftToVecZnxBigTmpBytesImpl<B>,
{
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self, n: usize) -> usize {
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self, n)
fn vec_znx_dft_to_vec_znx_big_tmp_bytes(&self) -> usize {
B::vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(self)
}
}

View File

@@ -14,8 +14,8 @@ impl<B> VmpPMatAlloc<B> for Module<B>
where
B: Backend + VmpPMatAllocImpl<B>,
{
fn vmp_pmat_alloc(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
B::vmp_pmat_alloc_impl(n, rows, cols_in, cols_out, size)
fn vmp_pmat_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> VmpPMatOwned<B> {
B::vmp_pmat_alloc_impl(self.n(), rows, cols_in, cols_out, size)
}
}
@@ -23,8 +23,8 @@ impl<B> VmpPMatAllocBytes for Module<B>
where
B: Backend + VmpPMatAllocBytesImpl<B>,
{
fn vmp_pmat_alloc_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_pmat_alloc_bytes_impl(n, rows, cols_in, cols_out, size)
fn vmp_pmat_alloc_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_pmat_alloc_bytes_impl(self.n(), rows, cols_in, cols_out, size)
}
}
@@ -32,16 +32,8 @@ impl<B> VmpPMatFromBytes<B> for Module<B>
where
B: Backend + VmpPMatFromBytesImpl<B>,
{
fn vmp_pmat_from_bytes(
&self,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> VmpPMatOwned<B> {
B::vmp_pmat_from_bytes_impl(n, rows, cols_in, cols_out, size, bytes)
fn vmp_pmat_from_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: Vec<u8>) -> VmpPMatOwned<B> {
B::vmp_pmat_from_bytes_impl(self.n(), rows, cols_in, cols_out, size, bytes)
}
}
@@ -49,8 +41,8 @@ impl<B> VmpPrepareTmpBytes for Module<B>
where
B: Backend + VmpPrepareTmpBytesImpl<B>,
{
fn vmp_prepare_tmp_bytes(&self, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_prepare_tmp_bytes_impl(self, n, rows, cols_in, cols_out, size)
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
B::vmp_prepare_tmp_bytes_impl(self, rows, cols_in, cols_out, size)
}
}
@@ -73,7 +65,6 @@ where
{
fn vmp_apply_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -82,7 +73,7 @@ where
b_size: usize,
) -> usize {
B::vmp_apply_tmp_bytes_impl(
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}
@@ -107,7 +98,6 @@ where
{
fn vmp_apply_add_tmp_bytes(
&self,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -116,7 +106,7 @@ where
b_size: usize,
) -> usize {
B::vmp_apply_add_tmp_bytes_impl(
self, n, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
self, res_size, a_size, b_rows, b_cols_in, b_cols_out, b_size,
)
}
}

View File

@@ -0,0 +1,114 @@
use crate::{
api::{ZnAddDistF64, ZnAddNormal, ZnFillDistF64, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace},
layouts::{Backend, Module, Scratch, ZnToMut},
oep::{ZnAddDistF64Impl, ZnAddNormalImpl, ZnFillDistF64Impl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
source::Source,
};
impl<B> ZnNormalizeInplace<B> for Module<B>
where
B: Backend + ZnNormalizeInplaceImpl<B>,
{
fn zn_normalize_inplace<A>(&self, n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut,
{
B::zn_normalize_inplace_impl(n, basek, a, a_col, scratch)
}
}
impl<B> ZnFillUniform for Module<B>
where
B: Backend + ZnFillUniformImpl<B>,
{
fn zn_fill_uniform<R>(&self, n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: ZnToMut,
{
B::zn_fill_uniform_impl(n, basek, res, res_col, k, source);
}
}
impl<B> ZnFillDistF64 for Module<B>
where
B: Backend + ZnFillDistF64Impl<B>,
{
fn zn_fill_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
B::zn_fill_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> ZnAddDistF64 for Module<B>
where
B: Backend + ZnAddDistF64Impl<B>,
{
fn zn_add_dist_f64<R, D: rand::prelude::Distribution<f64>>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut,
{
B::zn_add_dist_f64_impl(n, basek, res, res_col, k, source, dist, bound);
}
}
impl<B> ZnFillNormal for Module<B>
where
B: Backend + ZnFillNormalImpl<B>,
{
fn zn_fill_normal<R>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
B::zn_fill_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
}
}
impl<B> ZnAddNormal for Module<B>
where
B: Backend + ZnAddNormalImpl<B>,
{
fn zn_add_normal<R>(
&self,
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
B::zn_add_normal_impl(n, basek, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -3,7 +3,7 @@ use rug::{Assign, Float};
use crate::{
api::{ZnxInfos, ZnxView, ZnxViewMut, ZnxZero},
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef},
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef},
};
impl<D: DataMut> VecZnx<D> {
@@ -202,3 +202,90 @@ impl<D: DataRef> VecZnx<D> {
});
}
}
impl<D: DataMut> Zn<D> {
pub fn encode_i64(&mut self, basek: usize, k: usize, data: i64, log_max: usize) {
let size: usize = k.div_ceil(basek);
#[cfg(debug_assertions)]
{
let a: Zn<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k.div_ceil(basek)={} > a.size()={}",
size,
a.size()
);
}
let k_rem: usize = basek - (k % basek);
let mut a: Zn<&mut [u8]> = self.to_mut();
(0..a.size()).for_each(|j| a.at_mut(0, j)[0] = 0);
// If 2^{basek} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + k_rem < 63 || k_rem == basek {
a.at_mut(0, size - 1)[0] = data;
} else {
let mask: i64 = (1 << basek) - 1;
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_mut(0, j_rev)[0] = (data >> (j * basek)) & mask;
})
}
// Case where prec % k != 0.
if k_rem != basek {
let steps: usize = size.min(log_max.div_ceil(basek));
(size - steps..size).rev().for_each(|j| {
a.at_mut(0, j)[0] <<= k_rem;
})
}
}
}
impl<D: DataRef> Zn<D> {
pub fn decode_i64(&self, basek: usize, k: usize) -> i64 {
let a: Zn<&[u8]> = self.to_ref();
let size: usize = k.div_ceil(basek);
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
(0..size).for_each(|j| {
let x: i64 = a.at(0, j)[0];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
res = (res << basek) + x;
}
});
res
}
pub fn decode_float(&self, basek: usize) -> Float {
let a: Zn<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (basek * size) as u32;
// 2^{basek}
let base: Float = Float::with_val(prec, (1 << basek) as f64);
let mut res: Float = Float::with_val(prec, (1 << basek) as f64);
// y[i] = sum x[j][i] * 2^{-basek*j}
(0..size).for_each(|i| {
if i == 0 {
res.assign(a.at(0, size - i - 1)[0]);
res /= &base;
} else {
res += Float::with_val(prec, a.at(0, size - i - 1)[0]);
res /= &base;
}
});
res
}
}

View File

@@ -10,6 +10,7 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;
pub use mat_znx::*;
pub use module::*;
@@ -21,6 +22,7 @@ pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_dft::*;
pub use vmp_pmat::*;
pub use zn::*;
pub trait Data = PartialEq + Eq + Sized;
pub trait DataRef = Data + AsRef<[u8]>;

View File

@@ -0,0 +1,255 @@
use std::fmt;
use crate::{
alloc_aligned,
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, WriterTo},
source::Source,
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct Zn<D: Data> {
pub data: D,
pub n: usize,
pub cols: usize,
pub size: usize,
pub max_size: usize,
}
impl<D: DataRef> ToOwnedDeep for Zn<D> {
type Owned = Zn<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
Zn {
data: self.data.as_ref().to_vec(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
impl<D: DataRef> fmt::Debug for Zn<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl<D: Data> ZnxInfos for Zn<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data> ZnxSliceSize for Zn<D> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
impl<D: Data> DataView for Zn<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data> DataViewMut for Zn<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef> ZnxView for Zn<D> {
type Scalar = i64;
}
impl Zn<Vec<u8>> {
pub fn rsh_scratch_space(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
}
impl<D: DataMut> ZnxZero for Zn<D> {
fn zero(&mut self) {
self.raw_mut().fill(0)
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).fill(0);
}
}
impl Zn<Vec<u8>> {
pub fn alloc_bytes(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<i64>()
}
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols, size));
Self {
data,
n,
cols,
size,
max_size: size,
}
}
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::alloc_bytes(n, cols, size));
Self {
data,
n,
cols,
size,
max_size: size,
}
}
}
impl<D: Data> Zn<D> {
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
max_size: size,
}
}
}
impl<D: DataRef> fmt::Display for Zn<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Zn(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}
impl<D: DataMut> FillUniform for Zn<D> {
fn fill_uniform(&mut self, source: &mut Source) {
source.fill_bytes(self.data.as_mut());
}
}
impl<D: DataMut> Reset for Zn<D> {
fn reset(&mut self) {
self.zero();
self.n = 0;
self.cols = 0;
self.size = 0;
self.max_size = 0;
}
}
pub type ZnOwned = Zn<Vec<u8>>;
pub type ZnMut<'a> = Zn<&'a mut [u8]>;
pub type ZnRef<'a> = Zn<&'a [u8]>;
pub trait ZnToRef {
fn to_ref(&self) -> Zn<&[u8]>;
}
impl<D: DataRef> ZnToRef for Zn<D> {
fn to_ref(&self) -> Zn<&[u8]> {
Zn {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
pub trait ZnToMut {
fn to_mut(&mut self) -> Zn<&mut [u8]>;
}
impl<D: DataMut> ZnToMut for Zn<D> {
fn to_mut(&mut self) -> Zn<&mut [u8]> {
Zn {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
impl<D: DataMut> ReaderFrom for Zn<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.n = reader.read_u64::<LittleEndian>()? as usize;
self.cols = reader.read_u64::<LittleEndian>()? as usize;
self.size = reader.read_u64::<LittleEndian>()? as usize;
self.max_size = reader.read_u64::<LittleEndian>()? as usize;
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
let buf: &mut [u8] = self.data.as_mut();
if buf.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("self.data.len()={} != read len={}", buf.len(), len),
));
}
reader.read_exact(&mut buf[..len])?;
Ok(())
}
}
impl<D: DataRef> WriterTo for Zn<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.n as u64)?;
writer.write_u64::<LittleEndian>(self.cols as u64)?;
writer.write_u64::<LittleEndian>(self.size as u64)?;
writer.write_u64::<LittleEndian>(self.max_size as u64)?;
let buf: &[u8] = self.data.as_ref();
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
writer.write_all(buf)?;
Ok(())
}
}

View File

@@ -5,6 +5,7 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;
pub use module::*;
pub use scratch::*;
@@ -13,3 +14,4 @@ pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_dft::*;
pub use vmp_pmat::*;
pub use zn::*;

View File

@@ -10,7 +10,7 @@ use crate::{
/// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxNormalizeTmpBytesImpl<B: Backend> {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)

View File

@@ -263,7 +263,7 @@ pub unsafe trait VecZnxBigNegateInplaceImpl<B: Backend> {
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxBigNormalizeTmpBytesImpl<B: Backend> {
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
fn vec_znx_big_normalize_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)

View File

@@ -32,7 +32,7 @@ pub unsafe trait VecZnxDftAllocBytesImpl<B: Backend> {
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftToVecZnxBigTmpBytesImpl<B: Backend> {
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>, n: usize) -> usize;
fn vec_znx_dft_to_vec_znx_big_tmp_bytes_impl(module: &Module<B>) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)

View File

@@ -38,14 +38,7 @@ pub unsafe trait VmpPMatFromBytesImpl<B: Backend> {
/// * See TODO for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VmpPrepareTmpBytesImpl<B: Backend> {
fn vmp_prepare_tmp_bytes_impl(
module: &Module<B>,
n: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
) -> usize;
fn vmp_prepare_tmp_bytes_impl(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
@@ -67,7 +60,6 @@ pub unsafe trait VmpPMatPrepareImpl<B: Backend> {
pub unsafe trait VmpApplyTmpBytesImpl<B: Backend> {
fn vmp_apply_tmp_bytes_impl(
module: &Module<B>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,
@@ -97,7 +89,6 @@ pub unsafe trait VmpApplyImpl<B: Backend> {
pub unsafe trait VmpApplyAddTmpBytesImpl<B: Backend> {
fn vmp_apply_add_tmp_bytes_impl(
module: &Module<B>,
n: usize,
res_size: usize,
a_size: usize,
b_rows: usize,

97
poulpy-hal/src/oep/zn.rs Normal file
View File

@@ -0,0 +1,97 @@
use rand_distr::Distribution;
use crate::{
layouts::{Backend, Scratch, ZnToMut},
source::Source,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [zn_normalize_base2k_ref](https://github.com/phantomzone-org/spqlios-arithmetic/blob/32a3f5fcce9863b58e949f2dfd5abc1bfbaa09b4/spqlios/arithmetic/zn64.c#L9) for reference code.
/// * See [crate::api::ZnxNormalizeInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeInplaceImpl<B: Backend> {
fn zn_normalize_inplace_impl<A>(n: usize, basek: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillUniformImpl<B: Backend> {
fn zn_fill_uniform_impl<R>(n: usize, basek: usize, res: &mut R, res_col: usize, k: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnFillDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillDistF64Impl<B: Backend> {
fn zn_fill_dist_f64_impl<R, D: Distribution<f64>>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnAddDistF64] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnAddDistF64Impl<B: Backend> {
fn zn_add_dist_f64_impl<R, D: Distribution<f64>>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
dist: D,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnFillNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillNormalImpl<B: Backend> {
fn zn_fill_normal_impl<R>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [crate::api::ZnAddNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnAddNormalImpl<B: Backend> {
fn zn_add_normal_impl<R>(
n: usize,
basek: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}

View File

@@ -51,14 +51,13 @@ where
let mut scratch = ScratchOwned::alloc(
module.vmp_apply_tmp_bytes(
n,
res_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
) | module.vec_znx_big_normalize_tmp_bytes(n),
) | module.vec_znx_big_normalize_tmp_bytes(),
);
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, a_cols, a_size);
@@ -67,10 +66,10 @@ where
a.at_mut(i, a_size - 1)[i + 1] = 1;
});
let mut vmp: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut vmp: VmpPMat<Vec<u8>, B> = module.vmp_pmat_alloc(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(n, mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(n, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(mat_cols_out, mat_size);
let mut mat: MatZnx<Vec<u8>> = MatZnx::alloc(n, mat_rows, mat_cols_in, mat_cols_out, mat_size);
@@ -86,7 +85,7 @@ where
module.vmp_prepare(&mut vmp, &mat, scratch.borrow());
let mut a_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(n, a_cols, a_size);
let mut a_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft_from_vec_znx(1, 0, &mut a_dft, i, &a, i);
});