use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, Module, ZnxView, ZnxViewMut, alloc_aligned}; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; // pub const SCALAR_ZNX_ROWS: usize = 1; // pub const SCALAR_ZNX_SIZE: usize = 1; pub struct Scalar { data: D, n: usize, cols: usize, } impl ZnxInfos for Scalar { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { 1 } fn sl(&self) -> usize { self.n() } } impl DataView for Scalar { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for Scalar { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl> ZnxView for Scalar { type Scalar = i64; } impl + AsRef<[u8]>> Scalar { pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); self.at_mut(col, 0) .iter_mut() .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); } pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { assert!(hw <= self.n()); self.at_mut(col, 0)[..hw] .iter_mut() .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } // pub fn alias_as_vec_znx(&self) -> VecZnx { // VecZnx { // inner: ZnxBase { // n: self.n(), // rows: 1, // cols: 1, // size: 1, // data: Vec::new(), // ptr: self.ptr() as *mut u8, // }, // } // } } impl>> Scalar { pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { n * cols * size_of::() } pub(crate) fn new(n: usize, cols: usize) -> Self { let data = alloc_aligned::(Self::bytes_of::(n, cols)); Self { data: data.into(), n, cols, } } pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::bytes_of::(n, cols)); Self { data: data.into(), n, cols, } } } pub type ScalarOwned = Scalar>; pub trait ScalarAlloc { fn bytes_of_scalar(&self, cols: usize) -> usize; fn new_scalar(&self, cols: usize) -> ScalarOwned; fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned; // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar; } impl ScalarAlloc for Module { fn bytes_of_scalar(&self, cols: usize) -> usize { ScalarOwned::bytes_of::(self.n(), cols) } fn new_scalar(&self, cols: usize) -> ScalarOwned { ScalarOwned::new::(self.n(), cols) } fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarOwned { ScalarOwned::new_from_bytes::(self.n(), cols, bytes) } // fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar { // Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes) // } } // impl ZnxAlloc for Scalar { // type Scalar = i64; // fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self { // Self { // inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes), // } // } // fn bytes_of(module: &Module, _rows: usize, cols: usize, _size: usize) -> usize { // debug_assert_eq!( // _rows, SCALAR_ZNX_ROWS, // "rows != {} not supported for Scalar", // SCALAR_ZNX_ROWS // ); // debug_assert_eq!( // _size, SCALAR_ZNX_SIZE, // "rows != {} not supported for Scalar", // SCALAR_ZNX_SIZE // ); // module.n() * cols * std::mem::size_of::() // } // }