rework as discussed

This commit is contained in:
Jean-Philippe Bossuat
2025-05-05 17:35:35 +02:00
parent bd105497fd
commit ffa363804b
16 changed files with 1154 additions and 1153 deletions

View File

@@ -1,13 +1,10 @@
use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned};
use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
// pub const SCALAR_ZNX_ROWS: usize = 1;
// pub const SCALAR_ZNX_SIZE: usize = 1;
pub struct Scalar<D> {
data: D,
n: usize,
@@ -30,7 +27,9 @@ impl<D> ZnxInfos for Scalar<D> {
fn size(&self) -> usize {
1
}
}
impl<D> ZnxSliceSize for Scalar<D> {
fn sl(&self) -> usize {
self.n()
}
@@ -70,19 +69,6 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> Scalar<D> {
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.at_mut(col, 0).shuffle(source);
}
// pub fn alias_as_vec_znx(&self) -> VecZnx {
// VecZnx {
// inner: ZnxBase {
// n: self.n(),
// rows: 1,
// cols: 1,
// size: 1,
// data: Vec::new(),
// ptr: self.ptr() as *mut u8,
// },
// }
// }
}
impl<D: From<Vec<u8>>> Scalar<D> {
@@ -116,7 +102,6 @@ pub trait ScalarAlloc {
fn bytes_of_scalar(&self, cols: usize) -> usize;
fn new_scalar(&self, cols: usize) -> ScalarOwned;
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned;
// fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar;
}
impl<B: Backend> ScalarAlloc for Module<B> {
@@ -129,31 +114,62 @@ impl<B: Backend> ScalarAlloc for Module<B> {
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarOwned {
ScalarOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
}
// fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar {
// Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
// }
}
// impl<B: Backend> ZnxAlloc<B> for Scalar {
// type Scalar = i64;
pub trait ScalarToRef {
fn to_ref(&self) -> Scalar<&[u8]>;
}
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
// Self {
// inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes),
// }
// }
pub trait ScalarToMut {
fn to_mut(&mut self) -> Scalar<&mut [u8]>;
}
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
// debug_assert_eq!(
// _rows, SCALAR_ZNX_ROWS,
// "rows != {} not supported for Scalar",
// SCALAR_ZNX_ROWS
// );
// debug_assert_eq!(
// _size, SCALAR_ZNX_SIZE,
// "rows != {} not supported for Scalar",
// SCALAR_ZNX_SIZE
// );
// module.n() * cols * std::mem::size_of::<self::Scalar>()
// }
// }
impl ScalarToMut for Scalar<Vec<u8>> {
fn to_mut(&mut self) -> Scalar<&mut [u8]> {
Scalar {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl ScalarToRef for Scalar<Vec<u8>> {
fn to_ref(&self) -> Scalar<&[u8]> {
Scalar {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl ScalarToMut for Scalar<&mut [u8]> {
fn to_mut(&mut self) -> Scalar<&mut [u8]> {
Scalar {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl ScalarToRef for Scalar<&mut [u8]> {
fn to_ref(&self) -> Scalar<&[u8]> {
Scalar {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl ScalarToRef for Scalar<&[u8]> {
fn to_ref(&self) -> Scalar<&[u8]> {
Scalar {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}