mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
amend rlwe_encrypt example and minor changes at multiple places
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use base2k::{
|
||||
Encoding, FFT64, Module, Sampling, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps,
|
||||
VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned,
|
||||
Encoding, FFT64, Module, Sampling, ScalarAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxAlloc, VecZnxBigOps,
|
||||
VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
@@ -13,24 +13,24 @@ fn main() {
|
||||
let log_scale: usize = msg_size * log_base2k - 5;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
|
||||
let mut tmp_bytes_norm: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut tmp_bytes_dft = alloc_aligned(module.bytes_of_vec_znx_dft(1, ct_size));
|
||||
let mut scratch =
|
||||
ScratchOwned::new((2 * module.bytes_of_vec_znx_dft(1, ct_size)) + 2 * module.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let seed: [u8; 32] = [0; 32];
|
||||
let mut source: Source = Source::new(seed);
|
||||
|
||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||
let mut s: Scalar = module.new_scalar(1);
|
||||
let mut s = module.new_scalar(1);
|
||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft(s.cols());
|
||||
let mut s_dft = module.new_scalar_znx_dft(s.cols());
|
||||
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||
|
||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||
let mut ct: VecZnx = module.new_vec_znx(
|
||||
let mut ct = module.new_vec_znx(
|
||||
2, // Number of columns
|
||||
ct_size, // Number of small poly per column
|
||||
);
|
||||
@@ -39,11 +39,8 @@ fn main() {
|
||||
module.fill_uniform(log_base2k, &mut ct, 1, ct_size, &mut source);
|
||||
|
||||
// Scratch space for DFT values
|
||||
let mut buf_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(
|
||||
1, // Number of columns
|
||||
ct.size(), // Number of polynomials per column
|
||||
&mut tmp_bytes_dft,
|
||||
);
|
||||
let scratch = scratch.borrow();
|
||||
let (mut buf_dft, scratch) = scratch.tmp_vec_znx_dft(&module, 1, ct_size);
|
||||
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_dft(
|
||||
@@ -56,13 +53,14 @@ fn main() {
|
||||
);
|
||||
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
let mut buf_big: VecZnxBig<FFT64> = buf_dft.alias_as_vec_znx_big();
|
||||
let (mut buf_big, scratch) = scratch.tmp_vec_znx_big(&module, 1, ct_size);
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
// Note: Since `vec_znx_idft_tmp_a` takes no argument for generic `Data` a full qualified path seems necessary
|
||||
<Module<_> as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m: VecZnx = module.new_vec_znx(
|
||||
let mut m = module.new_vec_znx(
|
||||
1, // Number of columns
|
||||
msg_size, // Number of small polynomials
|
||||
);
|
||||
@@ -70,10 +68,11 @@ fn main() {
|
||||
want.iter_mut()
|
||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
||||
m.normalize(log_base2k, 0, &mut tmp_bytes_norm);
|
||||
let (tmp_bytes_norm, scratch) = scratch.tmp_scalar_slice(n * std::mem::size_of::<i64>());
|
||||
m.normalize(log_base2k, 0, tmp_bytes_norm);
|
||||
|
||||
// m - BIG(ct[1] * s)
|
||||
module.vec_znx_big_sub_small_a_inplace(
|
||||
module.vec_znx_big_sub_small_b_inplace(
|
||||
&mut buf_big,
|
||||
0, // Selects the first column of the receiver
|
||||
&m,
|
||||
@@ -83,12 +82,9 @@ fn main() {
|
||||
// Normalizes back to VecZnx
|
||||
// ct[0] <- m - BIG(c1 * s)
|
||||
module.vec_znx_big_normalize(
|
||||
log_base2k,
|
||||
&mut ct,
|
||||
0, // Selects the first column of ct (ct[0])
|
||||
&buf_big,
|
||||
0, // Selects the first column of buf_big
|
||||
&mut tmp_bytes_norm,
|
||||
log_base2k, &mut ct, 0, // Selects the first column of ct (ct[0])
|
||||
&buf_big, 0, // Selects the first column of buf_big
|
||||
scratch,
|
||||
);
|
||||
|
||||
// Add noise to ct[0]
|
||||
@@ -118,14 +114,14 @@ fn main() {
|
||||
);
|
||||
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
<Module<_> as VecZnxDftOps<_, &[u8], _>>::vec_znx_idft_tmp_a(&module, &mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// BIG(c1 * s) + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
// m + e <- BIG(ct[1] * s + ct[0])
|
||||
let mut res: VecZnx = module.new_vec_znx(1, ct_size);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, &mut tmp_bytes_norm);
|
||||
let mut res = module.new_vec_znx(1, ct_size);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch);
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
@@ -136,5 +132,7 @@ fn main() {
|
||||
.enumerate()
|
||||
.for_each(|(i, (a, b))| {
|
||||
println!("{}: {} {}", i, a, (*b as f64) / scale);
|
||||
})
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
@@ -1,78 +1,78 @@
|
||||
use base2k::{
|
||||
Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
||||
ZnxInfos, ZnxLayout, alloc_aligned,
|
||||
};
|
||||
// use base2k::{
|
||||
// Encoding, FFT64, MatZnxDft, MatZnxDftOps, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps,
|
||||
// ZnxInfos, ZnxLayout, alloc_aligned,
|
||||
// };
|
||||
|
||||
fn main() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
// let log_n: i32 = 5;
|
||||
// let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 15;
|
||||
// let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
// let log_base2k: usize = 15;
|
||||
|
||||
let a_cols: usize = 2;
|
||||
let a_size: usize = 5;
|
||||
// let a_cols: usize = 2;
|
||||
// let a_size: usize = 5;
|
||||
|
||||
let log_k: usize = log_base2k * a_size - 5;
|
||||
// let log_k: usize = log_base2k * a_size - 5;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = 2;
|
||||
let mat_size: usize = a_size + 1;
|
||||
// let mat_rows: usize = a_size;
|
||||
// let mat_cols_in: usize = a_cols;
|
||||
// let mat_cols_out: usize = 2;
|
||||
// let mat_size: usize = a_size + 1;
|
||||
|
||||
let mut tmp_bytes_vmp: Vec<u8> = alloc_aligned(
|
||||
module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size)
|
||||
| module.vmp_apply_dft_tmp_bytes(
|
||||
a_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
),
|
||||
);
|
||||
// let mut tmp_bytes_vmp: Vec<u8> = alloc_aligned(
|
||||
// module.vmp_prepare_row_tmp_bytes(mat_cols_out, mat_size)
|
||||
// | module.vmp_apply_dft_tmp_bytes(
|
||||
// a_size,
|
||||
// a_size,
|
||||
// mat_rows,
|
||||
// mat_cols_in,
|
||||
// mat_cols_out,
|
||||
// mat_size,
|
||||
// ),
|
||||
// );
|
||||
|
||||
let mut tmp_bytes_dft: Vec<u8> = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size));
|
||||
// let mut tmp_bytes_dft: Vec<u8> = alloc_aligned(module.bytes_of_vec_znx_dft(mat_cols_out, mat_size));
|
||||
|
||||
let mut a: VecZnx = module.new_vec_znx(a_cols, a_size);
|
||||
// let mut a: VecZnx = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
let mut values: Vec<i64> = vec![i64::default(); n];
|
||||
values[1 + i] = (1 << log_base2k) + 1;
|
||||
a.encode_vec_i64(i, log_base2k, log_k, &values, 32);
|
||||
a.normalize(log_base2k, i, &mut tmp_bytes_vmp);
|
||||
a.print(n, i);
|
||||
println!();
|
||||
});
|
||||
// (0..a_cols).for_each(|i| {
|
||||
// let mut values: Vec<i64> = vec![i64::default(); n];
|
||||
// values[1 + i] = (1 << log_base2k) + 1;
|
||||
// a.encode_vec_i64(i, log_base2k, log_k, &values, 32);
|
||||
// a.normalize(log_base2k, i, &mut tmp_bytes_vmp);
|
||||
// a.print(n, i);
|
||||
// println!();
|
||||
// });
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
// let mut mat_znx_dft: MatZnxDft<FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
(0..mat_cols_out).for_each(|j| {
|
||||
tmp.at_mut(j, row_i)[1 + j] = 1 as i64;
|
||||
});
|
||||
(0..mat_cols_in).for_each(|j| {
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp);
|
||||
})
|
||||
});
|
||||
// (0..a.size()).for_each(|row_i| {
|
||||
// let mut tmp: VecZnx = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
// (0..mat_cols_out).for_each(|j| {
|
||||
// tmp.at_mut(j, row_i)[1 + j] = 1 as i64;
|
||||
// });
|
||||
// (0..mat_cols_in).for_each(|j| {
|
||||
// module.vmp_prepare_row(&mut mat_znx_dft, row_i, j, &tmp, &mut tmp_bytes_vmp);
|
||||
// })
|
||||
// });
|
||||
|
||||
let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft);
|
||||
module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp);
|
||||
// let mut c_dft: VecZnxDft<FFT64> = module.new_vec_znx_dft_from_bytes_borrow(mat_cols_out, mat_size, &mut tmp_bytes_dft);
|
||||
// module.vmp_apply_dft(&mut c_dft, &a, &mat_znx_dft, &mut tmp_bytes_vmp);
|
||||
|
||||
let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size);
|
||||
let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp);
|
||||
// let mut res: VecZnx = module.new_vec_znx(mat_cols_out, a_size);
|
||||
// let mut c_big: VecZnxBig<FFT64> = c_dft.alias_as_vec_znx_big();
|
||||
// (0..mat_cols_out).for_each(|i| {
|
||||
// module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
// module.vec_znx_big_normalize(log_base2k, &mut res, i, &c_big, i, &mut tmp_bytes_vmp);
|
||||
|
||||
let mut values_res: Vec<i64> = vec![i64::default(); n];
|
||||
res.decode_vec_i64(i, log_base2k, log_k, &mut values_res);
|
||||
res.print(n, i);
|
||||
println!();
|
||||
println!("{:?}", values_res);
|
||||
println!();
|
||||
});
|
||||
// let mut values_res: Vec<i64> = vec![i64::default(); n];
|
||||
// res.decode_vec_i64(i, log_base2k, log_k, &mut values_res);
|
||||
// res.print(n, i);
|
||||
// println!();
|
||||
// println!("{:?}", values_res);
|
||||
// println!();
|
||||
// });
|
||||
|
||||
module.free();
|
||||
// module.free();
|
||||
}
|
||||
|
||||
@@ -18,17 +18,10 @@ pub mod vec_znx_dft_ops;
|
||||
pub mod vec_znx_ops;
|
||||
pub mod znx_base;
|
||||
|
||||
use std::{
|
||||
any::type_name,
|
||||
ops::{DerefMut, Sub},
|
||||
};
|
||||
|
||||
pub use encoding::*;
|
||||
pub use mat_znx_dft::*;
|
||||
pub use mat_znx_dft_ops::*;
|
||||
pub use module::*;
|
||||
use rand_core::le;
|
||||
use rand_distr::num_traits::sign;
|
||||
pub use sampling::*;
|
||||
pub use scalar_znx::*;
|
||||
pub use scalar_znx_dft::*;
|
||||
@@ -133,6 +126,8 @@ pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
)
|
||||
}
|
||||
|
||||
// Scratch implementation below
|
||||
|
||||
pub struct ScratchOwned(Vec<u8>);
|
||||
|
||||
impl ScratchOwned {
|
||||
@@ -141,16 +136,16 @@ impl ScratchOwned {
|
||||
Self(data)
|
||||
}
|
||||
|
||||
pub fn borrow(&mut self) -> &mut ScratchBorr {
|
||||
ScratchBorr::new(&mut self.0)
|
||||
pub fn borrow(&mut self) -> &mut Scratch {
|
||||
Scratch::new(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ScratchBorr {
|
||||
pub struct Scratch {
|
||||
data: [u8],
|
||||
}
|
||||
|
||||
impl ScratchBorr {
|
||||
impl Scratch {
|
||||
fn new(data: &mut [u8]) -> &mut Self {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
||||
}
|
||||
@@ -175,14 +170,14 @@ impl ScratchBorr {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len,
|
||||
take_len,
|
||||
aligned_len,
|
||||
// type_name::<T>(),
|
||||
// aligned_len
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn tmp_scalar_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
pub fn tmp_scalar_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
@@ -193,7 +188,7 @@ impl ScratchBorr {
|
||||
}
|
||||
}
|
||||
|
||||
fn tmp_vec_znx_dft<B: Backend>(
|
||||
pub fn tmp_vec_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
@@ -207,103 +202,17 @@ impl ScratchBorr {
|
||||
)
|
||||
}
|
||||
|
||||
fn tmp_vec_znx_big<D: for<'a> From<&'a mut [u8]>, B: Backend>(
|
||||
pub fn tmp_vec_znx_big<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<D, B>, &mut Self) {
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxBig::from_data(D::from(take_slice), module.n(), cols, size),
|
||||
VecZnxBig::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// pub struct ScratchBorrowed<'a> {
|
||||
// data: &'a mut [u8],
|
||||
// }
|
||||
|
||||
// impl<'a> ScratchBorrowed<'a> {
|
||||
// fn take_slice<T>(&mut self, take_len: usize) -> (&mut [T], ScratchBorrowed<'_>) {
|
||||
// let ptr = self.data.as_mut_ptr();
|
||||
// let self_len = self.data.len();
|
||||
|
||||
// //TODO(Jay): print the offset sometimes, just to check
|
||||
// let aligned_offset = ptr.align_offset(DEFAULTALIGN);
|
||||
// let aligned_len = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
// let take_len_bytes = take_len * std::mem::size_of::<T>();
|
||||
|
||||
// if let Some(rem_len) = aligned_len.checked_sub(take_len_bytes) {
|
||||
// unsafe {
|
||||
// let rem_ptr = ptr.add(aligned_offset).add(take_len_bytes);
|
||||
// let rem_slice = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
// let take_slice = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset) as *mut T, take_len_bytes);
|
||||
|
||||
// return (take_slice, ScratchBorrowed { data: rem_slice });
|
||||
// }
|
||||
// } else {
|
||||
// panic!(
|
||||
// "Attempted to take {} (={} elements of {}) from scratch with {} aligned bytes left",
|
||||
// take_len_bytes,
|
||||
// take_len,
|
||||
// type_name::<T>(),
|
||||
// aligned_len
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn reborrow(&mut self) -> ScratchBorrowed<'a> {
|
||||
// //(Jay)TODO: `data: &mut *self.data` does not work because liftime of &mut self is different from 'a.
|
||||
// // But it feels that there should be a simpler impl. than the one below
|
||||
// Self {
|
||||
// data: unsafe { &mut *std::ptr::slice_from_raw_parts_mut(self.data.as_mut_ptr(), self.data.len()) },
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxDft<&mut [u8], B>, Self) {
|
||||
// let (data, re_scratch) = self.take_slice::<u8>(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size));
|
||||
// (
|
||||
// VecZnxDft::from_data(data, module.n(), cols, size),
|
||||
// re_scratch,
|
||||
// )
|
||||
// }
|
||||
|
||||
// pub(crate) fn len(&self) -> usize {
|
||||
// self.data.len()
|
||||
// }
|
||||
// }
|
||||
|
||||
// pub trait Scratch<D> {
|
||||
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (D, &mut Self);
|
||||
// }
|
||||
|
||||
// impl<'a> Scratch<&'a mut [u8]> for ScratchBorr {
|
||||
// fn tmp_vec_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (&'a mut [u8], &mut Self) {
|
||||
// let (data, rem_scratch) = self.tmp_scalar_slice(vec_znx_dft::bytes_of_vec_znx_dft(module, cols, size));
|
||||
// (
|
||||
// data
|
||||
// rem_scratch,
|
||||
// )
|
||||
// }
|
||||
|
||||
// // fn tmp_vec_znx_big<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnxBig<&mut [u8], B>, Self) {
|
||||
// // // let (data, re_scratch) = self.take_slice(vec_znx_big::bytes_of_vec_znx_big(module, cols, size));
|
||||
// // // (
|
||||
// // // VecZnxBig::from_data(data, module.n(), cols, size),
|
||||
// // // re_scratch,
|
||||
// // // )
|
||||
// // }
|
||||
|
||||
// // fn scalar_slice<T>(&mut self, len: usize) -> (&mut [T], Self) {
|
||||
// // self.take_slice::<T>(len)
|
||||
// // }
|
||||
|
||||
// // fn reborrow(&mut self) -> Self {
|
||||
// // self.reborrow()
|
||||
// // }
|
||||
// }
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, ScratchBorr, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAllocOwned, Module, Scratch, VecZnx, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps,
|
||||
};
|
||||
|
||||
@@ -25,15 +25,6 @@ pub trait MatZnxDftAlloc<B> {
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftAllocOwned<B>;
|
||||
|
||||
// fn new_mat_znx_dft_from_bytes_borrow(
|
||||
// &self,
|
||||
// rows: usize,
|
||||
// cols_in: usize,
|
||||
// cols_out: usize,
|
||||
// size: usize,
|
||||
// bytes: &mut [u8],
|
||||
// ) -> MatZnxDft<FFT64>;
|
||||
}
|
||||
|
||||
pub trait MatZnxDftScratch {
|
||||
@@ -101,7 +92,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
a: &VecZnx<Data>,
|
||||
scratch: &mut ScratchBorr,
|
||||
scratch: &mut Scratch,
|
||||
);
|
||||
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxBig].
|
||||
@@ -118,7 +109,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
a: &MatZnxDft<Data, B>,
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
scratch: &mut ScratchBorr,
|
||||
scratch: &mut Scratch,
|
||||
);
|
||||
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||
@@ -165,7 +156,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_tmp_bytes].
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut ScratchBorr);
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft<DataMut, B>, a: &VecZnx<Data>, b: &MatZnxDft<Data, B>, scratch: &mut Scratch);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
@@ -197,7 +188,7 @@ pub trait MatZnxDftOps<DataMut, Data, B: Backend> {
|
||||
c: &mut VecZnxDft<DataMut, B>,
|
||||
a: &VecZnxDft<Data, B>,
|
||||
b: &MatZnxDft<Data, B>,
|
||||
scratch: &mut ScratchBorr,
|
||||
scratch: &mut Scratch,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -274,18 +265,14 @@ impl<B: Backend> MatZnxDftScratch for Module<B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataMut, Data> MatZnxDftOps<DataMut, Data, FFT64> for Module<FFT64>
|
||||
where
|
||||
DataMut: AsMut<[u8]> + AsRef<[u8]> + for<'a> From<&'a mut [u8]>,
|
||||
Data: AsRef<[u8]>,
|
||||
{
|
||||
impl MatZnxDftOps<&mut [u8], &[u8], FFT64> for Module<FFT64> {
|
||||
fn vmp_prepare_row(
|
||||
&self,
|
||||
b: &mut MatZnxDft<DataMut, FFT64>,
|
||||
b: &mut MatZnxDft<&mut [u8], FFT64>,
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
a: &VecZnx<Data>,
|
||||
scratch: &mut ScratchBorr,
|
||||
a: &VecZnx<&[u8]>,
|
||||
scratch: &mut Scratch,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -328,21 +315,19 @@ where
|
||||
let a_size: usize = a.size();
|
||||
|
||||
// let (tmp_bytes_a_dft, _) = tmp_bytes.split_at_mut(self.bytes_of_vec_znx_dft(cols_out, a_size));
|
||||
let (mut a_dft, _) = scratch.tmp_scalar_slice(12);
|
||||
DataMut::from(a_dft);
|
||||
// let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<DataMut, _>(self, cols_out, a_size);
|
||||
let (mut a_dft, _) = scratch.tmp_vec_znx_dft::<_>(self, cols_out, a_size);
|
||||
(0..cols_out).for_each(|i| self.vec_znx_dft(&mut a_dft, i, &a, i));
|
||||
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft);
|
||||
Self::vmp_prepare_row_dft(&self, b, b_row, b_col_in, &a_dft.to_ref());
|
||||
}
|
||||
|
||||
fn vmp_extract_row(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
b: &mut VecZnx<DataMut>,
|
||||
a: &MatZnxDft<Data, FFT64>,
|
||||
b: &mut VecZnx<&mut [u8]>,
|
||||
a: &MatZnxDft<&[u8], FFT64>,
|
||||
a_row: usize,
|
||||
a_col_in: usize,
|
||||
mut scratch: &mut ScratchBorr,
|
||||
scratch: &mut Scratch,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -386,12 +371,18 @@ where
|
||||
Self::vmp_extract_row_dft(&self, &mut b_dft, a, a_row, a_col_in);
|
||||
let (mut b_big, scratch) = scratch.tmp_vec_znx_big(self, cols_out, size);
|
||||
(0..cols_out).for_each(|i| {
|
||||
<Self as VecZnxDftOps<DataMut, Data, FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
|
||||
<Self as VecZnxDftOps<&mut [u8], &[u8], FFT64>>::vec_znx_idft_tmp_a(self, &mut b_big, i, &mut b_dft, i);
|
||||
self.vec_znx_big_normalize(log_base2k, b, i, &b_big, i, scratch);
|
||||
});
|
||||
}
|
||||
|
||||
fn vmp_prepare_row_dft(&self, b: &mut MatZnxDft<DataMut, FFT64>, b_row: usize, b_col_in: usize, a: &VecZnxDft<Data, FFT64>) {
|
||||
fn vmp_prepare_row_dft(
|
||||
&self,
|
||||
b: &mut MatZnxDft<&mut [u8], FFT64>,
|
||||
b_row: usize,
|
||||
b_col_in: usize,
|
||||
a: &VecZnxDft<&[u8], FFT64>,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(b.n(), self.n());
|
||||
@@ -436,7 +427,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<DataMut, FFT64>, a: &MatZnxDft<Data, FFT64>, a_row: usize, a_col_in: usize) {
|
||||
fn vmp_extract_row_dft(
|
||||
&self,
|
||||
b: &mut VecZnxDft<&mut [u8], FFT64>,
|
||||
a: &MatZnxDft<&[u8], FFT64>,
|
||||
a_row: usize,
|
||||
a_col_in: usize,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(b.n(), self.n());
|
||||
@@ -482,10 +479,10 @@ where
|
||||
|
||||
fn vmp_apply_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft<DataMut, FFT64>,
|
||||
a: &VecZnx<Data>,
|
||||
b: &MatZnxDft<Data, FFT64>,
|
||||
mut scratch: &mut ScratchBorr,
|
||||
c: &mut VecZnxDft<&mut [u8], FFT64>,
|
||||
a: &VecZnx<&[u8]>,
|
||||
b: &MatZnxDft<&[u8], FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -547,11 +544,12 @@ where
|
||||
|
||||
fn vmp_apply_dft_to_dft(
|
||||
&self,
|
||||
c: &mut VecZnxDft<DataMut, FFT64>,
|
||||
a: &VecZnxDft<Data, FFT64>,
|
||||
b: &MatZnxDft<Data, FFT64>,
|
||||
mut scratch: &mut ScratchBorr,
|
||||
c: &mut VecZnxDft<&mut [u8], FFT64>,
|
||||
a: &VecZnxDft<&[u8], FFT64>,
|
||||
b: &MatZnxDft<&[u8], FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) {
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(c.n(), self.n());
|
||||
@@ -611,6 +609,7 @@ where
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -658,27 +657,31 @@ mod tests {
|
||||
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||
});
|
||||
|
||||
// let g = vmpmat_0.to_mut();
|
||||
|
||||
module.vmp_prepare_row(&mut vmpmat_0.to_mut(), row_i, col_in, &a, scratch.borrow());
|
||||
module.vmp_prepare_row(
|
||||
&mut vmpmat_0.to_mut(),
|
||||
row_i,
|
||||
col_in,
|
||||
&a.to_ref(),
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// Checks that prepare(mat_znx_dft, a) = prepare_dft(mat_znx_dft, a_dft)
|
||||
module.vmp_prepare_row_dft(&mut vmpmat_1, row_i, col_in, &a_dft);
|
||||
module.vmp_prepare_row_dft(&mut vmpmat_1.to_mut(), row_i, col_in, &a_dft.to_ref());
|
||||
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
|
||||
|
||||
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
|
||||
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i, col_in);
|
||||
module.vmp_extract_row_dft(&mut b_dft.to_mut(), &vmpmat_0.to_ref(), row_i, col_in);
|
||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||
|
||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
||||
// module.vmp_extract_row(
|
||||
// log_base2k,
|
||||
// &mut b.to_mut(),
|
||||
// &vmpmat_0.to_ref(),
|
||||
// row_i,
|
||||
// col_in,
|
||||
// scratch.borrow(),
|
||||
// );
|
||||
module.vmp_extract_row(
|
||||
log_base2k,
|
||||
&mut b.to_mut(),
|
||||
&vmpmat_0.to_ref(),
|
||||
row_i,
|
||||
col_in,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
(0..mat_cols_out).for_each(|col_out| {
|
||||
module.vec_znx_idft(&mut a_big, col_out, &a_dft, col_out, &mut tmp_bytes);
|
||||
|
||||
@@ -217,7 +217,7 @@ pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
|
||||
pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
|
||||
|
||||
impl VecZnx<Vec<u8>> {
|
||||
pub(crate) fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
pub fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
@@ -226,7 +226,7 @@ impl VecZnx<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
pub fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
@@ -235,3 +235,14 @@ impl VecZnx<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnx<&mut [u8]> {
|
||||
pub fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: &self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::znx_base::{ZnxInfos, ZnxView};
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const VEC_ZNX_BIG_ROWS: usize = 1;
|
||||
// const VEC_ZNX_BIG_ROWS: usize = 1;
|
||||
|
||||
/// VecZnxBig is `Backend` dependent, denoted with backend generic `B`
|
||||
pub struct VecZnxBig<D, B> {
|
||||
@@ -97,7 +97,7 @@ impl<D, B> VecZnxBig<D, B> {
|
||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||
|
||||
impl<B> VecZnxBig<Vec<u8>, B> {
|
||||
pub(crate) fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
pub fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
@@ -107,7 +107,7 @@ impl<B> VecZnxBig<Vec<u8>, B> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
pub fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
@@ -117,9 +117,3 @@ impl<B> VecZnxBig<Vec<u8>, B> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// impl VecZnxBig<FFT64> {
|
||||
// pub fn print(&self, n: usize, col: usize) {
|
||||
// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
|
||||
// }
|
||||
// }
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, DataView, FFT64, Module, ScratchBorr, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, assert_alignement,
|
||||
bytes_of_vec_znx_big,
|
||||
};
|
||||
use crate::{Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxScratch, bytes_of_vec_znx_big};
|
||||
|
||||
pub trait VecZnxBigAlloc<B> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
@@ -129,7 +126,7 @@ pub trait VecZnxBigOps<DataMut, Data, B> {
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<Data, B>,
|
||||
a_col: usize,
|
||||
scratch: &mut ScratchBorr,
|
||||
scratch: &mut Scratch,
|
||||
);
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
@@ -503,7 +500,7 @@ where
|
||||
res_col: usize,
|
||||
a: &VecZnxBig<Data, FFT64>,
|
||||
a_col: usize,
|
||||
scratch: &mut ScratchBorr,
|
||||
scratch: &mut Scratch,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::ffi::vec_znx_dft;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxView, alloc_aligned};
|
||||
|
||||
const VEC_ZNX_DFT_ROWS: usize = 1;
|
||||
// const VEC_ZNX_DFT_ROWS: usize = 1;
|
||||
|
||||
// VecZnxDft is `Backend` dependent denoted with generic `B`
|
||||
pub struct VecZnxDft<D, B> {
|
||||
@@ -97,52 +97,36 @@ impl<D, B> VecZnxDft<D, B> {
|
||||
}
|
||||
}
|
||||
|
||||
// impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
|
||||
// type Scalar = u8;
|
||||
impl<B> VecZnxDft<Vec<u8>, B> {
|
||||
pub fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
// fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||
// debug_assert_eq!(bytes.len(), Self::bytes_of(module, _rows, cols, size));
|
||||
// Self {
|
||||
// inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
|
||||
// _marker: PhantomData,
|
||||
// }
|
||||
// }
|
||||
pub fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, size: usize) -> usize {
|
||||
// debug_assert_eq!(
|
||||
// _rows, VEC_ZNX_DFT_ROWS,
|
||||
// "rows != {} not supported for VecZnxDft",
|
||||
// VEC_ZNX_DFT_ROWS
|
||||
// );
|
||||
// unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl VecZnxDft<FFT64> {
|
||||
// pub fn print(&self, n: usize, col: usize) {
|
||||
// (0..self.size()).for_each(|i| println!("{}: {:?}", i, &self.at(col, i)[..n]));
|
||||
// }
|
||||
// }
|
||||
|
||||
// impl<B: Backend> VecZnxDft<B> {
|
||||
// /// Cast a [VecZnxDft] into a [VecZnxBig].
|
||||
// /// The returned [VecZnxBig] shares the backing array
|
||||
// /// with the original [VecZnxDft].
|
||||
// pub fn alias_as_vec_znx_big(&mut self) -> VecZnxBig<B> {
|
||||
// assert!(
|
||||
// self.data().len() == 0,
|
||||
// "cannot alias VecZnxDft into VecZnxBig if it owns the data"
|
||||
// );
|
||||
// VecZnxBig::<B> {
|
||||
// inner: ZnxBase {
|
||||
// data: Vec::new(),
|
||||
// ptr: self.ptr(),
|
||||
// n: self.n(),
|
||||
// rows: self.rows(),
|
||||
// cols: self.cols(),
|
||||
// size: self.size(),
|
||||
// },
|
||||
// _marker: PhantomData,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
impl<B> VecZnxDft<&mut [u8], B> {
|
||||
pub fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: &self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,19 +22,6 @@ pub trait VecZnxDftAlloc<B> {
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
|
||||
// /// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
// ///
|
||||
// /// Behavior: the backing array is only borrowed.
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// ///
|
||||
// /// * `cols`: the number of cols of the [VecZnxDft].
|
||||
// /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
// ///
|
||||
// /// # Panics
|
||||
// /// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
// fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, size: usize, bytes: &mut [u8]) -> VecZnxDft<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
@@ -23,19 +23,6 @@ pub trait VecZnxAlloc {
|
||||
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
|
||||
// /// Instantiates a new [VecZnx] from a slice of bytes.
|
||||
// /// The returned [VecZnx] does take ownership of the slice of bytes.
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// ///
|
||||
// /// * `cols`: the number of polynomials.
|
||||
// /// * `size`: the number small polynomials per column.
|
||||
// ///
|
||||
// /// # Panic
|
||||
// /// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
|
||||
// fn new_vec_znx_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnx;
|
||||
// (Jay)TODO
|
||||
|
||||
/// Returns the number of bytes necessary to allocate
|
||||
/// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes]
|
||||
/// or [VecZnxOps::new_vec_znx_from_bytes_borrow].
|
||||
@@ -140,7 +127,6 @@ pub trait VecZnxScratch {
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxAlloc for Module<B> {
|
||||
//(Jay)TODO: One must define the Scalar generic param here.
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||
VecZnxOwned::new::<i64>(self.n(), cols, size)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::num_traits::Zero;
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait ZnxInfos {
|
||||
@@ -157,7 +158,7 @@ pub fn switch_degree<S: Copy, DMut: ZnxViewMut<Scalar = S> + ZnxZero, D: ZnxView
|
||||
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
|
||||
use crate::{ScratchBorr, cast_mut};
|
||||
use crate::Scratch;
|
||||
pub trait Integer:
|
||||
Copy
|
||||
+ Default
|
||||
@@ -183,32 +184,15 @@ impl Integer for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
// (Jay)TODO: implement rsh for VecZnx, VecZnxBig
|
||||
// pub trait ZnxRsh: ZnxZero {
|
||||
// fn rsh(&mut self, k: usize, log_base2k: usize, col: usize, carry: &mut [u8]) {
|
||||
// rsh(k, log_base2k, self, col, carry)
|
||||
// }
|
||||
// }
|
||||
pub fn rsh<V: ZnxZero>(k: usize, log_base2k: usize, a: &mut V, a_col: usize, scratch: &mut ScratchBorr)
|
||||
//(Jay)Note: `rsh` impl. ignores the column
|
||||
pub fn rsh<V: ZnxZero>(k: usize, log_base2k: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
V::Scalar: From<usize> + Integer,
|
||||
V::Scalar: From<usize> + Integer + Zero,
|
||||
{
|
||||
let n: usize = a.n();
|
||||
let size: usize = a.size();
|
||||
let _size: usize = a.size();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
// #[cfg(debug_assertions)]
|
||||
// {
|
||||
// assert!(
|
||||
// tmp_bytes.len() >= rsh_tmp_bytes::<V::Scalar>(n),
|
||||
// "invalid carry: carry.len()/size_ofSelf::Scalar={} < rsh_tmp_bytes({}, {})",
|
||||
// tmp_bytes.len() / size_of::<V::Scalar>(),
|
||||
// n,
|
||||
// size,
|
||||
// );
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
// }
|
||||
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / log_base2k;
|
||||
|
||||
@@ -240,7 +224,7 @@ where
|
||||
*xi = (*xi - *ci) >> k_rem_t;
|
||||
});
|
||||
});
|
||||
//TODO: ZERO CARRYcarry
|
||||
carry.iter_mut().for_each(|r| *r = V::Scalar::zero());
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user