From bd105497fd097284f6e0fa7c029577eedb473014 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Sun, 4 May 2025 19:46:22 +0530 Subject: [PATCH] amend rlwe_encrypt example and minor changes at multiple places --- base2k/examples/rlwe_encrypt.rs | 52 ++++---- base2k/examples/vmp.rs | 122 +++++++++--------- base2k/src/lib.rs | 115 ++--------------- base2k/src/mat_znx_dft_ops.rs | 219 ++++++++++++++++---------------- base2k/src/vec_znx.rs | 15 ++- base2k/src/vec_znx_big.rs | 12 +- base2k/src/vec_znx_big_ops.rs | 9 +- base2k/src/vec_znx_dft.rs | 80 +++++------- base2k/src/vec_znx_dft_ops.rs | 13 -- base2k/src/vec_znx_ops.rs | 14 -- base2k/src/znx_base.rs | 30 +---- 11 files changed, 267 insertions(+), 414 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index afac2f8..742dcea 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,6 +1,6 @@ use base2k::{ - Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, - VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned, + Encoding, FFT64, Module, Sampling, ScalarAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxAlloc, VecZnxBigOps, + VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, }; use itertools::izip; use sampling::source::Source; @@ -13,24 +13,24 @@ fn main() { let log_scale: usize = msg_size * log_base2k - 5; let module: Module = Module::::new(n); - let mut tmp_bytes_norm: Vec = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes()); - let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size)); + let mut scratch = + ScratchOwned::new((2 * module.bytes_of_vec_znx_dft(1, ct_size)) + 2 * module.vec_znx_big_normalize_tmp_bytes()); let seed: [u8; 32] = [0; 32]; let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar = module.new_scalar(1); + let mut s = module.new_scalar(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain - let mut s_dft: ScalarZnxDft = module.new_scalar_znx_dft(s.cols()); + let mut s_dft = module.new_scalar_znx_dft(s.cols()); // s_dft <- DFT(s) module.svp_prepare(&mut s_dft, 0, &s, 0); // Allocates a VecZnx with two columns: ct=(0, 0) - let mut ct: VecZnx = module.new_vec_znx( + let mut ct = module.new_vec_znx( 2, // Number of columns ct_size, // Number of small poly per column ); @@ -39,11 +39,8 @@ fn main() { module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source); // Scratch space for DFT values - let mut buf_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow( - 1, // Number of columns - ct.size(), // Number of polynomials per column - &mut tmp_bytes_dft, - ); + let scratch = scratch.borrow(); + let (mut buf_dft, scratch) = scratch.tmp_vec_znx_dft(&module, 1, ct_size); // Applies DFT(ct[1]) * DFT(s) module.svp_apply_dft( @@ -56,13 +53,14 @@ fn main() { ); // Alias scratch space (VecZnxDft is always at least as big as VecZnxBig) - let mut buf_big: VecZnxBig = buf_dft.alias_as_vec_znx_big(); + let (mut buf_big, scratch) = scratch.tmp_vec_znx_big(&module, 1, ct_size); // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) - module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + // Note: Since `vec_znx_idft_tmp_a` takes no argument for generic `Data` a full qualified path seems necessary + as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); // Creates a plaintext: VecZnx with 1 column - let mut m: VecZnx = module.new_vec_znx( + let mut m = module.new_vec_znx( 1, // Number of columns msg_size, // Number of small polynomials ); @@ -70,10 +68,11 @@ fn main() { want.iter_mut() .for_each(|x| *x = source.next_u64n(16, 15) as i64); m.encode_vec_i64(0, log_base2k, log_scale, &want, 4); - m.normalize(log_base2k, 0, &mut tmp_bytes_norm); + let (tmp_bytes_norm, scratch) = scratch.tmp_scalar_slice(n * std::mem::size_of::()); + m.normalize(log_base2k, 0, tmp_bytes_norm); // m - BIG(ct[1] * s) - module.vec_znx_big_sub_small_a_inplace( + module.vec_znx_big_sub_small_b_inplace( &mut buf_big, 0, // Selects the first column of the receiver &m, @@ -83,12 +82,9 @@ fn main() { // Normalizes back to VecZnx // ct[0] <- m - BIG(c1 * s) module.vec_znx_big_normalize( - log_base2k, - &mut ct, - 0, // Selects the first column of ct (ct[0]) - &buf_big, - 0, // Selects the first column of buf_big - &mut tmp_bytes_norm, + log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0]) + &buf_big, 0, // Selects the first column of buf_big + scratch, ); // Add noise to ct[0] @@ -118,14 +114,14 @@ fn main() { ); // BIG(c1 * s) = IDFT(DFT(c1 * s)) - module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0); + as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0); // BIG(c1 * s) + ct[0] module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0); // m + e <- BIG(ct[1] * s + ct[0]) - let mut res: VecZnx = module.new_vec_znx(1, ct_size); - module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm); + let mut res = module.new_vec_znx(1, ct_size); + module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch); // have = m * 2^{log_scale} + e let mut have: Vec = vec![i64::default(); n]; @@ -136,5 +132,7 @@ fn main() { .enumerate() .for_each(|(i, (a, b))| { println!("{}: {} {}", i, a, (*b as f64) / scale); - }) + }); + + module.free(); } diff --git a/base2k/examples/vmp.rs b/base2k/examples/vmp.rs index 710744e..36943f7 100644 --- a/base2k/examples/vmp.rs +++ b/base2k/examples/vmp.rs @@ -1,78 +1,78 @@ -use base2k::{ - Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, - ZnxInfos, ZnxLayout, alloc_aligned, -}; +// use base2k::{ +// Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, +// ZnxInfos, ZnxLayout, alloc_aligned, +// }; fn main() { - let log_n: i32 = 5; - let n: usize = 1 << log_n; + // let log_n: i32 = 5; + // let n: usize = 1 << log_n; - let module: Module = Module::::new(n); - let log_base2k: usize = 15; + // let module: Module = Module::::new(n); + // let log_base2k: usize = 15; - let a_cols: usize = 2; - let a_size: usize = 5; + // let a_cols: usize = 2; + // let a_size: usize = 5; - let log_k: usize = log_base2k * a_size - 5; + // let log_k: usize = log_base2k * a_size - 5; - let mat_rows: usize = a_size; - let mat_cols_in: usize = a_cols; - let mat_cols_out: usize = 2; - let mat_size: usize = a_size + 1; + // let mat_rows: usize = a_size; + // let mat_cols_in: usize = a_cols; + // let mat_cols_out: usize = 2; + // let mat_size: usize = a_size + 1; - let mut tmp_bytes_vmp: Vec = alloc_aligned( - module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) - | module.vmp_apply_dft_tmp_bytes( - a_size, - a_size, - mat_rows, - mat_cols_in, - mat_cols_out, - mat_size, - ), - ); + // let mut tmp_bytes_vmp: Vec = alloc_aligned( + // module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size) + // | module.vmp_apply_dft_tmp_bytes( + // a_size, + // a_size, + // mat_rows, + // mat_cols_in, + // mat_cols_out, + // mat_size, + // ), + // ); - let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); + // let mut tmp_bytes_dft: Vec = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size)); - let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); + // let mut a: VecZnx = module.new_vec_znx(a_cols, a_size); - (0..a_cols).for_each(|i| { - let mut values: Vec = vec![i64::default(); n]; - values[1 + i] = (1 << log_base2k) + 1; - a.encode_vec_i64(i, log_base2k, log_k, &values, 32); - a.normalize(log_base2k, i, &mut tmp_bytes_vmp); - a.print(n, i); - println!(); - }); + // (0..a_cols).for_each(|i| { + // let mut values: Vec = vec![i64::default(); n]; + // values[1 + i] = (1 << log_base2k) + 1; + // a.encode_vec_i64(i, log_base2k, log_k, &values, 32); + // a.normalize(log_base2k, i, &mut tmp_bytes_vmp); + // a.print(n, i); + // println!(); + // }); - let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + // let mut mat_znx_dft: MatZnxDft = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); - (0..a.size()).for_each(|row_i| { - let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); - (0..mat_cols_out).for_each(|j| { - tmp.at_mut(j, row_i)[1 + j] = 1 as i64; - }); - (0..mat_cols_in).for_each(|j| { - module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); - }) - }); + // (0..a.size()).for_each(|row_i| { + // let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size); + // (0..mat_cols_out).for_each(|j| { + // tmp.at_mut(j, row_i)[1 + j] = 1 as i64; + // }); + // (0..mat_cols_in).for_each(|j| { + // module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp); + // }) + // }); - let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); - module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); + // let mut c_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft); + // module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp); - let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); - let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); - (0..mat_cols_out).for_each(|i| { - module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); - module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); + // let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size); + // let mut c_big: VecZnxBig = c_dft.alias_as_vec_znx_big(); + // (0..mat_cols_out).for_each(|i| { + // module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + // module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp); - let mut values_res: Vec = vec![i64::default(); n]; - res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); - res.print(n, i); - println!(); - println!("{:?}", values_res); - println!(); - }); + // let mut values_res: Vec = vec![i64::default(); n]; + // res.decode_vec_i64(i, log_base2k, log_k, &mut values_res); + // res.print(n, i); + // println!(); + // println!("{:?}", values_res); + // println!(); + // }); - module.free(); + // module.free(); } diff --git a/base2k/src/lib.rs b/base2k/src/lib.rs index f33ce60..38d6b4e 100644 --- a/base2k/src/lib.rs +++ b/base2k/src/lib.rs @@ -18,17 +18,10 @@ pub mod vec_znx_dft_ops; pub mod vec_znx_ops; pub mod znx_base; -use std::{ - any::type_name, - ops::{DerefMut, Sub}, -}; - pub use encoding::*; pub use mat_znx_dft::*; pub use mat_znx_dft_ops::*; pub use module::*; -use rand_core::le; -use rand_distr::num_traits::sign; pub use sampling::*; pub use scalar_znx::*; pub use scalar_znx_dft::*; @@ -133,6 +126,8 @@ pub fn alloc_aligned(size: usize) -> Vec { ) } +// Scratch implementation below + pub struct ScratchOwned(Vec); impl ScratchOwned { @@ -141,16 +136,16 @@ impl ScratchOwned { Self(data) } - pub fn borrow(&mut self) -> &mut ScratchBorr { - ScratchBorr::new(&mut self.0) + pub fn borrow(&mut self) -> &mut Scratch { + Scratch::new(&mut self.0) } } -pub struct ScratchBorr { +pub struct Scratch { data: [u8], } -impl ScratchBorr { +impl Scratch { fn new(data: &mut [u8]) -> &mut Self { unsafe { &mut *(data as *mut [u8] as *mut Self) } } @@ -175,14 +170,14 @@ impl ScratchBorr { panic!( "Attempted to take {} from scratch with {} aligned bytes left", take_len, - take_len, + aligned_len, // type_name::(), // aligned_len ); } } - fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { + pub fn tmp_scalar_slice(&mut self, len: usize) -> (&mut [T], &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::()); unsafe { @@ -193,7 +188,7 @@ impl ScratchBorr { } } - fn tmp_vec_znx_dft( + pub fn tmp_vec_znx_dft( &mut self, module: &Module, cols: usize, @@ -207,103 +202,17 @@ impl ScratchBorr { ) } - fn tmp_vec_znx_big From<&'a mut [u8]>, B: Backend>( + pub fn tmp_vec_znx_big( &mut self, module: &Module, cols: usize, size: usize, - ) -> (VecZnxBig, &mut Self) { + ) -> (VecZnxBig<&mut [u8], B>, &mut Self) { let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size)); ( - VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size), + VecZnxBig::from_data(take_slice, module.n(), cols, size), Self::new(rem_slice), ) } } - -// pub struct ScratchBorrowed<'a> { -// data: &'a mut [u8], -// } - -// impl<'a> ScratchBorrowed<'a> { -// fn take_slice(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) { -// let ptr = self.data.as_mut_ptr(); -// let self_len = self.data.len(); - -// //TODO(Jay): print the offset sometimes, just to check -// let aligned_offset = ptr.align_offset(DEFAULTALIGN); -// let aligned_len = self_len.saturating_sub(aligned_offset); - -// let take_len_bytes = take_len * std::mem::size_of::(); - -// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) { -// unsafe { -// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes); -// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len); - -// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes); - -// return (take_slice, ScratchBorrowed { data: rem_slice }); -// } -// } else { -// panic!( -// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left", -// take_len_bytes, -// take_len, -// type_name::(), -// aligned_len -// ); -// } -// } - -// fn reborrow(&mut self) -> ScratchBorrowed<'a> { -// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a. -// // But it feels that there should be a simpler impl. than the one below -// Self { -// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) }, -// } -// } - -// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) { -// let (data, re_scratch) = self.take_slice::(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); -// ( -// VecZnxDft::from_data(data, module.n(), cols, size), -// re_scratch, -// ) -// } - -// pub(crate) fn len(&self) -> usize { -// self.data.len() -// } -// } - -// pub trait Scratch { -// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (D, &mut Self); -// } - -// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr { -// fn tmp_vec_znx_dft(&mut self, module: &Module, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) { -// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size)); -// ( -// data -// rem_scratch, -// ) -// } - -// // fn tmp_vec_znx_big(&mut self, module: &Module, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) { -// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size)); -// // // ( -// // // VecZnxBig::from_data(data, module.n(), cols, size), -// // // re_scratch, -// // // ) -// // } - -// // fn scalar_slice(&mut self, len: usize) -> (&mut [T], Self) { -// // self.take_slice::(len) -// // } - -// // fn reborrow(&mut self) -> Self { -// // self.reborrow() -// // } -// } diff --git a/base2k/src/mat_znx_dft_ops.rs b/base2k/src/mat_znx_dft_ops.rs index 5ab44df..658ff5d 100644 --- a/base2k/src/mat_znx_dft_ops.rs +++ b/base2k/src/mat_znx_dft_ops.rs @@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vmp; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, Scratch, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, }; @@ -25,15 +25,6 @@ pub trait MatZnxDftAlloc { size: usize, bytes: Vec, ) -> MatZnxDftAllocOwned; - - // fn new_mat_znx_dft_from_bytes_borrow( - // &self, - // rows: usize, - // cols_in: usize, - // cols_out: usize, - // size: usize, - // bytes: &mut [u8], - // ) -> MatZnxDft; } pub trait MatZnxDftScratch { @@ -101,7 +92,7 @@ pub trait MatZnxDftOps { b_row: usize, b_col_in: usize, a: &VecZnx, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); /// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig]. @@ -118,7 +109,7 @@ pub trait MatZnxDftOps { a: &MatZnxDft, b_row: usize, b_col_in: usize, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); /// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft]. @@ -165,7 +156,7 @@ pub trait MatZnxDftOps { /// * `a`: the left operand [VecZnx] of the vector matrix product. /// * `b`: the right operand [MatZnxDft] of the vector matrix product. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes]. - fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut ScratchBorr); + fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &MatZnxDft, scratch: &mut Scratch); /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. @@ -197,7 +188,7 @@ pub trait MatZnxDftOps { c: &mut VecZnxDft, a: &VecZnxDft, b: &MatZnxDft, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); } @@ -274,18 +265,14 @@ impl MatZnxDftScratch for Module { } } -impl MatZnxDftOps for Module -where - DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>, - Data: AsRef<[u8]>, -{ +impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module { fn vmp_prepare_row( &self, - b: &mut MatZnxDft, + b: &mut MatZnxDft<&mut [u8], FFT64>, b_row: usize, b_col_in: usize, - a: &VecZnx, - scratch: &mut ScratchBorr, + a: &VecZnx<&[u8]>, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { @@ -328,21 +315,19 @@ where let a_size: usize = a.size(); // let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size)); - let (mut a_dft, _) = scratch.tmp_scalar_slice(12); - DataMut::from(a_dft); - // let (mut a_dft, _) = scratch.tmp_vec_znx_dft::(self, cols_out, a_size); + let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<_>(self, cols_out, a_size); (0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i)); - Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft); + Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft.to_ref()); } fn vmp_extract_row( &self, log_base2k: usize, - b: &mut VecZnx, - a: &MatZnxDft, + b: &mut VecZnx<&mut [u8]>, + a: &MatZnxDft<&[u8], FFT64>, a_row: usize, a_col_in: usize, - mut scratch: &mut ScratchBorr, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { @@ -386,12 +371,18 @@ where Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in); let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size); (0..cols_out).for_each(|i| { - >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); + >::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i); self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch); }); } - fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft, b_row: usize, b_col_in: usize, a: &VecZnxDft) { + fn vmp_prepare_row_dft( + &self, + b: &mut MatZnxDft<&mut [u8], FFT64>, + b_row: usize, + b_col_in: usize, + a: &VecZnxDft<&[u8], FFT64>, + ) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -436,7 +427,13 @@ where } } - fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &MatZnxDft, a_row: usize, a_col_in: usize) { + fn vmp_extract_row_dft( + &self, + b: &mut VecZnxDft<&mut [u8], FFT64>, + a: &MatZnxDft<&[u8], FFT64>, + a_row: usize, + a_col_in: usize, + ) { #[cfg(debug_assertions)] { assert_eq!(b.n(), self.n()); @@ -482,10 +479,10 @@ where fn vmp_apply_dft( &self, - c: &mut VecZnxDft, - a: &VecZnx, - b: &MatZnxDft, - mut scratch: &mut ScratchBorr, + c: &mut VecZnxDft<&mut [u8], FFT64>, + a: &VecZnx<&[u8]>, + b: &MatZnxDft<&[u8], FFT64>, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { @@ -547,68 +544,70 @@ where fn vmp_apply_dft_to_dft( &self, - c: &mut VecZnxDft, - a: &VecZnxDft, - b: &MatZnxDft, - mut scratch: &mut ScratchBorr, + c: &mut VecZnxDft<&mut [u8], FFT64>, + a: &VecZnxDft<&[u8], FFT64>, + b: &MatZnxDft<&[u8], FFT64>, + scratch: &mut Scratch, ) { - #[cfg(debug_assertions)] { - assert_eq!(c.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.n(), self.n()); - assert_eq!( - c.cols(), - b.cols_out(), - "c.cols(): {} != b.cols_out: {}", - c.cols(), - b.cols_out() - ); - assert_eq!( - a.cols(), - b.cols_in(), - "a.cols(): {} != b.cols_in: {}", - a.cols(), - b.cols_in() - ); - // assert!( - // tmp_bytes.len() - // >= self.vmp_apply_dft_to_dft_tmp_bytes( - // c.cols(), - // c.size(), - // a.cols(), - // a.size(), - // b.rows(), - // b.cols_in(), - // b.cols_out(), - // b.size() - // ) - // ); - // assert_alignement(tmp_bytes.as_ptr()); - } + #[cfg(debug_assertions)] + { + assert_eq!(c.n(), self.n()); + assert_eq!(b.n(), self.n()); + assert_eq!(a.n(), self.n()); + assert_eq!( + c.cols(), + b.cols_out(), + "c.cols(): {} != b.cols_out: {}", + c.cols(), + b.cols_out() + ); + assert_eq!( + a.cols(), + b.cols_in(), + "a.cols(): {} != b.cols_in: {}", + a.cols(), + b.cols_in() + ); + // assert!( + // tmp_bytes.len() + // >= self.vmp_apply_dft_to_dft_tmp_bytes( + // c.cols(), + // c.size(), + // a.cols(), + // a.size(), + // b.rows(), + // b.cols_in(), + // b.cols_out(), + // b.size() + // ) + // ); + // assert_alignement(tmp_bytes.as_ptr()); + } - let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( - c.cols(), - c.size(), - a.cols(), - a.size(), - b.rows(), - b.cols_in(), - b.cols_out(), - b.size(), - )); - unsafe { - vmp::vmp_apply_dft_to_dft( - self.ptr, - c.as_mut_ptr() as *mut vec_znx_dft_t, - c.poly_count() as u64, - a.as_ptr() as *const vec_znx_dft_t, - a.poly_count() as u64, - b.as_ptr() as *const vmp::vmp_pmat_t, - b.rows() as u64, - (b.size() * b.cols()) as u64, - tmp_bytes.as_mut_ptr(), - ) + let (tmp_bytes, _) = scratch.tmp_scalar_slice(self.vmp_apply_dft_to_dft_tmp_bytes( + c.cols(), + c.size(), + a.cols(), + a.size(), + b.rows(), + b.cols_in(), + b.cols_out(), + b.size(), + )); + unsafe { + vmp::vmp_apply_dft_to_dft( + self.ptr, + c.as_mut_ptr() as *mut vec_znx_dft_t, + c.poly_count() as u64, + a.as_ptr() as *const vec_znx_dft_t, + a.poly_count() as u64, + b.as_ptr() as *const vmp::vmp_pmat_t, + b.rows() as u64, + (b.size() * b.cols()) as u64, + tmp_bytes.as_mut_ptr(), + ) + } } } } @@ -658,27 +657,31 @@ mod tests { module.vec_znx_dft(&mut a_dft, col_out, &a, col_out); }); - // let g = vmpmat_0.to_mut(); - - module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow()); + module.vmp_prepare_row( + &mut vmpmat_0.to_mut(), + row_i, + col_in, + &a.to_ref(), + scratch.borrow(), + ); // Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft) - module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft); + module.vmp_prepare_row_dft(&mut vmpmat_1.to_mut(), row_i, col_in, &a_dft.to_ref()); assert_eq!(vmpmat_0.raw(), vmpmat_1.raw()); // Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft) - module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in); + module.vmp_extract_row_dft(&mut b_dft.to_mut(), &vmpmat_0.to_ref(), row_i, col_in); assert_eq!(a_dft.raw(), b_dft.raw()); // Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big) - // module.vmp_extract_row( - // log_base2k, - // &mut b.to_mut(), - // &vmpmat_0.to_ref(), - // row_i, - // col_in, - // scratch.borrow(), - // ); + module.vmp_extract_row( + log_base2k, + &mut b.to_mut(), + &vmpmat_0.to_ref(), + row_i, + col_in, + scratch.borrow(), + ); (0..mat_cols_out).for_each(|col_out| { module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes); diff --git a/base2k/src/vec_znx.rs b/base2k/src/vec_znx.rs index b386604..09b0051 100644 --- a/base2k/src/vec_znx.rs +++ b/base2k/src/vec_znx.rs @@ -217,7 +217,7 @@ pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>; pub type VecZnxRef<'a> = VecZnx<&'a [u8]>; impl VecZnx> { - pub(crate) fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + pub fn to_mut(&mut self) -> VecZnx<&mut [u8]> { VecZnx { data: self.data.as_mut_slice(), n: self.n, @@ -226,7 +226,7 @@ impl VecZnx> { } } - pub(crate) fn to_ref(&self) -> VecZnx<&[u8]> { + pub fn to_ref(&self) -> VecZnx<&[u8]> { VecZnx { data: self.data.as_slice(), n: self.n, @@ -235,3 +235,14 @@ impl VecZnx> { } } } + +impl VecZnx<&mut [u8]> { + pub fn to_ref(&self) -> VecZnx<&[u8]> { + VecZnx { + data: &self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } +} diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 7442f11..fe67516 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -3,7 +3,7 @@ use crate::znx_base::{ZnxInfos, ZnxView}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned}; use std::marker::PhantomData; -const VEC_ZNX_BIG_ROWS: usize = 1; +// const VEC_ZNX_BIG_ROWS: usize = 1; /// VecZnxBig is `Backend` dependent, denoted with backend generic `B` pub struct VecZnxBig { @@ -97,7 +97,7 @@ impl VecZnxBig { pub type VecZnxBigOwned = VecZnxBig, B>; impl VecZnxBig, B> { - pub(crate) fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { + pub fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { VecZnxBig { data: self.data.as_mut_slice(), n: self.n, @@ -107,7 +107,7 @@ impl VecZnxBig, B> { } } - pub(crate) fn to_ref(&self) -> VecZnxBig<&[u8], B> { + pub fn to_ref(&self) -> VecZnxBig<&[u8], B> { VecZnxBig { data: self.data.as_slice(), n: self.n, @@ -117,9 +117,3 @@ impl VecZnxBig, B> { } } } - -// impl VecZnxBig { -// pub fn print(&self, n: usize, col: usize) { -// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); -// } -// } diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 20b4f2e..d0e4bd3 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -1,9 +1,6 @@ use crate::ffi::vec_znx; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; -use crate::{ - Backend, DataView, FFT64, Module, ScratchBorr, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, assert_alignement, - bytes_of_vec_znx_big, -}; +use crate::{Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, bytes_of_vec_znx_big}; pub trait VecZnxBigAlloc { /// Allocates a vector Z[X]/(X^N+1) that stores not normalized values. @@ -129,7 +126,7 @@ pub trait VecZnxBigOps { res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ); /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. @@ -503,7 +500,7 @@ where res_col: usize, a: &VecZnxBig, a_col: usize, - scratch: &mut ScratchBorr, + scratch: &mut Scratch, ) { #[cfg(debug_assertions)] { diff --git a/base2k/src/vec_znx_dft.rs b/base2k/src/vec_znx_dft.rs index 5d15c00..a4a3242 100644 --- a/base2k/src/vec_znx_dft.rs +++ b/base2k/src/vec_znx_dft.rs @@ -4,7 +4,7 @@ use crate::ffi::vec_znx_dft; use crate::znx_base::ZnxInfos; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned}; -const VEC_ZNX_DFT_ROWS: usize = 1; +// const VEC_ZNX_DFT_ROWS: usize = 1; // VecZnxDft is `Backend` dependent denoted with generic `B` pub struct VecZnxDft { @@ -97,52 +97,36 @@ impl VecZnxDft { } } -// impl ZnxAlloc for VecZnxDft { -// type Scalar = u8; +impl VecZnxDft, B> { + pub fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + VecZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } -// fn from_bytes_borrow(module: &Module, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self { -// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size)); -// Self { -// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes), -// _marker: PhantomData, -// } -// } + pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: self.data.as_slice(), + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} -// fn bytes_of(module: &Module, _rows: usize, cols: usize, size: usize) -> usize { -// debug_assert_eq!( -// _rows, VEC_ZNX_DFT_ROWS, -// "rows != {} not supported for VecZnxDft", -// VEC_ZNX_DFT_ROWS -// ); -// unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols } -// } -// } - -// impl VecZnxDft { -// pub fn print(&self, n: usize, col: usize) { -// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n])); -// } -// } - -// impl VecZnxDft { -// /// Cast a [VecZnxDft] into a [VecZnxBig]. -// /// The returned [VecZnxBig] shares the backing array -// /// with the original [VecZnxDft]. -// pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig { -// assert!( -// self.data().len() == 0, -// "cannot alias VecZnxDft into VecZnxBig if it owns the data" -// ); -// VecZnxBig:: { -// inner: ZnxBase { -// data: Vec::new(), -// ptr: self.ptr(), -// n: self.n(), -// rows: self.rows(), -// cols: self.cols(), -// size: self.size(), -// }, -// _marker: PhantomData, -// } -// } -// } +impl VecZnxDft<&mut [u8], B> { + pub fn to_ref(&self) -> VecZnxDft<&[u8], B> { + VecZnxDft { + data: &self.data, + n: self.n, + cols: self.cols, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/base2k/src/vec_znx_dft_ops.rs b/base2k/src/vec_znx_dft_ops.rs index 9a1db2a..e894ef4 100644 --- a/base2k/src/vec_znx_dft_ops.rs +++ b/base2k/src/vec_znx_dft_ops.rs @@ -22,19 +22,6 @@ pub trait VecZnxDftAlloc { /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxDftOwned; - // /// Returns a new [VecZnxDft] with the provided bytes array as backing array. - // /// - // /// Behavior: the backing array is only borrowed. - // /// - // /// # Arguments - // /// - // /// * `cols`: the number of cols of the [VecZnxDft]. - // /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft]. - // /// - // /// # Panics - // /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft]. - // fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft; - /// Returns a new [VecZnxDft] with the provided bytes array as backing array. /// /// # Arguments diff --git a/base2k/src/vec_znx_ops.rs b/base2k/src/vec_znx_ops.rs index d647860..a8edb12 100644 --- a/base2k/src/vec_znx_ops.rs +++ b/base2k/src/vec_znx_ops.rs @@ -23,19 +23,6 @@ pub trait VecZnxAlloc { /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec) -> VecZnxOwned; - // /// Instantiates a new [VecZnx] from a slice of bytes. - // /// The returned [VecZnx] does take ownership of the slice of bytes. - // /// - // /// # Arguments - // /// - // /// * `cols`: the number of polynomials. - // /// * `size`: the number small polynomials per column. - // /// - // /// # Panic - // /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx]. - // fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx; - // (Jay)TODO - /// Returns the number of bytes necessary to allocate /// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes] /// or [VecZnxOps::new_vec_znx_from_bytes_borrow]. @@ -140,7 +127,6 @@ pub trait VecZnxScratch { } impl VecZnxAlloc for Module { - //(Jay)TODO: One must define the Scalar generic param here. fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned { VecZnxOwned::new::(self.n(), cols, size) } diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index 69afef8..9eea5bb 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -1,4 +1,5 @@ use itertools::izip; +use rand_distr::num_traits::Zero; use std::cmp::min; pub trait ZnxInfos { @@ -157,7 +158,7 @@ pub fn switch_degree + ZnxZero, D: ZnxView use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub}; -use crate::{ScratchBorr, cast_mut}; +use crate::Scratch; pub trait Integer: Copy + Default @@ -183,32 +184,15 @@ impl Integer for i128 { const BITS: u32 = 128; } -// (Jay)TODO: implement rsh for VecZnx, VecZnxBig -// pub trait ZnxRsh: ZnxZero { -// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) { -// rsh(k, log_base2k, self, col, carry) -// } -// } -pub fn rsh(k: usize, log_base2k: usize, a: &mut V, a_col: usize, scratch: &mut ScratchBorr) +//(Jay)Note: `rsh` impl. ignores the column +pub fn rsh(k: usize, log_base2k: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch) where - V::Scalar: From + Integer, + V::Scalar: From + Integer + Zero, { let n: usize = a.n(); - let size: usize = a.size(); + let _size: usize = a.size(); let cols: usize = a.cols(); - // #[cfg(debug_assertions)] - // { - // assert!( - // tmp_bytes.len() >= rsh_tmp_bytes::(n), - // "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})", - // tmp_bytes.len() / size_of::(), - // n, - // size, - // ); - // assert_alignement(tmp_bytes.as_ptr()); - // } - let size: usize = a.size(); let steps: usize = k / log_base2k; @@ -240,7 +224,7 @@ where *xi = (*xi - *ci) >> k_rem_t; }); }); - //TODO: ZERO CARRYcarry + carry.iter_mut().for_each(|r| *r = V::Scalar::zero()); }) } }