mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
rework as discussed
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned};
|
||||
use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
use sampling::source::Source;
|
||||
|
||||
// pub const SCALAR_ZNX_ROWS: usize = 1;
|
||||
// pub const SCALAR_ZNX_SIZE: usize = 1;
|
||||
|
||||
pub struct Scalar<D> {
|
||||
data: D,
|
||||
n: usize,
|
||||
@@ -30,7 +27,9 @@ impl<D> ZnxInfos for Scalar<D> {
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for Scalar<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
@@ -70,19 +69,6 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> Scalar<D> {
|
||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||
self.at_mut(col, 0).shuffle(source);
|
||||
}
|
||||
|
||||
// pub fn alias_as_vec_znx(&self) -> VecZnx {
|
||||
// VecZnx {
|
||||
// inner: ZnxBase {
|
||||
// n: self.n(),
|
||||
// rows: 1,
|
||||
// cols: 1,
|
||||
// size: 1,
|
||||
// data: Vec::new(),
|
||||
// ptr: self.ptr() as *mut u8,
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> Scalar<D> {
|
||||
@@ -116,7 +102,6 @@ pub trait ScalarAlloc {
|
||||
fn bytes_of_scalar(&self, cols: usize) -> usize;
|
||||
fn new_scalar(&self, cols: usize) -> ScalarOwned;
|
||||
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned;
|
||||
// fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarAlloc for Module<B> {
|
||||
@@ -129,31 +114,62 @@ impl<B: Backend> ScalarAlloc for Module<B> {
|
||||
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned {
|
||||
ScalarOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
|
||||
}
|
||||
// fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar {
|
||||
// Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
|
||||
// }
|
||||
}
|
||||
|
||||
// impl<B: Backend> ZnxAlloc<B> for Scalar {
|
||||
// type Scalar = i64;
|
||||
pub trait ScalarToRef {
|
||||
fn to_ref(&self) -> Scalar<&[u8]>;
|
||||
}
|
||||
|
||||
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
||||
// Self {
|
||||
// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes),
|
||||
// }
|
||||
// }
|
||||
pub trait ScalarToMut {
|
||||
fn to_mut(&mut self) -> Scalar<&mut [u8]>;
|
||||
}
|
||||
|
||||
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
|
||||
// debug_assert_eq!(
|
||||
// _rows, SCALAR_ZNX_ROWS,
|
||||
// "rows != {} not supported for Scalar",
|
||||
// SCALAR_ZNX_ROWS
|
||||
// );
|
||||
// debug_assert_eq!(
|
||||
// _size, SCALAR_ZNX_SIZE,
|
||||
// "rows != {} not supported for Scalar",
|
||||
// SCALAR_ZNX_SIZE
|
||||
// );
|
||||
// module.n() * cols * std::mem::size_of::<self::Scalar>()
|
||||
// }
|
||||
// }
|
||||
impl ScalarToMut for Scalar<Vec<u8>> {
|
||||
fn to_mut(&mut self) -> Scalar<&mut [u8]> {
|
||||
Scalar {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarToRef for Scalar<Vec<u8>> {
|
||||
fn to_ref(&self) -> Scalar<&[u8]> {
|
||||
Scalar {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarToMut for Scalar<&mut [u8]> {
|
||||
fn to_mut(&mut self) -> Scalar<&mut [u8]> {
|
||||
Scalar {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarToRef for Scalar<&mut [u8]> {
|
||||
fn to_ref(&self) -> Scalar<&[u8]> {
|
||||
Scalar {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarToRef for Scalar<&[u8]> {
|
||||
fn to_ref(&self) -> Scalar<&[u8]> {
|
||||
Scalar {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user