mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
multiples fixes to base2k, including svp to take into account column interleaving
This commit is contained in:
@@ -20,7 +20,7 @@ fn main() {
|
|||||||
let mut source: Source = Source::new(seed);
|
let mut source: Source = Source::new(seed);
|
||||||
|
|
||||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||||
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar(1);
|
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||||
|
|
||||||
// Buffer to store s in the DFT domain
|
// Buffer to store s in the DFT domain
|
||||||
|
|||||||
@@ -39,8 +39,10 @@ unsafe extern "C" {
|
|||||||
module: *const MODULE,
|
module: *const MODULE,
|
||||||
res: *const VEC_ZNX_DFT,
|
res: *const VEC_ZNX_DFT,
|
||||||
res_size: u64,
|
res_size: u64,
|
||||||
|
res_cols: u64,
|
||||||
ppol: *const SVP_PPOL,
|
ppol: *const SVP_PPOL,
|
||||||
a: *const VEC_ZNX_DFT,
|
a: *const VEC_ZNX_DFT,
|
||||||
a_size: u64,
|
a_size: u64,
|
||||||
|
a_cols: u64,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ impl Scratch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tmp_scalar_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
pub fn tmp_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
@@ -188,6 +188,24 @@ impl Scratch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tmp_scalar<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||||
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
||||||
|
|
||||||
|
(
|
||||||
|
ScalarZnx::from_data(take_slice, module.n(), cols),
|
||||||
|
Self::new(rem_slice),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tmp_scalar_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||||
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
||||||
|
|
||||||
|
(
|
||||||
|
ScalarZnxDft::from_data(take_slice, module.n(), cols),
|
||||||
|
Self::new(rem_slice),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tmp_vec_znx_dft<B: Backend>(
|
pub fn tmp_vec_znx_dft<B: Backend>(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<B>,
|
module: &Module<B>,
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_tmp_bytes(
|
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||||
res.size(),
|
res.size(),
|
||||||
a.size(),
|
a.size(),
|
||||||
b.rows(),
|
b.rows(),
|
||||||
|
|||||||
@@ -98,24 +98,34 @@ impl<D: From<Vec<u8>>> ScalarZnx<D> {
|
|||||||
|
|
||||||
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
||||||
|
|
||||||
|
pub(crate) fn bytes_of_scalar_znx<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||||
|
ScalarZnxOwned::bytes_of::<i64>(module.n(), cols)
|
||||||
|
}
|
||||||
|
|
||||||
pub trait ScalarZnxAlloc {
|
pub trait ScalarZnxAlloc {
|
||||||
fn bytes_of_scalar(&self, cols: usize) -> usize;
|
fn bytes_of_scalar_znx(&self, cols: usize) -> usize;
|
||||||
fn new_scalar(&self, cols: usize) -> ScalarZnxOwned;
|
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned;
|
||||||
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> ScalarZnxAlloc for Module<B> {
|
impl<B: Backend> ScalarZnxAlloc for Module<B> {
|
||||||
fn bytes_of_scalar(&self, cols: usize) -> usize {
|
fn bytes_of_scalar_znx(&self, cols: usize) -> usize {
|
||||||
ScalarZnxOwned::bytes_of::<i64>(self.n(), cols)
|
ScalarZnxOwned::bytes_of::<i64>(self.n(), cols)
|
||||||
}
|
}
|
||||||
fn new_scalar(&self, cols: usize) -> ScalarZnxOwned {
|
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned {
|
||||||
ScalarZnxOwned::new::<i64>(self.n(), cols)
|
ScalarZnxOwned::new::<i64>(self.n(), cols)
|
||||||
}
|
}
|
||||||
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||||
ScalarZnxOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
|
ScalarZnxOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D> ScalarZnx<D> {
|
||||||
|
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||||
|
Self { data, n, cols }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait ScalarZnxToRef {
|
pub trait ScalarZnxToRef {
|
||||||
fn to_ref(&self) -> ScalarZnx<&[u8]>;
|
fn to_ref(&self) -> ScalarZnx<&[u8]>;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ impl<D: AsRef<[u8]>> ZnxView for ScalarZnxDft<D, FFT64> {
|
|||||||
type Scalar = f64;
|
type Scalar = f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn bytes_of_scalar_znx_dft<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||||
|
ScalarZnxDftOwned::bytes_of(module, cols)
|
||||||
|
}
|
||||||
|
|
||||||
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
||||||
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
|
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
|
||||||
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
|
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
|
||||||
@@ -79,6 +83,17 @@ impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D, B: Backend> ScalarZnxDft<D, B> {
|
||||||
|
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
n,
|
||||||
|
cols,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
||||||
|
|
||||||
pub trait ScalarZnxDftToRef<B: Backend> {
|
pub trait ScalarZnxDftToRef<B: Backend> {
|
||||||
|
|||||||
@@ -71,9 +71,11 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
|
res.cols() as u64,
|
||||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
|
b.cols() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -90,9 +92,11 @@ impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
|
res.cols() as u64,
|
||||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
|
res.cols() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use crate::ffi::vec_znx_big;
|
|||||||
use crate::znx_base::{ZnxInfos, ZnxView};
|
use crate::znx_base::{ZnxInfos, ZnxView};
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned};
|
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
use std::{cmp::min, fmt};
|
||||||
|
|
||||||
pub struct VecZnxBig<D, B: Backend> {
|
pub struct VecZnxBig<D, B: Backend> {
|
||||||
data: D,
|
data: D,
|
||||||
@@ -162,3 +163,38 @@ impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
writeln!(
|
||||||
|
f,
|
||||||
|
"VecZnx(n={}, cols={}, size={})",
|
||||||
|
self.n, self.cols, self.size
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for col in 0..self.cols {
|
||||||
|
writeln!(f, "Column {}:", col)?;
|
||||||
|
for size in 0..self.size {
|
||||||
|
let coeffs = self.at(col, size);
|
||||||
|
write!(f, " Size {}: [", size)?;
|
||||||
|
|
||||||
|
let max_show = 100;
|
||||||
|
let show_count = coeffs.len().min(max_show);
|
||||||
|
|
||||||
|
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
write!(f, ", ")?;
|
||||||
|
}
|
||||||
|
write!(f, "{}", coeff)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if coeffs.len() > max_show {
|
||||||
|
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
writeln!(f, "]")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -528,7 +528,7 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
// assert_alignement(tmp_bytes.as_ptr());
|
// assert_alignement(tmp_bytes.as_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
|
let (tmp_bytes, _) = scratch.tmp_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
|
||||||
&self,
|
&self,
|
||||||
));
|
));
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
let mut res_mut = res.to_mut();
|
let mut res_mut = res.to_mut();
|
||||||
let a_ref = a.to_ref();
|
let a_ref = a.to_ref();
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_idft_tmp_bytes());
|
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
|
||||||
|
|
||||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||||
|
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
|||||||
assert_eq!(res.n(), self.n());
|
assert_eq!(res.n(), self.n());
|
||||||
}
|
}
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes());
|
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_normalize_base2k(
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
@@ -203,7 +203,7 @@ impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
|||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
}
|
}
|
||||||
|
|
||||||
let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vec_znx_normalize_tmp_bytes());
|
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_normalize_base2k(
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ where
|
|||||||
let k_rem: usize = k % log_base2k;
|
let k_rem: usize = k % log_base2k;
|
||||||
|
|
||||||
if k_rem != 0 {
|
if k_rem != 0 {
|
||||||
let (carry, _) = scratch.tmp_scalar_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
|
let (carry, _) = scratch.tmp_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||||
|
|||||||
Reference in New Issue
Block a user