refactoring of code

This commit is contained in:
Jean-Philippe Bossuat
2025-01-27 12:47:05 +01:00
parent 72e0e38827
commit 250d1a4942
9 changed files with 332 additions and 258 deletions

View File

@@ -1,12 +1,16 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use spqlios::bindings::*; use spqlios::bindings::{
new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp,
reim_fft_precomp_get_buffer, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp,
reim_ifft_precomp_get_buffer,
};
use std::ffi::c_void; use std::ffi::c_void;
fn fft(c: &mut Criterion) { fn fft(c: &mut Criterion) {
fn forward<'a>( fn forward<'a>(
m: u32, m: u32,
log_bound: u32, log_bound: u32,
reim_fft_precomp: *mut spqlios::reim_fft_precomp, reim_fft_precomp: *mut reim_fft_precomp,
a: &'a [i64], a: &'a [i64],
) -> Box<dyn FnMut() + 'a> { ) -> Box<dyn FnMut() + 'a> {
unsafe { unsafe {
@@ -29,8 +33,6 @@ fn fft(c: &mut Criterion) {
}) })
} }
let q: u64 = 0x1fffffffffe00001u64;
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =
c.benchmark_group("fft"); c.benchmark_group("fft");

View File

@@ -1,12 +1,12 @@
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
use spqlios::module::{Module, FFT64, VECZNXBIG}; use spqlios::module::{Module, FFT64};
use spqlios::poly::Poly;
use spqlios::scalar::Scalar; use spqlios::scalar::Scalar;
use spqlios::vector::Vector;
fn main() { fn main() {
let n: usize = 16; let n: usize = 16;
let log_base2k: usize = 15; let log_base2k: usize = 40;
let prec: usize = 54; let prec: usize = 54;
let log_scale: usize = 18; let log_scale: usize = 18;
let module: Module = Module::new::<FFT64>(n); let module: Module = Module::new::<FFT64>(n);
@@ -16,23 +16,20 @@ fn main() {
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
let mut res: Poly = Poly::new(n, log_base2k, prec); let mut res: Vector = Vector::new(n, log_base2k, prec);
// Allocates a buffer to store DFT(s)
module.new_svp_ppol();
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1) // s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: Scalar = Scalar::new(n); let mut s: Scalar = Scalar::new(n);
s.fill_ternary_prob(0.5, &mut source); s.fill_ternary_prob(0.5, &mut source);
// Buffer to store s in the DFT domain // Buffer to store s in the DFT domain
let mut s_ppol: spqlios::module::SVPPOL = module.new_svp_ppol(); let mut s_ppol: spqlios::module::SVPPOL = module.svp_new_ppol();
// s_ppol <- DFT(s) // s_ppol <- DFT(s)
module.svp_prepare(&mut s_ppol, &s); module.svp_prepare(&mut s_ppol, &s);
// a <- Z_{2^prec}[X]/(X^{N}+1) // a <- Z_{2^prec}[X]/(X^{N}+1)
let mut a: Poly = Poly::new(n, log_base2k, prec); let mut a: Vector = Vector::new(n, log_base2k, prec);
a.fill_uniform(&mut source); a.fill_uniform(&mut source);
// Scratch space for DFT values // Scratch space for DFT values
@@ -47,7 +44,7 @@ fn main() {
// buf_big <- IDFT(buf_dft) (not normalized) // buf_big <- IDFT(buf_dft) (not normalized)
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs()); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs());
let mut m: Poly = Poly::new(n, log_base2k, prec - log_scale); let mut m: Vector = Vector::new(n, log_base2k, prec - log_scale);
let mut want: Vec<i64> = vec![0; n]; let mut want: Vec<i64> = vec![0; n];
want.iter_mut() want.iter_mut()
.for_each(|x| *x = source.next_u64n(16, 15) as i64); .for_each(|x| *x = source.next_u64n(16, 15) as i64);
@@ -60,7 +57,7 @@ fn main() {
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m);
// b <- normalize(buf_big) + e // b <- normalize(buf_big) + e
let mut b: Poly = Poly::new(n, log_base2k, prec); let mut b: Vector = Vector::new(n, log_base2k, prec);
module.vec_znx_big_normalize(&mut b, &buf_big, &mut carry); module.vec_znx_big_normalize(&mut b, &buf_big, &mut carry);
b.add_normal(&mut source, 3.2, 19.0); b.add_normal(&mut source, 3.2, 19.0);

View File

@@ -1,6 +1,6 @@
pub mod module; pub mod module;
pub mod poly;
pub mod scalar; pub mod scalar;
pub mod vector;
#[allow( #[allow(
non_camel_case_types, non_camel_case_types,
@@ -13,7 +13,14 @@ pub mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs")); include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
} }
pub use bindings::*; pub mod vec_znx_arithmetic;
pub use vec_znx_arithmetic::*;
pub mod vec_znx_big_arithmetic;
pub use vec_znx_big_arithmetic::*;
pub mod vec_znx_dft;
pub use vec_znx_dft::*;
pub mod scalar_vector_product;
pub use scalar_vector_product::*;
fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] {
let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; let ptr: *mut u8 = data.as_mut_ptr() as *mut u8;

View File

@@ -1,12 +1,12 @@
use crate::bindings::*; use crate::bindings::{
use crate::poly::Poly; module_info_t, new_module_info, svp_ppol_t, vec_znx_bigcoeff_t, vec_znx_dft_t, MODULE,
use crate::scalar::Scalar; };
pub type MODULETYPE = u8; pub type MODULETYPE = u8;
pub const FFT64: u8 = 0; pub const FFT64: u8 = 0;
pub const NTT120: u8 = 1; pub const NTT120: u8 = 1;
pub struct Module(*mut MODULE); pub struct Module(pub *mut MODULE);
impl Module { impl Module {
// Instantiates a new module. // Instantiates a new module.
@@ -19,236 +19,13 @@ impl Module {
Self(m) Self(m)
} }
} }
// Prepares a scalar polynomial (1 limb) for a scalar x vector product.
// Method will panic if a.limbs() != 1.
pub fn svp_prepare(&self, svp_ppol: &mut SVPPOL, a: &Scalar) {
unsafe { svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
}
// Allocates a scalar-vector-product prepared-poly (SVPPOL).
pub fn new_svp_ppol(&self) -> SVPPOL {
unsafe { SVPPOL(new_svp_ppol(self.0)) }
}
// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
pub fn new_vec_znx_dft(&self, limbs: usize) -> VECZNXDFT {
unsafe { VECZNXDFT(new_vec_znx_dft(self.0, limbs as u64), limbs) }
}
// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
pub fn new_vec_znx_big(&self, limbs: usize) -> VECZNXBIG {
unsafe { VECZNXBIG(new_vec_znx_big(self.0, limbs as u64), limbs) }
}
// Applies a scalar x vector product: res <- a (ppol) x b
pub fn svp_apply_dft(&self, c: &mut VECZNXDFT, a: &SVPPOL, b: &Poly) {
let limbs: u64 = b.limbs() as u64;
assert!(
c.limbs() as u64 >= limbs,
"invalid c_vector: c_vector.limbs()={} < b.limbs()={}",
c.limbs(),
limbs
);
unsafe { svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) }
}
// b <- IDFT(a), uses a as scratch space.
pub fn vec_znx_idft_tmp_a(&self, b: &mut VECZNXBIG, a: &mut VECZNXDFT, a_limbs: usize) {
assert!(
b.limbs() >= a_limbs,
"invalid c_vector: b_vector.limbs()={} < a_limbs={}",
b.limbs(),
a_limbs
);
unsafe { vec_znx_idft_tmp_a(self.0, b.0, a_limbs as u64, a.0, a_limbs as u64) }
}
// Returns the size of the scratch space for [vec_znx_idft].
pub fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_idft_tmp_bytes(self.0) as usize }
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
pub fn vec_znx_idft(
&self,
b_vector: &mut VECZNXBIG,
a_vector: &mut VECZNXDFT,
a_limbs: usize,
tmp_bytes: &mut [u8],
) {
assert!(
b_vector.limbs() >= a_limbs,
"invalid c_vector: b_vector.limbs()={} < a_limbs={}",
b_vector.limbs(),
a_limbs
);
assert!(
a_vector.limbs() >= a_limbs,
"invalid c_vector: c_vector.limbs()={} < a_limbs={}",
a_vector.limbs(),
a_limbs
);
assert!(
tmp_bytes.len() <= self.vec_znx_idft_tmp_bytes(),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
self.vec_znx_idft_tmp_bytes()
);
unsafe {
vec_znx_idft(
self.0,
b_vector.0,
a_limbs as u64,
a_vector.0,
a_limbs as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
// c <- b - a
pub fn vec_znx_big_sub_small_a(&self, c: &mut VECZNXBIG, a: &Poly, b: &VECZNXBIG) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
assert!(
c.limbs() >= limbs,
"invalid c: c.limbs()={} < a.limbs()={}",
c.limbs(),
limbs
);
unsafe {
vec_znx_big_sub_small_a(
self.0,
c.0,
c.limbs() as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
b.0,
b.limbs() as u64,
)
}
}
// b <- b - a
pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VECZNXBIG, a: &Poly) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
unsafe {
vec_znx_big_sub_small_a(
self.0,
b.0,
b.limbs() as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
b.0,
b.limbs() as u64,
)
}
}
// c <- b + a
pub fn vec_znx_big_add_small(&self, c: &mut VECZNXBIG, a: &Poly, b: &VECZNXBIG) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
assert!(
c.limbs() >= limbs,
"invalid c: c.limbs()={} < a.limbs()={}",
c.limbs(),
limbs
);
unsafe {
vec_znx_big_add_small(
self.0,
c.0,
limbs as u64,
b.0,
limbs as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
)
}
}
// b <- b + a
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VECZNXBIG, a: &Poly) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
unsafe {
vec_znx_big_add_small(
self.0,
b.0,
limbs as u64,
b.0,
limbs as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
)
}
}
pub fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize }
}
// b <- normalize(a)
pub fn vec_znx_big_normalize(&self, b: &mut Poly, a: &VECZNXBIG, tmp_bytes: &mut [u8]) {
let limbs: usize = b.limbs();
assert!(
b.limbs() >= limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
assert!(
tmp_bytes.len() <= self.vec_znx_big_normalize_tmp_bytes(),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
tmp_bytes.len(),
self.vec_znx_big_normalize_tmp_bytes()
);
unsafe {
vec_znx_big_normalize_base2k(
self.0,
b.log_base2k as u64,
b.as_mut_ptr(),
limbs as u64,
b.n() as u64,
a.0,
limbs as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
} }
pub struct SVPPOL(*mut svp_ppol_t); pub struct SVPPOL(pub *mut svp_ppol_t);
pub struct VECZNXDFT(*mut vec_znx_dft_t, usize);
pub struct VECZNXBIG(*mut vec_znx_bigcoeff_t, usize);
pub struct VECZNXBIG(pub *mut vec_znx_bigcoeff_t, pub usize);
// Stores a vector of
impl VECZNXBIG { impl VECZNXBIG {
pub fn as_vec_znx_dft(&mut self) -> VECZNXDFT { pub fn as_vec_znx_dft(&mut self) -> VECZNXDFT {
VECZNXDFT(self.0 as *mut vec_znx_dft_t, self.1) VECZNXDFT(self.0 as *mut vec_znx_dft_t, self.1)
@@ -258,6 +35,8 @@ impl VECZNXBIG {
} }
} }
pub struct VECZNXDFT(pub *mut vec_znx_dft_t, pub usize);
impl VECZNXDFT { impl VECZNXDFT {
pub fn as_vec_znx_big(&mut self) -> VECZNXBIG { pub fn as_vec_znx_big(&mut self) -> VECZNXBIG {
VECZNXBIG(self.0 as *mut vec_znx_bigcoeff_t, self.1) VECZNXBIG(self.0 as *mut vec_znx_bigcoeff_t, self.1)

View File

@@ -0,0 +1,29 @@
use crate::bindings::{new_svp_ppol, svp_apply_dft, svp_prepare};
use crate::module::{Module, SVPPOL, VECZNXDFT};
use crate::scalar::Scalar;
use crate::vector::Vector;
impl Module {
// Prepares a scalar polynomial (1 limb) for a scalar x vector product.
// Method will panic if a.limbs() != 1.
pub fn svp_prepare(&self, svp_ppol: &mut SVPPOL, a: &Scalar) {
unsafe { svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
}
// Allocates a scalar-vector-product prepared-poly (SVPPOL).
pub fn svp_new_ppol(&self) -> SVPPOL {
unsafe { SVPPOL(new_svp_ppol(self.0)) }
}
// Applies a scalar x vector product: res <- a (ppol) x b
pub fn svp_apply_dft(&self, c: &mut VECZNXDFT, a: &SVPPOL, b: &Vector) {
let limbs: u64 = b.limbs() as u64;
assert!(
c.limbs() as u64 >= limbs,
"invalid c_vector: c_vector.limbs()={} < b.limbs()={}",
c.limbs(),
limbs
);
unsafe { svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) }
}
}

View File

@@ -0,0 +1,35 @@
use crate::bindings::vec_znx_automorphism;
use crate::module::Module;
use crate::vector::Vector;
impl Module {
pub fn vec_znx_automorphism(&self, gal_el: i64, b: &mut Vector, a: &Vector) {
unsafe {
vec_znx_automorphism(
self.0,
gal_el,
b.as_mut_ptr(),
b.limbs() as u64,
b.n() as u64,
a.as_ptr(),
a.limbs() as u64,
a.n() as u64,
);
}
}
pub fn vec_znx_automorphism_inplace(&self, gal_el: i64, a: &mut Vector) {
unsafe {
vec_znx_automorphism(
self.0,
gal_el,
a.as_mut_ptr(),
a.limbs() as u64,
a.n() as u64,
a.as_ptr(),
a.limbs() as u64,
a.n() as u64,
);
}
}
}

View File

@@ -0,0 +1,162 @@
use crate::bindings::{
new_vec_znx_big, vec_znx_big_add_small, vec_znx_big_automorphism, vec_znx_big_normalize_base2k,
vec_znx_big_normalize_base2k_tmp_bytes, vec_znx_big_sub_small_a,
};
use crate::module::{Module, VECZNXBIG};
use crate::vector::Vector;
impl Module {
// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
pub fn new_vec_znx_big(&self, limbs: usize) -> VECZNXBIG {
unsafe { VECZNXBIG(new_vec_znx_big(self.0, limbs as u64), limbs) }
}
// b <- b - a
pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VECZNXBIG, a: &Vector) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
unsafe {
vec_znx_big_sub_small_a(
self.0,
b.0,
b.limbs() as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
b.0,
b.limbs() as u64,
)
}
}
// c <- b - a
pub fn big_sub_small_a(&self, c: &mut VECZNXBIG, a: &Vector, b: &VECZNXBIG) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
assert!(
c.limbs() >= limbs,
"invalid c: c.limbs()={} < a.limbs()={}",
c.limbs(),
limbs
);
unsafe {
vec_znx_big_sub_small_a(
self.0,
c.0,
c.limbs() as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
b.0,
b.limbs() as u64,
)
}
}
// c <- b + a
pub fn vec_znx_big_add_small(&self, c: &mut VECZNXBIG, a: &Vector, b: &VECZNXBIG) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
assert!(
c.limbs() >= limbs,
"invalid c: c.limbs()={} < a.limbs()={}",
c.limbs(),
limbs
);
unsafe {
vec_znx_big_add_small(
self.0,
c.0,
limbs as u64,
b.0,
limbs as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
)
}
}
// b <- b + a
pub fn vec_znx_big_add_small_inplace(&self, b: &mut VECZNXBIG, a: &Vector) {
let limbs: usize = a.limbs();
assert!(
b.limbs() >= limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
unsafe {
vec_znx_big_add_small(
self.0,
b.0,
limbs as u64,
b.0,
limbs as u64,
a.as_ptr(),
limbs as u64,
a.n() as u64,
)
}
}
pub fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize }
}
// b <- normalize(a)
pub fn vec_znx_big_normalize(&self, b: &mut Vector, a: &VECZNXBIG, tmp_bytes: &mut [u8]) {
let limbs: usize = b.limbs();
assert!(
b.limbs() >= limbs,
"invalid c_vector: b.limbs()={} < a.limbs()={}",
b.limbs(),
limbs
);
assert!(
tmp_bytes.len() <= self.vec_znx_big_normalize_tmp_bytes(),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
tmp_bytes.len(),
self.vec_znx_big_normalize_tmp_bytes()
);
unsafe {
vec_znx_big_normalize_base2k(
self.0,
b.log_base2k as u64,
b.as_mut_ptr(),
limbs as u64,
b.n() as u64,
a.0,
limbs as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
pub fn big_automorphism(&self, gal_el: i64, b: &mut VECZNXBIG, a: &VECZNXBIG) {
unsafe {
vec_znx_big_automorphism(self.0, gal_el, b.0, b.limbs() as u64, a.0, a.limbs() as u64);
}
}
pub fn big_automorphism_inplace(&self, gal_el: i64, a: &mut VECZNXBIG) {
unsafe {
vec_znx_big_automorphism(self.0, gal_el, a.0, a.limbs() as u64, a.0, a.limbs() as u64);
}
}
}

View File

@@ -0,0 +1,63 @@
use crate::bindings::{new_vec_znx_dft, vec_znx_idft, vec_znx_idft_tmp_a, vec_znx_idft_tmp_bytes};
use crate::module::{Module, VECZNXBIG, VECZNXDFT};
impl Module {
// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
pub fn new_vec_znx_dft(&self, limbs: usize) -> VECZNXDFT {
unsafe { VECZNXDFT(new_vec_znx_dft(self.0, limbs as u64), limbs) }
}
// b <- IDFT(a), uses a as scratch space.
pub fn vec_znx_idft_tmp_a(&self, b: &mut VECZNXBIG, a: &mut VECZNXDFT, a_limbs: usize) {
assert!(
b.limbs() >= a_limbs,
"invalid c_vector: b_vector.limbs()={} < a_limbs={}",
b.limbs(),
a_limbs
);
unsafe { vec_znx_idft_tmp_a(self.0, b.0, a_limbs as u64, a.0, a_limbs as u64) }
}
// Returns the size of the scratch space for [vec_znx_idft].
pub fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_idft_tmp_bytes(self.0) as usize }
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
pub fn vec_znx_idft(
&self,
b_vector: &mut VECZNXBIG,
a_vector: &mut VECZNXDFT,
a_limbs: usize,
tmp_bytes: &mut [u8],
) {
assert!(
b_vector.limbs() >= a_limbs,
"invalid c_vector: b_vector.limbs()={} < a_limbs={}",
b_vector.limbs(),
a_limbs
);
assert!(
a_vector.limbs() >= a_limbs,
"invalid c_vector: c_vector.limbs()={} < a_limbs={}",
a_vector.limbs(),
a_limbs
);
assert!(
tmp_bytes.len() <= self.vec_znx_idft_tmp_bytes(),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
self.vec_znx_idft_tmp_bytes()
);
unsafe {
vec_znx_idft(
self.0,
b_vector.0,
a_limbs as u64,
a_vector.0,
a_limbs as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}

View File

@@ -1,20 +1,20 @@
use crate::{ use crate::bindings::{
cast_mut_u8_to_mut_i64_slice, znx_automorphism_i64, znx_automorphism_inplace_i64, znx_automorphism_i64, znx_automorphism_inplace_i64, znx_normalize, znx_zero_i64_ref,
znx_normalize, znx_zero_i64_ref,
}; };
use crate::cast_mut_u8_to_mut_i64_slice;
use itertools::izip; use itertools::izip;
use rand_distr::{Distribution, Normal}; use rand_distr::{Distribution, Normal};
use sampling::source::Source; use sampling::source::Source;
use std::cmp::min; use std::cmp::min;
pub struct Poly { pub struct Vector {
pub n: usize, pub n: usize,
pub log_base2k: usize, pub log_base2k: usize,
pub prec: usize, pub prec: usize,
pub data: Vec<i64>, pub data: Vec<i64>,
} }
impl Poly { impl Vector {
pub fn new(n: usize, log_base2k: usize, prec: usize) -> Self { pub fn new(n: usize, log_base2k: usize, prec: usize) -> Self {
Self { Self {
n: n, n: n,
@@ -173,7 +173,7 @@ impl Poly {
}) })
} }
} }
pub fn automorphism(&mut self, gal_el: i64, a: &mut Poly) { pub fn automorphism(&mut self, gal_el: i64, a: &mut Vector) {
unsafe { unsafe {
(0..self.limbs()).for_each(|i| { (0..self.limbs()).for_each(|i| {
znx_automorphism_i64(self.n as u64, gal_el, a.at_mut_ptr(i), self.at_ptr(i)) znx_automorphism_i64(self.n as u64, gal_el, a.at_mut_ptr(i), self.at_ptr(i))
@@ -260,7 +260,7 @@ impl Poly {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::poly::Poly; use crate::vector::Vector;
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
@@ -269,7 +269,7 @@ mod tests {
let n: usize = 32; let n: usize = 32;
let k: usize = 19; let k: usize = 19;
let prec: usize = 128; let prec: usize = 128;
let mut a: Poly = Poly::new(n, k, prec); let mut a: Vector = Vector::new(n, k, prec);
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut() have.iter_mut()
.enumerate() .enumerate()
@@ -285,7 +285,7 @@ mod tests {
let n: usize = 8; let n: usize = 8;
let k: usize = 17; let k: usize = 17;
let prec: usize = 84; let prec: usize = 84;
let mut a: Poly = Poly::new(n, k, prec); let mut a: Vector = Vector::new(n, k, prec);
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
let mut source = Source::new([1; 32]); let mut source = Source::new([1; 32]);
have.iter_mut().for_each(|x| { have.iter_mut().for_each(|x| {
@@ -305,7 +305,7 @@ mod tests {
let n: usize = 8; let n: usize = 8;
let k: usize = 17; let k: usize = 17;
let prec: usize = 84; let prec: usize = 84;
let mut a: Poly = Poly::new(n, k, prec); let mut a: Vector = Vector::new(n, k, prec);
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
let mut source = Source::new([1; 32]); let mut source = Source::new([1; 32]);
have.iter_mut().for_each(|x| { have.iter_mut().for_each(|x| {