amend rlwe_encrypt example and minor changes at multiple places

This commit is contained in:
Janmajaya Mall
2025-05-04 19:46:22 +05:30
parent b82a1ca1b4
commit bd105497fd
11 changed files with 267 additions and 414 deletions

View File

@@ -18,17 +18,10 @@ pub mod vec_znx_dft_ops;
pub mod vec_znx_ops;
pub mod znx_base;
use std::{
any::type_name,
ops::{DerefMut, Sub},
};
pub use encoding::*;
pub use mat_znx_dft::*;
pub use mat_znx_dft_ops::*;
pub use module::*;
use rand_core::le;
use rand_distr::num_traits::sign;
pub use sampling::*;
pub use scalar_znx::*;
pub use scalar_znx_dft::*;
@@ -133,6 +126,8 @@ pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
)
}
// Scratch implementation below
pub struct ScratchOwned(Vec<u8>);
impl ScratchOwned {
@@ -141,16 +136,16 @@ impl ScratchOwned {
Self(data)
}
pub fn borrow(&mut self) -> &mut ScratchBorr {
ScratchBorr::new(&mut self.0)
pub fn borrow(&mut self) -> &mut Scratch {
Scratch::new(&mut self.0)
}
}
pub struct ScratchBorr {
pub struct Scratch {
data: [u8],
}
impl ScratchBorr {
impl Scratch {
fn new(data: &mut [u8]) -> &mut Self {
unsafe { &mut *(data as *mut [u8] as *mut Self) }
}
@@ -175,14 +170,14 @@ impl ScratchBorr {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len,
take_len,
aligned_len,
// type_name::<T>(),
// aligned_len
);
}
}
fn tmp_scalar_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
pub fn tmp_scalar_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
unsafe {
@@ -193,7 +188,7 @@ impl ScratchBorr {
}
}
fn tmp_vec_znx_dft<B: Backend>(
pub fn tmp_vec_znx_dft<B: Backend>(
&mut self,
module: &Module<B>,
cols: usize,
@@ -207,103 +202,17 @@ impl ScratchBorr {
)
}
fn tmp_vec_znx_big<D: for<'a> From<&'a mut [u8]>, B: Backend>(
pub fn tmp_vec_znx_big<B: Backend>(
&mut self,
module: &Module<B>,
cols: usize,
size: usize,
) -> (VecZnxBig<D, B>, &mut Self) {
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
(
VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size),
VecZnxBig::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
}
// pub struct ScratchBorrowed<'a> {
// data: &'a mut [u8],
// }
// impl<'a> ScratchBorrowed<'a> {
// fn take_slice<T>(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) {
// let ptr = self.data.as_mut_ptr();
// let self_len = self.data.len();
// //TODO(Jay): print the offset sometimes, just to check
// let aligned_offset = ptr.align_offset(DEFAULTALIGN);
// let aligned_len = self_len.saturating_sub(aligned_offset);
// let take_len_bytes = take_len * std::mem::size_of::<T>();
// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) {
// unsafe {
// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes);
// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes);
// return (take_slice, ScratchBorrowed { data: rem_slice });
// }
// } else {
// panic!(
// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left",
// take_len_bytes,
// take_len,
// type_name::<T>(),
// aligned_len
// );
// }
// }
// fn reborrow(&mut self) -> ScratchBorrowed<'a> {
// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a.
// // But it feels that there should be a simpler impl. than the one below
// Self {
// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) },
// }
// }
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) {
// let (data, re_scratch) = self.take_slice::<u8>(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size));
// (
// VecZnxDft::from_data(data, module.n(), cols, size),
// re_scratch,
// )
// }
// pub(crate) fn len(&self) -> usize {
// self.data.len()
// }
// }
// pub trait Scratch<D> {
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (D, &mut Self);
// }
// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr {
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) {
// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size));
// (
// data
// rem_scratch,
// )
// }
// // fn tmp_vec_znx_big<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) {
// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size));
// // // (
// // // VecZnxBig::from_data(data, module.n(), cols, size),
// // // re_scratch,
// // // )
// // }
// // fn scalar_slice<T>(&mut self, len: usize) -> (&mut [T], Self) {
// // self.take_slice::<T>(len)
// // }
// // fn reborrow(&mut self) -> Self {
// // self.reborrow()
// // }
// }