added from_bytes to vec_znx_[dft/big]

This commit is contained in:
Jean-Philippe Bossuat
2025-02-10 15:13:02 +01:00
parent 6b154e64a4
commit 83fa66f8f4
5 changed files with 65 additions and 32 deletions

View File

@@ -32,16 +32,15 @@ fn main() {
a.print_limbs(a.limbs(), n); a.print_limbs(a.limbs(), n);
println!(); println!();
let mut vecznx: Vec<VecZnx>= Vec::new(); let mut vecznx: Vec<VecZnx> = Vec::new();
(0..rows).for_each(|_|{ (0..rows).for_each(|_| {
vecznx.push(module.new_vec_znx(cols)); vecznx.push(module.new_vec_znx(cols));
}); });
(0..rows).for_each(|i|{ (0..rows).for_each(|i| {
vecznx[i].data[i*n+1] = 1 as i64; vecznx[i].data[i * n + 1] = 1 as i64;
}); });
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols); let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(rows, cols);
module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf); module.vmp_prepare_dblptr(&mut vmp_pmat, &vecznx, &mut buf);
@@ -57,7 +56,6 @@ fn main() {
let mut values_res: Vec<i64> = vec![i64::default(); n]; let mut values_res: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(log_base2k, log_k, &mut values_res); res.decode_vec_i64(log_base2k, log_k, &mut values_res);
res.print_limbs(res.limbs(), n); res.print_limbs(res.limbs(), n);
module.free(); module.free();
@@ -67,7 +65,6 @@ fn main() {
//println!("{:?}", values_res) //println!("{:?}", values_res)
} }
/* /*
use base2k::{ use base2k::{

View File

@@ -92,9 +92,5 @@ unsafe extern "C" {
} }
unsafe extern "C" { unsafe extern "C" {
pub unsafe fn vmp_prepare_tmp_bytes( pub unsafe fn vmp_prepare_tmp_bytes(module: *const MODULE, nrows: u64, ncols: u64) -> u64;
module: *const MODULE,
nrows: u64,
ncols: u64,
) -> u64;
} }

View File

@@ -415,7 +415,6 @@ pub trait VecZnxOps {
/// and that b.n() * b.len() <= a.n() /// and that b.n() * b.len() <= a.n()
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx); fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx);
/// Merges the subrings a into b. /// Merges the subrings a into b.
/// ///
/// # Panics /// # Panics
@@ -423,11 +422,9 @@ pub trait VecZnxOps {
/// This method requires that all [VecZnx] of a have the same ring degree /// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n() /// and that a.n() * a.len() <= b.n()
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>); fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>);
} }
impl VecZnxOps for Module { impl VecZnxOps for Module {
fn new_vec_znx(&self, limbs: usize) -> VecZnx { fn new_vec_znx(&self, limbs: usize) -> VecZnx {
VecZnx::new(self.n(), limbs) VecZnx::new(self.n(), limbs)
} }
@@ -592,34 +589,48 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx){ fn vec_znx_split(&self, b: &mut Vec<VecZnx>, a: &VecZnx, buf: &mut VecZnx) {
let (n_in, n_out) = (a.n(), b[0].n()); let (n_in, n_out) = (a.n(), b[0].n());
assert!(n_out < n_in, "invalid a: output ring degree should be smaller"); assert!(
b[1..].iter().for_each(|bi|{ n_out < n_in,
assert_eq!(bi.n(), n_out, "invalid input a: all VecZnx must have the same degree") "invalid a: output ring degree should be smaller"
);
b[1..].iter().for_each(|bi| {
assert_eq!(
bi.n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
}); });
b.iter_mut().enumerate().for_each(|(i, bi)|{ b.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0{ if i == 0 {
a.switch_degree(bi); a.switch_degree(bi);
self.vec_znx_rotate(-1, buf, a); self.vec_znx_rotate(-1, buf, a);
}else{ } else {
buf.switch_degree(bi); buf.switch_degree(bi);
self.vec_znx_rotate_inplace(-1, buf); self.vec_znx_rotate_inplace(-1, buf);
} }
}) })
} }
fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>){ fn vec_znx_merge(&self, b: &mut VecZnx, a: &Vec<VecZnx>) {
let (n_in, n_out) = (b.n(), a[0].n()); let (n_in, n_out) = (b.n(), a[0].n());
assert!(n_out < n_in, "invalid a: output ring degree should be smaller"); assert!(
a[1..].iter().for_each(|ai|{ n_out < n_in,
assert_eq!(ai.n(), n_out, "invalid input a: all VecZnx must have the same degree") "invalid a: output ring degree should be smaller"
);
a[1..].iter().for_each(|ai| {
assert_eq!(
ai.n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
}); });
a.iter().enumerate().for_each(|(i, ai)|{ a.iter().enumerate().for_each(|(i, ai)| {
ai.switch_degree(b); ai.switch_degree(b);
self.vec_znx_rotate_inplace(-1, b); self.vec_znx_rotate_inplace(-1, b);
}); });

View File

@@ -5,6 +5,16 @@ use crate::{Infos, Module, VecZnx, VecZnxDft};
pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize); pub struct VecZnxBig(pub *mut vec_znx_big::vec_znx_bigcoeff_t, pub usize);
impl VecZnxBig { impl VecZnxBig {
/// Casts a contiguous array of [u8] into as a [VecZnxDft].
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
pub fn from_bytes(&self, limbs: usize, data: &mut [u8]) -> VecZnxBig {
VecZnxBig(
data.as_mut_ptr() as *mut vec_znx_big::vec_znx_bigcoeff_t,
limbs,
)
}
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft { pub fn as_vec_znx_dft(&mut self) -> VecZnxDft {
VecZnxDft(self.0 as *mut vec_znx_dft::vec_znx_dft_t, self.1) VecZnxDft(self.0 as *mut vec_znx_dft::vec_znx_dft_t, self.1)
} }
@@ -19,6 +29,10 @@ impl Module {
unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) } unsafe { VecZnxBig(vec_znx_big::new_vec_znx_big(self.0, limbs as u64), limbs) }
} }
pub fn bytes_of_vec_znx_big(&self, limbs: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.0, limbs as u64) as usize }
}
// b <- b - a // b <- b - a
pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) { pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
let limbs: usize = a.limbs(); let limbs: usize = a.limbs();

View File

@@ -1,10 +1,21 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::ffi::vec_znx_dft; use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::{Module, VecZnxBig}; use crate::{Module, VecZnxBig};
pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize); pub struct VecZnxDft(pub *mut vec_znx_dft::vec_znx_dft_t, pub usize);
impl VecZnxDft { impl VecZnxDft {
/// Casts a contiguous array of [u8] into as a [VecZnxDft].
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
pub fn from_bytes(&self, limbs: usize, data: &mut [u8]) -> VecZnxDft {
VecZnxDft(data.as_mut_ptr() as *mut vec_znx_dft::vec_znx_dft_t, limbs)
}
/// Cast a [VecZnxDft] into a [VecZnxBig].
/// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft].
pub fn as_vec_znx_big(&mut self) -> VecZnxBig { pub fn as_vec_znx_big(&mut self) -> VecZnxBig {
VecZnxBig(self.0 as *mut vec_znx_big::vec_znx_bigcoeff_t, self.1) VecZnxBig(self.0 as *mut vec_znx_big::vec_znx_bigcoeff_t, self.1)
} }
@@ -19,6 +30,10 @@ impl Module {
unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) } unsafe { VecZnxDft(vec_znx_dft::new_vec_znx_dft(self.0, limbs as u64), limbs) }
} }
pub fn bytes_of_vec_znx_dft(&self, limbs: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(self.0, limbs as u64) as usize }
}
// b <- IDFT(a), uses a as scratch space. // b <- IDFT(a), uses a as scratch space.
pub fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) { pub fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_limbs: usize) {
assert!( assert!(