use crate::znx_base::ZnxInfos; use crate::{ Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned, }; use rand::seq::SliceRandom; use rand_core::RngCore; use rand_distr::{Distribution, weighted::WeightedIndex}; use sampling::source::Source; pub struct ScalarZnx { data: D, n: usize, cols: usize, } impl ZnxInfos for ScalarZnx { fn cols(&self) -> usize { self.cols } fn rows(&self) -> usize { 1 } fn n(&self) -> usize { self.n } fn size(&self) -> usize { 1 } } impl ZnxSliceSize for ScalarZnx { fn sl(&self) -> usize { self.n() } } impl DataView for ScalarZnx { type D = D; fn data(&self) -> &Self::D { &self.data } } impl DataViewMut for ScalarZnx { fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } impl> ZnxView for ScalarZnx { type Scalar = i64; } impl + AsRef<[u8]>> ScalarZnx { pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) { let choices: [i64; 3] = [-1, 0, 1]; let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); self.at_mut(col, 0) .iter_mut() .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); } pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) { assert!(hw <= self.n()); self.at_mut(col, 0)[..hw] .iter_mut() .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); self.at_mut(col, 0).shuffle(source); } } impl>> ScalarZnx { pub(crate) fn bytes_of(n: usize, cols: usize) -> usize { n * cols * size_of::() } pub(crate) fn new(n: usize, cols: usize) -> Self { let data = alloc_aligned::(Self::bytes_of::(n, cols)); Self { data: data.into(), n, cols, } } pub(crate) fn new_from_bytes(n: usize, cols: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); assert!(data.len() == Self::bytes_of::(n, cols)); Self { data: data.into(), n, cols, } } } pub type ScalarZnxOwned = ScalarZnx>; pub(crate) fn bytes_of_scalar_znx(module: &Module, cols: usize) -> usize { ScalarZnxOwned::bytes_of::(module.n(), cols) } pub trait ScalarZnxAlloc { fn bytes_of_scalar_znx(&self, cols: usize) -> usize; fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned; fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned; } impl ScalarZnxAlloc for Module { fn bytes_of_scalar_znx(&self, cols: usize) -> usize { ScalarZnxOwned::bytes_of::(self.n(), cols) } fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned { ScalarZnxOwned::new::(self.n(), cols) } fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec) -> ScalarZnxOwned { ScalarZnxOwned::new_from_bytes::(self.n(), cols, bytes) } } impl ScalarZnx { pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self { Self { data, n, cols } } } pub trait ScalarZnxToRef { fn to_ref(&self) -> ScalarZnx<&[u8]>; } pub trait ScalarZnxToMut { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>; } impl ScalarZnxToMut for ScalarZnx> { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { ScalarZnx { data: self.data.as_mut_slice(), n: self.n, cols: self.cols, } } } impl VecZnxToMut for ScalarZnx> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), n: self.n, cols: self.cols, size: 1, } } } impl ScalarZnxToRef for ScalarZnx> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { data: self.data.as_slice(), n: self.n, cols: self.cols, } } } impl VecZnxToRef for ScalarZnx> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), n: self.n, cols: self.cols, size: 1, } } } impl ScalarZnxToMut for ScalarZnx<&mut [u8]> { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { ScalarZnx { data: self.data, n: self.n, cols: self.cols, } } } impl VecZnxToMut for ScalarZnx<&mut [u8]> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data, n: self.n, cols: self.cols, size: 1, } } } impl ScalarZnxToRef for ScalarZnx<&mut [u8]> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { data: self.data, n: self.n, cols: self.cols, } } } impl VecZnxToRef for ScalarZnx<&mut [u8]> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data, n: self.n, cols: self.cols, size: 1, } } } impl ScalarZnxToRef for ScalarZnx<&[u8]> { fn to_ref(&self) -> ScalarZnx<&[u8]> { ScalarZnx { data: self.data, n: self.n, cols: self.cols, } } } impl VecZnxToRef for ScalarZnx<&[u8]> { fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data, n: self.n, cols: self.cols, size: 1, } } }