mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
Crates io (#76)
* crates re-organisation * fixed typo in layout & added test for vmp_apply * updated dependencies
This commit is contained in:
committed by
GitHub
parent
dce4d82706
commit
a1de248567
247
poulpy-hal/src/layouts/scalar_znx.rs
Normal file
247
poulpy-hal/src/layouts/scalar_znx.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
|
||||
use crate::{
|
||||
alloc_aligned,
|
||||
api::{DataView, DataViewMut, FillUniform, Reset, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero},
|
||||
layouts::{Data, DataMut, DataRef, ReaderFrom, ToOwnedDeep, VecZnx, WriterTo},
|
||||
source::Source,
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq, Debug, Clone)]
|
||||
pub struct ScalarZnx<D: Data> {
|
||||
pub data: D,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
}
|
||||
|
||||
impl<D: DataRef> ToOwnedDeep for ScalarZnx<D> {
|
||||
type Owned = ScalarZnx<Vec<u8>>;
|
||||
fn to_owned_deep(&self) -> Self::Owned {
|
||||
ScalarZnx {
|
||||
data: self.data.as_ref().to_vec(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxInfos for ScalarZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> ZnxSliceSize for ScalarZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataView for ScalarZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Data> DataViewMut for ScalarZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ZnxView for ScalarZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: DataMut> ScalarZnx<D> {
|
||||
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
|
||||
let choices: [i64; 3] = [-1, 0, 1];
|
||||
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
||||
let dist: WeightedIndex<f64> = WeightedIndex::new(weights).unwrap();
|
||||
self.at_mut(col, 0)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
|
||||
assert!(hw <= self.n());
|
||||
self.at_mut(col, 0)[..hw]
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||
self.at_mut(col, 0).shuffle(source);
|
||||
}
|
||||
|
||||
pub fn fill_binary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
|
||||
let choices: [i64; 2] = [0, 1];
|
||||
let weights: [f64; 2] = [1.0 - prob, prob];
|
||||
let dist: WeightedIndex<f64> = WeightedIndex::new(weights).unwrap();
|
||||
self.at_mut(col, 0)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
|
||||
}
|
||||
|
||||
pub fn fill_binary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
|
||||
assert!(hw <= self.n());
|
||||
self.at_mut(col, 0)[..hw]
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = (source.next_u32() & 1) as i64);
|
||||
self.at_mut(col, 0).shuffle(source);
|
||||
}
|
||||
|
||||
pub fn fill_binary_block(&mut self, col: usize, block_size: usize, source: &mut Source) {
|
||||
assert!(self.n().is_multiple_of(block_size));
|
||||
let max_idx: u64 = (block_size + 1) as u64;
|
||||
let mask_idx: u64 = (1 << ((u64::BITS - max_idx.leading_zeros()) as u64)) - 1;
|
||||
for block in self.at_mut(col, 0).chunks_mut(block_size) {
|
||||
let idx: usize = source.next_u64n(max_idx, mask_idx) as usize;
|
||||
if idx != block_size {
|
||||
block[idx] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnx<Vec<u8>> {
|
||||
pub fn alloc_bytes(n: usize, cols: usize) -> usize {
|
||||
n * cols * size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn alloc(n: usize, cols: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned::<u8>(Self::alloc_bytes(n, cols));
|
||||
Self { data, n, cols }
|
||||
}
|
||||
|
||||
pub fn from_bytes(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::alloc_bytes(n, cols));
|
||||
Self { data, n, cols }
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ZnxZero for ScalarZnx<D> {
|
||||
fn zero(&mut self) {
|
||||
self.raw_mut().fill(0)
|
||||
}
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
self.at_mut(i, j).fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> FillUniform for ScalarZnx<D> {
|
||||
fn fill_uniform(&mut self, source: &mut Source) {
|
||||
source.fill_bytes(self.data.as_mut());
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> Reset for ScalarZnx<D> {
|
||||
fn reset(&mut self) {
|
||||
self.zero();
|
||||
self.n = 0;
|
||||
self.cols = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
||||
|
||||
impl<D: Data> ScalarZnx<D> {
|
||||
pub fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self { data, n, cols }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxToRef {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataRef> ScalarZnxToRef for ScalarZnx<D> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxToMut {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl<D: DataMut> ScalarZnxToMut for ScalarZnx<D> {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> ScalarZnx<D> {
|
||||
pub fn as_vec_znx(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_ref(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
max_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> ScalarZnx<D> {
|
||||
pub fn as_vec_znx_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
max_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for ScalarZnx<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
self.n = reader.read_u64::<LittleEndian>()? as usize;
|
||||
self.cols = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
let buf: &mut [u8] = self.data.as_mut();
|
||||
if buf.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
format!("self.data.len()={} != read len={}", buf.len(), len),
|
||||
));
|
||||
}
|
||||
reader.read_exact(&mut buf[..len])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for ScalarZnx<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
writer.write_u64::<LittleEndian>(self.n as u64)?;
|
||||
writer.write_u64::<LittleEndian>(self.cols as u64)?;
|
||||
let buf: &[u8] = self.data.as_ref();
|
||||
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
|
||||
writer.write_all(buf)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user