From b82a1ca1b4cfb124a65d76f1e6bd01249555ecbe Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 4 May 2025 18:39:28 +0530 Subject: [PATCH] wip --- base2k/src/lib.rs | 196 ++++++++++++++++++++--- base2k/src/mat_znx_dft.rs | 48 +++--- base2k/src/mat_znx_dft_ops.rs | 278 ++++++++++++++++++-------------- base2k/src/vec_znx.rs | 53 +++--- base2k/src/vec_znx_big.rs | 46 +++++- base2k/src/vec_znx_big_ops.rs | 35 ++-- base2k/src/vec_znx_dft.rs | 16 +- base2k/src/vec_znx_dft_ops.rs | 11 +- base2k/src/vec_znx_ops.rs | 20 ++- base2k/src/znx_base.rs | 294 +++++++++------------------------- 10 files changed, 551 insertions(+), 446 deletions(-) diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index 7ae1193..f33ce60 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -18,10 +18,17 @@ pub mod vec_znx_dft_ops; pub mod vec_znx_ops; pub mod znx_base; +use std::{ + any::type_name, + ops::{DerefMut, Sub}, +}; + pub use encoding::*; pub use mat_znx_dft::*; pub use mat_znx_dft_ops::*; pub use module::*; +use rand_core::le; +use rand_distr::num_traits::sign; pub use sampling::*; pub use scalar_znx::*; pub use scalar_znx_dft::*; @@ -126,28 +133,177 @@ pub fn alloc_aligned(size: usize) -> Vec { ) } -pub(crate) struct ScratchSpace { - // data: D, -} +pub struct ScratchOwned(Vec); -impl ScratchSpace { - fn tmp_vec_znx_dft(&mut self, n: usize, cols: usize, size: usize) -> VecZnxDft { - todo!() +impl ScratchOwned { + pub fn new(byte_count: usize) -> Self { + let data: Vec = alloc_aligned(byte_count); + Self(data) } - fn tmp_vec_znx_big(&mut self, n: usize, cols: usize, size: usize) -> VecZnxBig { - todo!() - } - - fn vec_znx_big_normalize_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { - todo!() - } - - fn vmp_apply_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { - todo!() - } - - fn vmp_apply_dft_to_dft_tmp_bytes(&mut self, module: &Module) -> &mut [u8] { - todo!() + pub fn borrow(&mut self) -> &mut ScratchBorr { + ScratchBorr::new(&mut self.0) } } + +pub struct ScratchBorr { + data: [u8], +} + +impl ScratchBorr { + fn new(data: &mut [u8]) -> &mut Self { + unsafe { &mut *(data as *mut [u8] as *mut Self) } + } + + fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) { + let ptr = data.as_mut_ptr(); + let self_len = data.len(); + + let aligned_offset = ptr.align_offset(DEFAULTALIGN); + let aligned_len = self_len.saturating_sub(aligned_offset); + + if let Some(rem_len) = aligned_len.checked_sub(take_len) { + unsafe { + let rem_ptr = ptr.add(aligned_offset).add(take_len); + let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + + let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len); + + return (take_slice, rem_slice); + } + } else { + panic!( + "Attempted to take {} from scratch with {} aligned bytes left", + take_len, + take_len, + // type_name::(), + // aligned_len + ); + } + } + + fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); + + unsafe { + ( + &mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)), + Self::new(rem_slice), + ) + } + } + + fn tmp_vec_znx_dft( + &mut self, + module: &Module, + cols: usize, + size: usize, + ) -> (VecZnxDft<&mut [u8], B>, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size)); + + ( + VecZnxDft::from_data(take_slice, module.n(), cols, size), + Self::new(rem_slice), + ) + } + + fn tmp_vec_znx_big From<&'a mut [u8]>, B: Backend>( + &mut self, + module: &Module, + cols: usize, + size: usize, + ) -> (VecZnxBig, &mut Self) { + let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size)); + + ( + VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size), + Self::new(rem_slice), + ) + } +} + +// pub struct ScratchBorrowed<'a> { +// data: &'a mut [u8], +// } + +// impl<'a> ScratchBorrowed<'a> { +// fn take_slice(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) { +// let ptr = self.data.as_mut_ptr(); +// let self_len = self.data.len(); + +// //TODO(Jay): print the offset sometimes, just to check +// let aligned_offset = ptr.align_offset(DEFAULTALIGN); +// let aligned_len = self_len.saturating_sub(aligned_offset); + +// let take_len_bytes = take_len * std::mem::size_of::(); + +// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) { +// unsafe { +// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes); +// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); + +// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes); + +// return (take_slice, ScratchBorrowed { data: rem_slice }); +// } +// } else { +// panic!( +// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left", +// take_len_bytes, +// take_len, +// type_name::(), +// aligned_len +// ); +// } +// } + +// fn reborrow(&mut self) -> ScratchBorrowed<'a> { +// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a. +// // But it feels that there should be a simpler impl. than the one below +// Self { +// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) }, +// } +// } + +// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) { +// let (data, re_scratch) = self.take_slice::(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); +// ( +// VecZnxDft::from_data(data, module.n(), cols, size), +// re_scratch, +// ) +// } + +// pub(crate) fn len(&self) -> usize { +// self.data.len() +// } +// } + +// pub trait Scratch { +// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (D, &mut Self); +// } + +// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr { +// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) { +// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); +// ( +// data +// rem_scratch, +// ) +// } + +// // fn tmp_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) { +// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size)); +// // // ( +// // // VecZnxBig::from_data(data, module.n(), cols, size), +// // // re_scratch, +// // // ) +// // } + +// // fn scalar_slice(&mut self, len: usize) -> (&mut [T], Self) { +// // self.take_slice::(len) +// // } + +// // fn reborrow(&mut self) -> Self { +// // self.reborrow() +// // } +// } diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index 34c711a..7a39dd1 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,4 +1,4 @@ -use crate::znx_base::{GetZnxBase, ZnxBase, ZnxInfos}; +use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; use std::marker::PhantomData; @@ -111,26 +111,6 @@ impl>, B: Backend> MatZnxDft { _marker: PhantomData, } } - - // pub fn from_bytes_borrow( - // module: &Module, - // rows: usize, - // cols_in: usize, - // cols_out: usize, - // size: usize, - // bytes: &mut [u8], - // ) -> Self { - // debug_assert_eq!( - // bytes.len(), - // Self::bytes_of(module, rows, cols_in, cols_out, size) - // ); - // Self { - // inner: ZnxBase::from_bytes_borrow(module.n(), rows, cols_out, size, bytes), - // cols_in: cols_in, - // cols_out: cols_out, - // _marker: PhantomData, - // } - // } } impl> MatZnxDft { @@ -170,3 +150,29 @@ impl> MatZnxDft { } pub type MatZnxDftAllocOwned = MatZnxDft, B>; + +impl MatZnxDft, B> { + pub fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + size: self.size, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + _marker: PhantomData, + } + } + + pub fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + size: self.size, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + _marker: PhantomData, + } + } +} diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 62b56a1..5ab44df 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, assert_alignement, is_aligned, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, }; pub trait MatZnxDftAlloc { @@ -36,12 +36,55 @@ pub trait MatZnxDftAlloc { // ) -> MatZnxDft; } -/// This trait implements methods for vector matrix product, -/// that is, multiplying a [VecZnx] with a [MatZnxDft]. -pub trait MatZnxDftOps { +pub trait MatZnxDftScratch { /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_prepare_row] fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; + /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; + + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. + /// + /// # Arguments + /// + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnx]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. + fn vmp_apply_dft_tmp_bytes( + &self, + c_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; + + /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. + /// + /// # Arguments + /// + /// * `c_size`: number of size of the output [VecZnxDft]. + /// * `a_size`: number of size of the input [VecZnxDft]. + /// * `rows`: number of rows of the input [MatZnxDft]. + /// * `size`: number of size of the input [MatZnxDft]. + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_cols: usize, + c_size: usize, + a_cols: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize; +} + +/// This trait implements methods for vector matrix product, +/// that is, multiplying a [VecZnx] with a [MatZnxDft]. +pub trait MatZnxDftOps { /// Prepares the ith-row of [MatZnxDft] from a [VecZnx]. /// /// # Arguments @@ -58,12 +101,9 @@ pub trait MatZnxDftOps { b_row: usize, b_col_in: usize, a: &VecZnx, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); - /// Returns the of bytes needed as scratch space for [MatZnxDftOps::vmp_extract_row] - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize; - /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. /// /// # Arguments @@ -78,7 +118,7 @@ pub trait MatZnxDftOps { a: &MatZnxDft, b_row: usize, b_col_in: usize, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. @@ -101,24 +141,6 @@ pub trait MatZnxDftOps { /// * `row_i`: the index of the row to extract. fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize); - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnx]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_tmp_bytes( - &self, - c_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; - /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// /// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft] @@ -143,27 +165,7 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchSpace); - - /// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft]. - /// - /// # Arguments - /// - /// * `c_size`: number of size of the output [VecZnxDft]. - /// * `a_size`: number of size of the input [VecZnxDft]. - /// * `rows`: number of rows of the input [MatZnxDft]. - /// * `size`: number of size of the input [MatZnxDft]. - fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - c_cols: usize, - c_size: usize, - a_cols: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize; + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchBorr); /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -195,7 +197,7 @@ pub trait MatZnxDftOps { c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); } @@ -220,22 +222,70 @@ impl MatZnxDftAlloc for Module { } } -impl MatZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]>, - Data: AsRef<[u8]>, -{ +impl MatZnxDftScratch for Module { fn vmp_prepare_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - >::bytes_of_vec_znx_dft(self, cols_out, size) + >::bytes_of_vec_znx_dft(self, cols_out, size) } + fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { + >::bytes_of_vec_znx_dft(self, cols_out, size) + + ::vec_znx_big_normalize_tmp_bytes(self) + } + + fn vmp_apply_dft_tmp_bytes( + &self, + c_size: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_tmp_bytes( + self.ptr, + c_size as u64, + a_size as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } + fn vmp_apply_dft_to_dft_tmp_bytes( + &self, + c_cols: usize, + c_size: usize, + a_cols: usize, + a_size: usize, + b_rows: usize, + b_cols_in: usize, + b_cols_out: usize, + b_size: usize, + ) -> usize { + unsafe { + vmp::vmp_apply_dft_to_dft_tmp_bytes( + self.ptr, + (c_size * c_cols) as u64, + (a_size * a_cols) as u64, + (b_rows * b_cols_in) as u64, + (b_size * b_cols_out) as u64, + ) as usize + } + } +} + +impl MatZnxDftOps for Module +where + DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>, + Data: AsRef<[u8]>, +{ fn vmp_prepare_row( &self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnx, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -278,17 +328,13 @@ where let a_size: usize = a.size(); // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - let mut a_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, a_size); + let (mut a_dft, _) = scratch.tmp_scalar_slice(12); + DataMut::from(a_dft); + // let (mut a_dft, _) = scratch.tmp_vec_znx_dft::(self, cols_out, a_size); (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); - Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); } - fn vmp_extract_row_tmp_bytes(&self, cols_out: usize, size: usize) -> usize { - self.bytes_of_vec_znx_dft(cols_out, size) - + >::vec_znx_big_normalize_tmp_bytes(self) - } - fn vmp_extract_row( &self, log_base2k: usize, @@ -296,7 +342,7 @@ where a: &MatZnxDft, a_row: usize, a_col_in: usize, - scratch: &mut ScratchSpace, + mut scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -336,9 +382,9 @@ where let size: usize = b.size(); // let (bytes_a_dft, tmp_bytes) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, size)); - let mut b_dft = scratch.tmp_vec_znx_dft::(self.n(), cols_out, size); + let (mut b_dft, scratch) = scratch.tmp_vec_znx_dft(self, cols_out, size); Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); - let mut b_big = scratch.tmp_vec_znx_big(self.n(), cols_out, size); + let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size); (0..cols_out).for_each(|i| { >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); @@ -434,32 +480,12 @@ where } } - fn vmp_apply_dft_tmp_bytes( - &self, - res_size: usize, - a_size: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_tmp_bytes( - self.ptr, - res_size as u64, - a_size as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - fn vmp_apply_dft( &self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, - scratch: &mut ScratchSpace, + mut scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -493,6 +519,16 @@ where // ); // assert_alignement(tmp_bytes.as_ptr()); } + let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vmp_apply_dft_tmp_bytes( + self, + c.size(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { vmp::vmp_apply_dft( self.ptr, @@ -504,39 +540,17 @@ where b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, - scratch.vmp_apply_dft_tmp_bytes(self).as_mut_ptr(), + tmp_bytes.as_mut_ptr(), ) } } - fn vmp_apply_dft_to_dft_tmp_bytes( - &self, - res_cols: usize, - res_size: usize, - a_size: usize, - a_cols: usize, - b_rows: usize, - b_cols_in: usize, - b_cols_out: usize, - b_size: usize, - ) -> usize { - unsafe { - vmp::vmp_apply_dft_to_dft_tmp_bytes( - self.ptr, - (res_size * res_cols) as u64, - (a_size * a_cols) as u64, - (b_rows * b_cols_in) as u64, - (b_size * b_cols_out) as u64, - ) as usize - } - } - fn vmp_apply_dft_to_dft( &self, c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, - scratch: &mut ScratchSpace, + mut scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -572,6 +586,17 @@ where // ); // assert_alignement(tmp_bytes.as_ptr()); } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( + c.cols(), + c.size(), + a.cols(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); unsafe { vmp::vmp_apply_dft_to_dft( self.ptr, @@ -582,7 +607,7 @@ where b.as_ptr() as *const vmp::vmp_pmat_t, b.rows() as u64, (b.size() * b.cols()) as u64, - scratch.vmp_apply_dft_to_dft_tmp_bytes(self).as_mut_ptr(), + tmp_bytes.as_mut_ptr(), ) } } @@ -590,6 +615,7 @@ where #[cfg(test)] mod tests { + use crate::ScratchOwned; use crate::mat_znx_dft_ops::*; use crate::vec_znx_big_ops::*; use crate::vec_znx_dft_ops::*; @@ -617,7 +643,9 @@ mod tests { // let mut tmp_bytes: Vec = // alloc_aligned(module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) | module.vec_znx_big_normalize_tmp_bytes()); - let mut scratch = ScratchSpace {}; + let mut scratch = ScratchOwned::new( + 2 * (module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + module.vec_znx_big_normalize_tmp_bytes()), + ); let mut tmp_bytes: Vec = alloc_aligned::( as VecZnxDftOps, Vec, _>>::vec_znx_idft_tmp_bytes(&module)); @@ -630,7 +658,9 @@ mod tests { module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - module.vmp_prepare_row(&mut vmpmat_0, row_i, col_in, &a, &mut scratch); + // let g = vmpmat_0.to_mut(); + + module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow()); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); @@ -641,11 +671,25 @@ mod tests { assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - module.vmp_extract_row(log_base2k, &mut b, &vmpmat_0, row_i, col_in, &mut scratch); + // module.vmp_extract_row( + // log_base2k, + // &mut b.to_mut(), + // &vmpmat_0.to_ref(), + // row_i, + // col_in, + // scratch.borrow(), + // ); (0..mat_cols_out).for_each(|col_out| { module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); - module.vec_znx_big_normalize(log_base2k, &mut a, col_out, &a_big, col_out, &mut scratch); + module.vec_znx_big_normalize( + log_base2k, + &mut a.to_mut(), + col_out, + &a_big.to_ref(), + col_out, + scratch.borrow(), + ); }); assert_eq!(a.raw(), b.raw()); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index 3321f8e..b386604 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -97,11 +97,6 @@ impl + AsRef<[u8]>> VecZnx { pub fn switch_degree>(&mut self, col: usize, a: &VecZnx, col_a: usize) { switch_degree(self, col_a, a, col) } - - // Prints the first `n` coefficients of each limb - // pub fn print(&self, n: usize, col: usize) { - // (0..self.size()).for_each(|j| println!("{}: {:?}", j, &self.at(col, j)[..n])); - // } } impl>> VecZnx { @@ -131,8 +126,6 @@ impl>> VecZnx { } } -//(Jay)TODO: Impl. truncate pow2 for Owned Vector - /// Copies the coefficients of `a` on the receiver. /// Copy is done with the minimum size matching both backing arrays. /// Panics if the cols do not match. @@ -148,12 +141,6 @@ where data_b[..size].copy_from_slice(&data_a[..size]) } -// if !self.borrowing() { -// self.inner -// .data -// .truncate(self.n() * self.cols() * (self.size() - k / log_base2k)); -// } - fn normalize_tmp_bytes(n: usize) -> usize { n * std::mem::size_of::() } @@ -190,26 +177,6 @@ fn normalize + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx, } } -// impl ZnxAlloc for VecZnx { -// type Scalar = i64; - -// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnx { -// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); -// VecZnx { -// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_ROWS, cols, size, bytes), -// } -// } - -// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { -// debug_assert_eq!( -// _rows, VEC_ZNX_ROWS, -// "rows != {} not supported for VecZnx", -// VEC_ZNX_ROWS -// ); -// module.n() * cols * size * size_of::() -// } -// } - impl> fmt::Display for VecZnx { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( @@ -248,3 +215,23 @@ impl> fmt::Display for VecZnx { pub type VecZnxOwned = VecZnx>; pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; + +impl VecZnx> { + pub(crate) fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + VecZnx { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + } + } + + pub(crate) fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + } + } +} diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 72b15d7..7442f11 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -53,13 +53,13 @@ impl> ZnxView for VecZnxBig { type Scalar = i64; } -impl>, B: Backend> VecZnxBig { - pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } - } +pub(crate) fn bytes_of_vec_znx_big(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols } +} +impl>, B: Backend> VecZnxBig { pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + let data = alloc_aligned::(bytes_of_vec_znx_big(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -71,7 +71,7 @@ impl>, B: Backend> VecZnxBig { pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, cols, size)); + assert!(data.len() == bytes_of_vec_znx_big(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -82,8 +82,42 @@ impl>, B: Backend> VecZnxBig { } } +impl VecZnxBig { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { + Self { + data, + n, + cols, + size, + _phantom: PhantomData, + } + } +} + pub type VecZnxBigOwned = VecZnxBig, B>; +impl VecZnxBig, B> { + pub(crate) fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + VecZnxBig { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } + + pub(crate) fn to_ref(&self) -> VecZnxBig<&[u8], B> { + VecZnxBig { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} + // impl VecZnxBig { // pub fn print(&self, n: usize, col: usize) { // (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index bb46802..20b4f2e 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,6 +1,9 @@ use crate::ffi::vec_znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{Backend, DataView, FFT64, Module, ScratchSpace, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxOps, assert_alignement}; +use crate::{ + Backend, DataView, FFT64, Module, ScratchBorr, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, assert_alignement, + bytes_of_vec_znx_big, +}; pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -113,9 +116,6 @@ pub trait VecZnxBigOps { /// Subtracts `res` from `a` and stores the result on `res`. fn vec_znx_big_sub_small_b_inplace(&self, res: &mut VecZnxBig, res_col: usize, a: &VecZnx, a_col: usize); - /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; - /// Normalizes `a` and stores the result on `b`. /// /// # Arguments @@ -129,7 +129,7 @@ pub trait VecZnxBigOps { res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. @@ -146,6 +146,11 @@ pub trait VecZnxBigOps { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut VecZnxBig, a_col: usize); } +pub trait VecZnxBigScratch { + /// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize]. + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize; +} + impl VecZnxBigAlloc for Module { fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned { VecZnxBig::new(self, cols, size) @@ -160,7 +165,7 @@ impl VecZnxBigAlloc for Module { // } fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize { - VecZnxBigOwned::bytes_of(self, cols, size) + bytes_of_vec_znx_big(self, cols, size) } } @@ -491,10 +496,6 @@ where } } - fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { - >::vec_znx_normalize_tmp_bytes(self) - } - fn vec_znx_big_normalize( &self, log_base2k: usize, @@ -502,7 +503,7 @@ where res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchSpace, + scratch: &mut ScratchBorr, ) { #[cfg(debug_assertions)] { @@ -513,6 +514,10 @@ where // assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); // assert_alignement(tmp_bytes.as_ptr()); } + + let (tmp_bytes, _) = scratch.tmp_scalar_slice(::vec_znx_big_normalize_tmp_bytes( + &self, + )); unsafe { vec_znx::vec_znx_normalize_base2k( self.ptr, @@ -523,7 +528,7 @@ where a.at_ptr(a_col, 0), a.size() as u64, a.sl() as u64, - scratch.vec_znx_big_normalize_tmp_bytes(self).as_mut_ptr(), + tmp_bytes.as_mut_ptr(), ); } } @@ -574,3 +579,9 @@ where } } } + +impl VecZnxBigScratch for Module { + fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + ::vec_znx_normalize_tmp_bytes(self) + } +} diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 74b559c..5d15c00 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -54,13 +54,13 @@ impl> ZnxView for VecZnxDft { type Scalar = f64; } -impl>, B: Backend> VecZnxDft { - pub(crate) fn bytes_of(module: &Module, cols: usize, size: usize) -> usize { - unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } - } +pub(crate) fn bytes_of_vec_znx_dft(module: &Module, cols: usize, size: usize) -> usize { + unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } +} +impl>, B: Backend> VecZnxDft { pub(crate) fn new(module: &Module, cols: usize, size: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(module, cols, size)); + let data = alloc_aligned::(bytes_of_vec_znx_dft(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -72,7 +72,7 @@ impl>, B: Backend> VecZnxDft { pub(crate) fn new_from_bytes(module: &Module, cols: usize, size: usize, bytes: impl Into>) -> Self { let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, cols, size)); + assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size)); Self { data: data.into(), n: module.n(), @@ -85,8 +85,8 @@ impl>, B: Backend> VecZnxDft { pub type VecZnxDftOwned = VecZnxDft, B>; -impl<'a, D: ?Sized, B> VecZnxDft<&'a mut D, B> { - pub(crate) fn from_mut_slice(data: &'a mut D, n: usize, cols: usize, size: usize) -> Self { +impl VecZnxDft { + pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self { Self { data, n, diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 2c1cc97..9a1db2a 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -1,6 +1,7 @@ -use crate::VecZnxDftOwned; use crate::ffi::{vec_znx_big, vec_znx_dft}; +use crate::vec_znx_dft::bytes_of_vec_znx_dft; use crate::znx_base::ZnxInfos; +use crate::{Backend, VecZnxDftOwned}; use crate::{FFT64, Module, VecZnx, VecZnxBig, VecZnxDft, ZnxView, ZnxViewMut, ZnxZero, assert_alignement}; use std::cmp::min; @@ -66,12 +67,12 @@ pub trait VecZnxDftOps { fn vec_znx_dft(&self, res: &mut VecZnxDft, res_col: usize, a: &VecZnx, a_col: usize); } -impl VecZnxDftAlloc for Module { - fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { +impl VecZnxDftAlloc for Module { + fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned { VecZnxDftOwned::new(&self, cols, size) } - fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { + fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned { VecZnxDftOwned::new_from_bytes(self, cols, size, bytes) } @@ -80,7 +81,7 @@ impl VecZnxDftAlloc for Module { // } fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize { - VecZnxDftOwned::bytes_of(&self, cols, size) + bytes_of_vec_znx_dft(self, cols, size) } } diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index 6951651..d647860 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -43,9 +43,6 @@ pub trait VecZnxAlloc { } pub trait VecZnxOps { - /// Returns the minimum number of bytes necessary for normalization. - fn vec_znx_normalize_tmp_bytes(&self) -> usize; - /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. fn vec_znx_normalize( &self, @@ -137,6 +134,11 @@ pub trait VecZnxOps { fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec>, a_col: usize); } +pub trait VecZnxScratch { + /// Returns the minimum number of bytes necessary for normalization. + fn vec_znx_normalize_tmp_bytes(&self) -> usize; +} + impl VecZnxAlloc for Module { //(Jay)TODO: One must define the Scalar generic param here. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { @@ -157,10 +159,6 @@ where Data: AsRef<[u8]>, DataMut: AsRef<[u8]> + AsMut<[u8]>, { - fn vec_znx_normalize_tmp_bytes(&self) -> usize { - unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } - } - fn vec_znx_normalize( &self, log_base2k: usize, @@ -174,7 +172,7 @@ where { assert_eq!(a.n(), self.n()); assert_eq!(res.n(), self.n()); - assert!(tmp_bytes.len() >= >::vec_znx_normalize_tmp_bytes(&self)); + assert!(tmp_bytes.len() >= ::vec_znx_normalize_tmp_bytes(&self)); assert_alignement(tmp_bytes.as_ptr()); } unsafe { @@ -489,3 +487,9 @@ where >::vec_znx_rotate_inplace(self, a.len() as i64, res, res_col); } } + +impl VecZnxScratch for Module { + fn vec_znx_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize } + } +} diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index a7361ad..69afef8 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -1,59 +1,6 @@ -use crate::{Backend, Module, alloc_aligned, assert_alignement, cast_mut}; use itertools::izip; use std::cmp::min; -pub struct ZnxBase { - /// The ring degree - pub n: usize, - - /// The number of rows (in the third dimension) - pub rows: usize, - - /// The number of polynomials - pub cols: usize, - - /// The number of size per polynomial (a.k.a small polynomials). - pub size: usize, - - /// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n. - pub data: Vec, - - /// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it). - pub ptr: *mut u8, -} - -impl ZnxBase { - pub fn from_bytes(n: usize, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { - let mut res: Self = Self::from_bytes_borrow(n, rows, cols, size, &mut bytes); - res.data = bytes; - res - } - - pub fn from_bytes_borrow(n: usize, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { - #[cfg(debug_assertions)] - { - assert_eq!(n & (n - 1), 0, "n must be a power of two"); - assert!(n > 0, "n must be greater than 0"); - assert!(rows > 0, "rows must be greater than 0"); - assert!(cols > 0, "cols must be greater than 0"); - assert!(size > 0, "size must be greater than 0"); - } - Self { - n: n, - rows: rows, - cols: cols, - size: size, - data: Vec::new(), - ptr: bytes.as_mut_ptr(), - } - } -} - -pub trait GetZnxBase { - fn znx(&self) -> &ZnxBase; - fn znx_mut(&mut self) -> &mut ZnxBase; -} - pub trait ZnxInfos { /// Returns the ring degree of the polynomials. fn n(&self) -> usize; @@ -82,30 +29,6 @@ pub trait ZnxInfos { fn sl(&self) -> usize; } -// pub trait ZnxSliceSize {} - -//(Jay) TODO: Remove ZnxAlloc -// pub trait ZnxAlloc -// where -// Self: Sized + ZnxInfos, -// { -// type Scalar; -// fn new(module: &Module, rows: usize, cols: usize, size: usize) -> Self { -// let bytes: Vec = alloc_aligned::(Self::bytes_of(module, rows, cols, size)); -// Self::from_bytes(module, rows, cols, size, bytes) -// } - -// fn from_bytes(module: &Module, rows: usize, cols: usize, size: usize, mut bytes: Vec) -> Self { -// let mut res: Self = Self::from_bytes_borrow(module, rows, cols, size, &mut bytes); -// res.znx_mut().data = bytes; -// res -// } - -// fn from_bytes_borrow(module: &Module, rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self; - -// fn bytes_of(module: &Module, rows: usize, cols: usize, size: usize) -> usize; -// } - pub trait DataView { type D; fn data(&self) -> &Self::D; @@ -176,35 +99,6 @@ pub trait ZnxViewMut: ZnxView + DataViewMut> { //(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known impl ZnxViewMut for T where T: ZnxView + DataViewMut> {} -use std::convert::TryFrom; -use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; -pub trait Num: - Copy - + Default - + PartialEq - + PartialOrd - + Add - + Sub - + Mul - + Div - + Neg - + AddAssign -{ - const BITS: u32; -} - -impl Num for i64 { - const BITS: u32 = 64; -} - -impl Num for i128 { - const BITS: u32 = 128; -} - -impl Num for f64 { - const BITS: u32 = 64; -} - pub trait ZnxZero: ZnxViewMut where Self: Sized, @@ -261,128 +155,96 @@ pub fn switch_degree + ZnxZero, D: ZnxView }); } +use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; + +use crate::{ScratchBorr, cast_mut}; +pub trait Integer: + Copy + + Default + + PartialEq + + PartialOrd + + Add + + Sub + + Mul + + Div + + Neg + + Shl + + Shr + + AddAssign +{ + const BITS: u32; +} + +impl Integer for i64 { + const BITS: u32 = 64; +} + +impl Integer for i128 { + const BITS: u32 = 128; +} + // (Jay)TODO: implement rsh for VecZnx, VecZnxBig // pub trait ZnxRsh: ZnxZero { // fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { // rsh(k, log_base2k, self, col, carry) // } // } -// pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, tmp_bytes: &mut [u8]) { -// let n: usize = a.n(); -// let size: usize = a.size(); -// let cols: usize = a.cols(); +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, scratch: &mut ScratchBorr) +where + V::Scalar: From + Integer, +{ + let n: usize = a.n(); + let size: usize = a.size(); + let cols: usize = a.cols(); -// #[cfg(debug_assertions)] -// { -// assert!( -// tmp_bytes.len() >= rsh_tmp_bytes::(n), -// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", -// tmp_bytes.len() / size_of::(), -// n, -// size, -// ); -// assert_alignement(tmp_bytes.as_ptr()); -// } + // #[cfg(debug_assertions)] + // { + // assert!( + // tmp_bytes.len() >= rsh_tmp_bytes::(n), + // "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", + // tmp_bytes.len() / size_of::(), + // n, + // size, + // ); + // assert_alignement(tmp_bytes.as_ptr()); + // } -// let size: usize = a.size(); -// let steps: usize = k / log_base2k; + let size: usize = a.size(); + let steps: usize = k / log_base2k; -// a.raw_mut().rotate_right(n * steps * cols); -// (0..cols).for_each(|i| { -// (0..steps).for_each(|j| { -// a.zero_at(i, j); -// }) -// }); + a.raw_mut().rotate_right(n * steps * cols); + (0..cols).for_each(|i| { + (0..steps).for_each(|j| { + a.zero_at(i, j); + }) + }); -// let k_rem: usize = k % log_base2k; + let k_rem: usize = k % log_base2k; -// if k_rem != 0 { -// let carry: &mut [V::Scalar] = cast_mut(tmp_bytes); + if k_rem != 0 { + let (carry, _) = scratch.tmp_scalar_slice::(rsh_tmp_bytes::(n)); -// unsafe { -// std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); -// } + unsafe { + std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::()); + } -// let log_base2k_t: V::Scalar = V::Scalar::try_from(log_base2k).unwrap(); -// let shift: V::Scalar = V::Scalar::try_from(V::Scalar::BITS as usize - k_rem).unwrap(); -// let k_rem_t: V::Scalar = V::Scalar::try_from(k_rem).unwrap(); + let log_base2k_t = V::Scalar::from(log_base2k); + let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem); + let k_rem_t = V::Scalar::from(k_rem); -// (steps..size).for_each(|i| { -// izip!(carry.iter_mut(), a.at_mut(a_col, i).iter_mut()).for_each(|(ci, xi)| { -// *xi += *ci << log_base2k_t; -// *ci = get_base_k_carry(*xi, shift); -// *xi = (*xi - *ci) >> k_rem_t; -// }); -// }) -// } -// } + (0..cols).for_each(|i| { + (steps..size).for_each(|j| { + izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| { + *xi += *ci << log_base2k_t; + *ci = (*xi << shift) >> shift; + *xi = (*xi - *ci) >> k_rem_t; + }); + }); + //TODO: ZERO CARRYcarry + }) + } +} -// #[inline(always)] -// fn get_base_k_carry(x: T, shift: T) -> T { -// (x << shift) >> shift -// } - -// pub fn rsh_tmp_bytes(n: usize) -> usize { -// n * std::mem::size_of::() -// } - -// pub trait ZnxLayout: ZnxInfos { -// type Scalar; - -// /// Returns true if the receiver is only borrowing the data. -// fn borrowing(&self) -> bool { -// self.znx().data.len() == 0 -// } - -// /// Returns a non-mutable pointer to the underlying coefficients array. -// fn as_ptr(&self) -> *const Self::Scalar { -// self.znx().ptr as *const Self::Scalar -// } - -// /// Returns a mutable pointer to the underlying coefficients array. -// fn as_mut_ptr(&mut self) -> *mut Self::Scalar { -// self.znx_mut().ptr as *mut Self::Scalar -// } - -// /// Returns a non-mutable reference to the entire underlying coefficient array. -// fn raw(&self) -> &[Self::Scalar] { -// unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) } -// } - -// /// Returns a mutable reference to the entire underlying coefficient array. -// fn raw_mut(&mut self) -> &mut [Self::Scalar] { -// unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) } -// } - -// /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column. -// fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar { -// #[cfg(debug_assertions)] -// { -// assert!(i < self.cols()); -// assert!(j < self.size()); -// } -// let offset: usize = self.n() * (j * self.cols() + i); -// unsafe { self.as_ptr().add(offset) } -// } - -// /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column. -// fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar { -// #[cfg(debug_assertions)] -// { -// assert!(i < self.cols()); -// assert!(j < self.size()); -// } -// let offset: usize = self.n() * (j * self.cols() + i); -// unsafe { self.as_mut_ptr().add(offset) } -// } - -// /// Returns non-mutable reference to the (i, j)-th small polynomial. -// fn at(&self, i: usize, j: usize) -> &[Self::Scalar] { -// unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) } -// } - -// /// Returns mutable reference to the (i, j)-th small polynomial. -// fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] { -// unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) } -// } -// } +pub fn rsh_tmp_bytes(n: usize) -> usize { + n * std::mem::size_of::() +}