reworked scalar

This commit is contained in:
Jean-Philippe Bossuat
2025-04-30 23:11:43 +02:00
parent 6f7b93c7ca
commit 9ade995cd7
8 changed files with 311 additions and 338 deletions

View File

@@ -5,7 +5,9 @@ pub mod ffi;
pub mod mat_znx_dft;
pub mod module;
pub mod sampling;
pub mod scalar_znx;
pub mod scalar_znx_dft;
pub mod scalar_znx_dft_ops;
pub mod stats;
pub mod vec_znx;
pub mod vec_znx_big;
@@ -19,8 +21,11 @@ pub use encoding::*;
pub use mat_znx_dft::*;
pub use module::*;
pub use sampling::*;
#[allow(unused_imports)]
pub use scalar_znx::*;
pub use scalar_znx_dft::*;
#[allow(unused_imports)]
pub use scalar_znx_dft_ops::*;
pub use stats::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
@@ -50,13 +55,13 @@ pub fn assert_alignement<T>(ptr: *const T) {
pub fn cast<T, V>(data: &[T]) -> &[V] {
let ptr: *const V = data.as_ptr() as *const V;
let len: usize = data.len() / std::mem::size_of::<V>();
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
let ptr: *mut V = data.as_ptr() as *mut V;
let len: usize = data.len() / std::mem::size_of::<V>();
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
@@ -70,7 +75,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
align
);
assert_eq!(
(size * std::mem::size_of::<u8>()) % align,
(size * size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
@@ -98,22 +103,25 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * std::mem::size_of::<T>()) % align,
(size * size_of::<T>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(std::mem::size_of::<T>() * size, align);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
let len: usize = vec_u8.len() / size_of::<T>();
let cap: usize = vec_u8.capacity() / size_of::<T>();
std::mem::forget(vec_u8);
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
/// Allocates an aligned of size equal to the smallest multiple
/// of [DEFAULTALIGN] that is equal or greater to `size`.
/// Allocates an aligned vector of size equal to the smallest multiple
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(size + (size % DEFAULTALIGN), DEFAULTALIGN)
alloc_aligned_custom::<T>(
size + (size % (DEFAULTALIGN / size_of::<T>())),
DEFAULTALIGN,
)
}