diff --git a/Cargo.lock b/Cargo.lock index 659ed82..f4fc43f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -803,7 +803,11 @@ name = "spqlios" version = "0.1.0" dependencies = [ "bindgen", + "criterion", "itertools 0.14.0", + "rand", + "rand_core", + "rand_distr", "sampling", ] diff --git a/spqlios/Cargo.toml b/spqlios/Cargo.toml index 16048d1..f355948 100644 --- a/spqlios/Cargo.toml +++ b/spqlios/Cargo.toml @@ -4,8 +4,16 @@ version = "0.1.0" edition = "2021" [dependencies] +rand = "0.8.5" +rand_core = "0.6.4" itertools = "0.14.0" +criterion = "0.5.1" +rand_distr = "0.4.3" sampling = { path = "../sampling" } [build-dependencies] bindgen = "0.71.1" + +[[bench]] +name = "fft" +harness = false \ No newline at end of file diff --git a/spqlios/benches/fft.rs b/spqlios/benches/fft.rs new file mode 100644 index 0000000..51674ea --- /dev/null +++ b/spqlios/benches/fft.rs @@ -0,0 +1,69 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use spqlios::bindings::*; +use std::ffi::c_void; + +fn fft(c: &mut Criterion) { + fn forward<'a>( + m: u32, + log_bound: u32, + reim_fft_precomp: *mut spqlios::reim_fft_precomp, + a: &'a [i64], + ) -> Box { + unsafe { + let buf_a: *mut f64 = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); + reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); + Box::new(move || reim_fft(reim_fft_precomp, buf_a)) + } + } + + fn backward<'a>( + m: u32, + log_bound: u32, + reim_ifft_precomp: *mut reim_ifft_precomp, + a: &'a [i64], + ) -> Box { + Box::new(move || unsafe { + let buf_a: *mut f64 = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); + reim_from_znx64_simple(m as u32, log_bound as u32, buf_a as *mut c_void, a.as_ptr()); + reim_ifft(reim_ifft_precomp, buf_a); + }) + } + + let q: u64 = 0x1fffffffffe00001u64; + + let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = + c.benchmark_group("fft"); + + for log_n in 10..17 { + let n: usize = 1 << log_n; + let m: usize = n >> 1; + let log_bound: u32 = 19; + + let mut a: Vec = vec![i64::default(); n]; + a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + + unsafe { + let reim_fft_precomp: *mut reim_fft_precomp = new_reim_fft_precomp(m as u32, 1); + let reim_ifft_precomp: *mut reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1); + + let runners: [(String, Box); 2] = [ + (format!("forward"), { + forward(m as u32, log_bound, reim_fft_precomp, &a) + }), + (format!("backward"), { + backward(m as u32, log_bound, reim_ifft_precomp, &a) + }), + ]; + + for (name, mut runner) in runners { + let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n)); + b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| { + b.iter(&mut runner) + }); + } + } + } +} + +criterion_group!(benches, fft,); +criterion_main!(benches); diff --git a/spqlios/examples/rlwe_encrypt.rs b/spqlios/examples/rlwe_encrypt.rs new file mode 100644 index 0000000..2805668 --- /dev/null +++ b/spqlios/examples/rlwe_encrypt.rs @@ -0,0 +1,89 @@ +use itertools::izip; +use sampling::source::Source; +use spqlios::module::{Module, FFT64, VECZNXBIG}; +use spqlios::poly::Poly; +use spqlios::scalar::Scalar; + +fn main() { + let n: usize = 16; + let log_base2k: usize = 15; + let prec: usize = 54; + let log_scale: usize = 18; + let module: Module = Module::new::(n); + + let mut carry: Vec = vec![0; module.vec_znx_big_normalize_tmp_bytes()]; + + let seed: [u8; 32] = [0; 32]; + let mut source: Source = Source::new(seed); + + let mut res: Poly = Poly::new(n, log_base2k, prec); + + // Allocates a buffer to store DFT(s) + module.new_svp_ppol(); + + // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) + let mut s: Scalar = Scalar::new(n); + s.fill_ternary_prob(0.5, &mut source); + + // Buffer to store s in the DFT domain + let mut s_ppol: spqlios::module::SVPPOL = module.new_svp_ppol(); + + // s_ppol <- DFT(s) + module.svp_prepare(&mut s_ppol, &s); + + // a <- Z_{2^prec}[X]/(X^{N}+1) + let mut a: Poly = Poly::new(n, log_base2k, prec); + a.fill_uniform(&mut source); + + // Scratch space for DFT values + let mut buf_dft: spqlios::module::VECZNXDFT = module.new_vec_znx_dft(a.limbs()); + + // Applies buf_dft <- s * a + module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + + // Alias scratch space + let mut buf_big: spqlios::module::VECZNXBIG = buf_dft.as_vec_znx_big(); + + // buf_big <- IDFT(buf_dft) (not normalized) + module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs()); + + let mut m: Poly = Poly::new(n, log_base2k, prec - log_scale); + let mut want: Vec = vec![0; n]; + want.iter_mut() + .for_each(|x| *x = source.next_u64n(16, 15) as i64); + + // m + m.set_i64(&want, 4); + m.normalize(&mut carry); + + // buf_big <- m - buf_big + module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m); + + // b <- normalize(buf_big) + e + let mut b: Poly = Poly::new(n, log_base2k, prec); + module.vec_znx_big_normalize(&mut b, &buf_big, &mut carry); + b.add_normal(&mut source, 3.2, 19.0); + + //Decrypt + + // buf_big <- a * s + module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); + module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, b.limbs()); + + // buf_big <- a * s + b + module.vec_znx_big_add_small_inplace(&mut buf_big, &b); + + // res <- normalize(buf_big) + module.vec_znx_big_normalize(&mut res, &buf_big, &mut carry); + + // have = m * 2^{log_scale} + e + let mut have: Vec = vec![i64::default(); n]; + res.get_i64(&mut have); + + let scale: f64 = (1 << log_scale) as f64; + izip!(want.iter(), have.iter()) + .enumerate() + .for_each(|(i, (a, b))| { + println!("{}: {} {}", i, a, (*b as f64) / scale); + }) +} diff --git a/spqlios/src/lib.rs b/spqlios/src/lib.rs index 945716e..b5e8d53 100644 --- a/spqlios/src/lib.rs +++ b/spqlios/src/lib.rs @@ -1,5 +1,6 @@ pub mod module; pub mod poly; +pub mod scalar; #[allow( non_camel_case_types, @@ -13,3 +14,15 @@ pub mod bindings { } pub use bindings::*; + +fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { + let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; + let len: usize = data.len() * std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} + +fn cast_mut_u8_to_mut_i64_slice(data: &mut [u8]) -> &mut [i64] { + let ptr: *mut i64 = data.as_mut_ptr() as *mut i64; + let len: usize = data.len() / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} diff --git a/spqlios/src/mod.rs b/spqlios/src/mod.rs deleted file mode 100644 index a94a210..0000000 --- a/spqlios/src/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod module; \ No newline at end of file diff --git a/spqlios/src/module.rs b/spqlios/src/module.rs index 84b6136..59ba1fb 100644 --- a/spqlios/src/module.rs +++ b/spqlios/src/module.rs @@ -1,91 +1,268 @@ use crate::bindings::*; +use crate::poly::Poly; +use crate::scalar::Scalar; -pub fn create_module(N: u64, mtype: module_type_t) -> *mut MODULE { - unsafe { - let m = new_module_info(N, mtype); - if m.is_null() { - println!("Failed to create module."); - } - m - } -} +pub type MODULETYPE = u8; +pub const FFT64: u8 = 0; +pub const NTT120: u8 = 1; -#[test] -fn test_new_module_info() { - let N: u64 = 1024; - let module_ptr: *mut module_info_t = create_module(N, module_type_t_FFT64); - assert!(!module_ptr.is_null()); - println!("{:?}", module_ptr); -} - -#[cfg(test)] -mod tests { - use super::*; - use std::ffi::c_void; - use std::time::Instant; - //use test::Bencher; - - #[test] - fn test_fft() { - let log_bound: usize = 19; - - let n: usize = 2048; - let m: usize = n >> 1; - - let mut a: Vec = vec![i64::default(); n]; - let mut b: Vec = vec![i64::default(); n]; - let mut c: Vec = vec![i64::default(); n]; - - a.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - b[1] = 1; - - println!("{:?}", b); +pub struct Module(*mut MODULE); +impl Module { + // Instantiates a new module. + pub fn new(n: usize) -> Self { unsafe { - let reim_fft_precomp = new_reim_fft_precomp(m as u32, 2); - let reim_ifft_precomp = new_reim_ifft_precomp(m as u32, 1); + let m: *mut module_info_t = new_module_info(n as u64, MODULETYPE as u32); + if m.is_null() { + panic!("Failed to create module."); + } + Self(m) + } + } - let buf_a = reim_fft_precomp_get_buffer(reim_fft_precomp, 0); - let buf_b = reim_fft_precomp_get_buffer(reim_fft_precomp, 1); - let buf_c = reim_ifft_precomp_get_buffer(reim_ifft_precomp, 0); + // Prepares a scalar polynomial (1 limb) for a scalar x vector product. + // Method will panic if a.limbs() != 1. + pub fn svp_prepare(&self, svp_ppol: &mut SVPPOL, a: &Scalar) { + unsafe { svp_prepare(self.0, svp_ppol.0, a.as_ptr()) } + } - let now = Instant::now(); - (0..1024).for_each(|i| { - reim_from_znx64_simple( - m as u32, - log_bound as u32, - buf_a as *mut c_void, - a.as_ptr(), - ); - reim_fft(reim_fft_precomp, buf_a); + // Allocates a scalar-vector-product prepared-poly (SVPPOL). + pub fn new_svp_ppol(&self) -> SVPPOL { + unsafe { SVPPOL(new_svp_ppol(self.0)) } + } - reim_from_znx64_simple( - m as u32, - log_bound as u32, - buf_b as *mut c_void, - b.as_ptr(), - ); - reim_fft(reim_fft_precomp, buf_b); + // Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space. + pub fn new_vec_znx_dft(&self, limbs: usize) -> VECZNXDFT { + unsafe { VECZNXDFT(new_vec_znx_dft(self.0, limbs as u64), limbs) } + } - reim_fftvec_mul_simple( - m as u32, - buf_c as *mut c_void, - buf_a as *mut c_void, - buf_b as *mut c_void, - ); - reim_ifft(reim_ifft_precomp, buf_c); + // Allocates a vector Z[X]/(X^N+1) that stores not normalized values. + pub fn new_vec_znx_big(&self, limbs: usize) -> VECZNXBIG { + unsafe { VECZNXBIG(new_vec_znx_big(self.0, limbs as u64), limbs) } + } - reim_to_znx64_simple( - m as u32, - m as f64, - log_bound as u32, - c.as_mut_ptr(), - buf_c as *mut c_void, - ) - }); + // Applies a scalar x vector product: res <- a (ppol) x b + pub fn svp_apply_dft(&self, c: &mut VECZNXDFT, a: &SVPPOL, b: &Poly) { + let limbs: u64 = b.limbs() as u64; + assert!( + c.limbs() as u64 >= limbs, + "invalid c_vector: c_vector.limbs()={} < b.limbs()={}", + c.limbs(), + limbs + ); + unsafe { svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) } + } - println!("time: {}us", now.elapsed().as_micros()); - println!("{:?}", &c[..16]); + // b <- IDFT(a), uses a as scratch space. + pub fn vec_znx_idft_tmp_a(&self, b: &mut VECZNXBIG, a: &mut VECZNXDFT, a_limbs: usize) { + assert!( + b.limbs() >= a_limbs, + "invalid c_vector: b_vector.limbs()={} < a_limbs={}", + b.limbs(), + a_limbs + ); + unsafe { vec_znx_idft_tmp_a(self.0, b.0, a_limbs as u64, a.0, a_limbs as u64) } + } + + // Returns the size of the scratch space for [vec_znx_idft]. + pub fn vec_znx_idft_tmp_bytes(&self) -> usize { + unsafe { vec_znx_idft_tmp_bytes(self.0) as usize } + } + + // b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes]. + pub fn vec_znx_idft( + &self, + b_vector: &mut VECZNXBIG, + a_vector: &mut VECZNXDFT, + a_limbs: usize, + tmp_bytes: &mut [u8], + ) { + assert!( + b_vector.limbs() >= a_limbs, + "invalid c_vector: b_vector.limbs()={} < a_limbs={}", + b_vector.limbs(), + a_limbs + ); + assert!( + a_vector.limbs() >= a_limbs, + "invalid c_vector: c_vector.limbs()={} < a_limbs={}", + a_vector.limbs(), + a_limbs + ); + assert!( + tmp_bytes.len() <= self.vec_znx_idft_tmp_bytes(), + "invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}", + tmp_bytes.len(), + self.vec_znx_idft_tmp_bytes() + ); + unsafe { + vec_znx_idft( + self.0, + b_vector.0, + a_limbs as u64, + a_vector.0, + a_limbs as u64, + tmp_bytes.as_mut_ptr(), + ) + } + } + + // c <- b - a + pub fn vec_znx_big_sub_small_a(&self, c: &mut VECZNXBIG, a: &Poly, b: &VECZNXBIG) { + let limbs: usize = a.limbs(); + assert!( + b.limbs() >= limbs, + "invalid c: b.limbs()={} < a.limbs()={}", + b.limbs(), + limbs + ); + assert!( + c.limbs() >= limbs, + "invalid c: c.limbs()={} < a.limbs()={}", + c.limbs(), + limbs + ); + unsafe { + vec_znx_big_sub_small_a( + self.0, + c.0, + c.limbs() as u64, + a.as_ptr(), + limbs as u64, + a.n() as u64, + b.0, + b.limbs() as u64, + ) + } + } + + // b <- b - a + pub fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VECZNXBIG, a: &Poly) { + let limbs: usize = a.limbs(); + assert!( + b.limbs() >= limbs, + "invalid c_vector: b.limbs()={} < a.limbs()={}", + b.limbs(), + limbs + ); + unsafe { + vec_znx_big_sub_small_a( + self.0, + b.0, + b.limbs() as u64, + a.as_ptr(), + limbs as u64, + a.n() as u64, + b.0, + b.limbs() as u64, + ) + } + } + + // c <- b + a + pub fn vec_znx_big_add_small(&self, c: &mut VECZNXBIG, a: &Poly, b: &VECZNXBIG) { + let limbs: usize = a.limbs(); + assert!( + b.limbs() >= limbs, + "invalid c: b.limbs()={} < a.limbs()={}", + b.limbs(), + limbs + ); + assert!( + c.limbs() >= limbs, + "invalid c: c.limbs()={} < a.limbs()={}", + c.limbs(), + limbs + ); + unsafe { + vec_znx_big_add_small( + self.0, + c.0, + limbs as u64, + b.0, + limbs as u64, + a.as_ptr(), + limbs as u64, + a.n() as u64, + ) + } + } + + // b <- b + a + pub fn vec_znx_big_add_small_inplace(&self, b: &mut VECZNXBIG, a: &Poly) { + let limbs: usize = a.limbs(); + assert!( + b.limbs() >= limbs, + "invalid c_vector: b.limbs()={} < a.limbs()={}", + b.limbs(), + limbs + ); + unsafe { + vec_znx_big_add_small( + self.0, + b.0, + limbs as u64, + b.0, + limbs as u64, + a.as_ptr(), + limbs as u64, + a.n() as u64, + ) + } + } + + pub fn vec_znx_big_normalize_tmp_bytes(&self) -> usize { + unsafe { vec_znx_big_normalize_base2k_tmp_bytes(self.0) as usize } + } + + // b <- normalize(a) + pub fn vec_znx_big_normalize(&self, b: &mut Poly, a: &VECZNXBIG, tmp_bytes: &mut [u8]) { + let limbs: usize = b.limbs(); + assert!( + b.limbs() >= limbs, + "invalid c_vector: b.limbs()={} < a.limbs()={}", + b.limbs(), + limbs + ); + assert!( + tmp_bytes.len() <= self.vec_znx_big_normalize_tmp_bytes(), + "invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}", + tmp_bytes.len(), + self.vec_znx_big_normalize_tmp_bytes() + ); + unsafe { + vec_znx_big_normalize_base2k( + self.0, + b.log_base2k as u64, + b.as_mut_ptr(), + limbs as u64, + b.n() as u64, + a.0, + limbs as u64, + tmp_bytes.as_mut_ptr(), + ) } } } + +pub struct SVPPOL(*mut svp_ppol_t); +pub struct VECZNXDFT(*mut vec_znx_dft_t, usize); +pub struct VECZNXBIG(*mut vec_znx_bigcoeff_t, usize); + +impl VECZNXBIG { + pub fn as_vec_znx_dft(&mut self) -> VECZNXDFT { + VECZNXDFT(self.0 as *mut vec_znx_dft_t, self.1) + } + pub fn limbs(&self) -> usize { + self.1 + } +} + +impl VECZNXDFT { + pub fn as_vec_znx_big(&mut self) -> VECZNXBIG { + VECZNXBIG(self.0 as *mut vec_znx_bigcoeff_t, self.1) + } + pub fn limbs(&self) -> usize { + self.1 + } +} diff --git a/spqlios/src/poly.rs b/spqlios/src/poly.rs index a7c940f..9da1f89 100644 --- a/spqlios/src/poly.rs +++ b/spqlios/src/poly.rs @@ -1,46 +1,51 @@ -use crate::{znx_normalize, znx_zero_i64_ref}; +use crate::{ + cast_mut_u8_to_mut_i64_slice, znx_automorphism_i64, znx_automorphism_inplace_i64, + znx_normalize, znx_zero_i64_ref, +}; use itertools::izip; +use rand_distr::{Distribution, Normal}; +use sampling::source::Source; use std::cmp::min; pub struct Poly { pub n: usize, - pub k: usize, + pub log_base2k: usize, pub prec: usize, pub data: Vec, } impl Poly { - pub fn new(n: usize, k: usize, prec: usize) -> Self { + pub fn new(n: usize, log_base2k: usize, prec: usize) -> Self { Self { n: n, - k: k, + log_base2k: log_base2k, prec: prec, - data: vec![i64::default(); Self::buffer_size(n, k, prec)], + data: vec![i64::default(); Self::buffer_size(n, log_base2k, prec)], } } - pub fn buffer_size(n: usize, k: usize, prec: usize) -> usize { - n * ((prec + k - 1) / k) + pub fn buffer_size(n: usize, log_base2k: usize, prec: usize) -> usize { + n * ((prec + log_base2k - 1) / log_base2k) } - pub fn from_buffer(&mut self, n: usize, k: usize, prec: usize, buf: &[i64]) { - let size = Self::buffer_size(n, k, prec); + pub fn from_buffer(&mut self, n: usize, log_base2k: usize, prec: usize, buf: &[i64]) { + let size = Self::buffer_size(n, log_base2k, prec); assert!( buf.len() >= size, "invalid buffer: buf.len()={} < self.buffer_size(n={}, k={}, prec={})={}", buf.len(), n, - k, + log_base2k, prec, size ); self.n = n; - self.k = k; + self.log_base2k = log_base2k; self.prec = prec; self.data = Vec::from(&buf[..size]) } - pub fn log_n(&self) -> usize { + pub fn log_n(&self) -> u64 { (u64::BITS - (self.n - 1).leading_zeros()) as _ } @@ -48,10 +53,22 @@ impl Poly { self.n } + pub fn prec(&self) -> usize { + self.prec + } + pub fn limbs(&self) -> usize { self.data.len() / self.n } + pub fn as_ptr(&self) -> *const i64 { + self.data.as_ptr() + } + + pub fn as_mut_ptr(&mut self) -> *mut i64 { + self.data.as_mut_ptr() + } + pub fn at(&self, i: usize) -> &[i64] { &self.data[i * self.n..(i + 1) * self.n] } @@ -70,54 +87,57 @@ impl Poly { pub fn set_i64(&mut self, data: &[i64], log_max: usize) { let size: usize = min(data.len(), self.n()); - let k_rem: usize = self.k - (self.prec % self.k); + let k_rem: usize = self.log_base2k - (self.prec % self.log_base2k); - // If 2^{base} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy // values on the last limb. - // Else we decompose values base k. - if log_max + k_rem < 63 || k_rem == self.k { + // Else we decompose values base2k. + if log_max + k_rem < 63 || k_rem == self.log_base2k { self.at_mut(self.limbs() - 1).copy_from_slice(&data[..size]); } else { - let mask: i64 = (1 << self.k) - 1; + let mask: i64 = (1 << self.log_base2k) - 1; let limbs = self.limbs(); - let steps: usize = min(limbs, (log_max + k_rem + self.k - 1) / self.k); + let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); (limbs - steps..limbs) .rev() .enumerate() .for_each(|(i, i_rev)| { - let shift: usize = i * self.k; + let shift: usize = i * self.log_base2k; izip!(self.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) .for_each(|(y, x)| *y = (x >> shift) & mask); }) } // Case where self.prec % self.k != 0. - if k_rem != self.k { + if k_rem != self.log_base2k { let limbs = self.limbs(); - let steps: usize = min(limbs, (log_max + k_rem + self.k - 1) / self.k); + let steps: usize = min(limbs, (log_max + self.log_base2k - 1) / self.log_base2k); (limbs - steps..limbs).rev().for_each(|i| { self.at_mut(i)[..size].iter_mut().for_each(|x| *x <<= k_rem); }) } } - pub fn normalize(&mut self, carry: &mut [i64]) { + pub fn normalize(&mut self, carry: &mut [u8]) { assert!( - carry.len() >= self.n, + carry.len() >= self.n * 8, "invalid carry: carry.len()={} < self.n()={}", carry.len(), self.n() ); + + let carry_i64: &mut [i64] = cast_mut_u8_to_mut_i64_slice(carry); + unsafe { - znx_zero_i64_ref(self.n() as u64, carry.as_mut_ptr()); + znx_zero_i64_ref(self.n() as u64, carry_i64.as_mut_ptr()); (0..self.limbs()).rev().for_each(|i| { znx_normalize( self.n as u64, - self.k as u64, + self.log_base2k as u64, self.at_mut_ptr(i), - carry.as_mut_ptr(), + carry_i64.as_mut_ptr(), self.at_mut_ptr(i), - carry.as_mut_ptr(), + carry_i64.as_mut_ptr(), ) }); } @@ -131,20 +151,111 @@ impl Poly { self.n ); data.copy_from_slice(self.at(0)); - let rem: usize = self.k - (self.prec % self.k); + let rem: usize = self.log_base2k - (self.prec % self.log_base2k); (1..self.limbs()).for_each(|i| { - if i == self.limbs() - 1 && rem != self.k { - let k_rem: usize = self.k - rem; + if i == self.limbs() - 1 && rem != self.log_base2k { + let k_rem: usize = self.log_base2k - rem; izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem); }); } else { izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << self.k) + x; + *y = (*y << self.log_base2k) + x; }); } }) } + + pub fn automorphism_inplace(&mut self, gal_el: i64) { + unsafe { + (0..self.limbs()).for_each(|i| { + znx_automorphism_inplace_i64(self.n as u64, gal_el, self.at_mut_ptr(i)) + }) + } + } + pub fn automorphism(&mut self, gal_el: i64, a: &mut Poly) { + unsafe { + (0..self.limbs()).for_each(|i| { + znx_automorphism_i64(self.n as u64, gal_el, a.at_mut_ptr(i), self.at_ptr(i)) + }) + } + } + + pub fn fill_uniform(&mut self, source: &mut Source) { + let mut base2k: u64 = 1 << self.log_base2k; + let mut mask: u64 = base2k - 1; + let mut base2k_half: i64 = (base2k >> 1) as i64; + + let size: usize = self.n() * (self.limbs() - 1); + + self.data[..size] + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + + let log_base2k_rem: usize = self.prec % self.log_base2k; + + if log_base2k_rem != 0 { + base2k = 1 << log_base2k_rem; + mask = (base2k - 1) << (self.log_base2k - log_base2k_rem); + base2k_half = ((mask >> 1) + 1) as i64; + } + + self.data[size..] + .iter_mut() + .for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half); + } + + pub fn add_dist_f64>(&mut self, source: &mut Source, dist: T, bound: f64) { + let log_base2k_rem: usize = self.prec % self.log_base2k; + + if log_base2k_rem != 0 { + self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += (dist_f64.round() as i64) << log_base2k_rem + }); + } else { + self.at_mut(self.limbs() - 1).iter_mut().for_each(|a| { + let mut dist_f64: f64 = dist.sample(source); + while dist_f64.abs() > bound { + dist_f64 = dist.sample(source) + } + *a += dist_f64.round() as i64 + }); + } + } + + pub fn add_normal(&mut self, source: &mut Source, sigma: f64, bound: f64) { + self.add_dist_f64(source, Normal::new(0.0, sigma).unwrap(), bound); + } + + pub fn trunc_pow2(&mut self, k: usize) { + if k == 0 { + return; + } + + assert!( + k <= self.prec, + "invalid argument k: k={} > self.prec()={}", + k, + self.prec() + ); + + self.prec -= k; + self.data + .truncate((self.limbs() - k / self.log_base2k) * self.n()); + + let k_rem: usize = k % self.log_base2k; + + if k_rem != 0 { + let mask: i64 = ((1 << (self.log_base2k - k_rem - 1)) - 1) << k_rem; + self.at_mut(self.limbs() - 1) + .iter_mut() + .for_each(|x: &mut i64| *x &= mask) + } + } } #[cfg(test)] @@ -171,9 +282,9 @@ mod tests { #[test] fn test_set_get_i64_hi_norm() { - let n: usize = 1; - let k: usize = 19; - let prec: usize = 128; + let n: usize = 8; + let k: usize = 17; + let prec: usize = 84; let mut a: Poly = Poly::new(n, k, prec); let mut have: Vec = vec![i64::default(); n]; let mut source = Source::new([1; 32]); @@ -183,8 +294,35 @@ mod tests { .wrapping_sub(u64::MAX / 2 + 1) as i64; }); a.set_i64(&have, 63); + //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); + let mut want = vec![i64::default(); n]; + //(0..a.limbs()).for_each(|i| println!("i:{} -> {:?}", i, a.at(i))); + a.get_i64(&mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + } + #[test] + fn test_normalize() { + let n: usize = 8; + let k: usize = 17; + let prec: usize = 84; + let mut a: Poly = Poly::new(n, k, prec); + let mut have: Vec = vec![i64::default(); n]; + let mut source = Source::new([1; 32]); + have.iter_mut().for_each(|x| { + *x = source + .next_u64n(u64::MAX, u64::MAX) + .wrapping_sub(u64::MAX / 2 + 1) as i64; + }); + a.set_i64(&have, 63); + let mut carry: Vec = vec![u8::default(); n * 8]; + a.normalize(&mut carry); + + let base_half = 1 << (k - 1); + a.data + .iter() + .for_each(|x| assert!(x.abs() <= base_half, "|x|={} > 2^(k-1)={}", x, base_half)); let mut want = vec![i64::default(); n]; a.get_i64(&mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b)); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); } } diff --git a/spqlios/src/scalar.rs b/spqlios/src/scalar.rs new file mode 100644 index 0000000..0e1c230 --- /dev/null +++ b/spqlios/src/scalar.rs @@ -0,0 +1,48 @@ +use rand::distributions::{Distribution, WeightedIndex}; +use rand::seq::SliceRandom; +use rand_core::RngCore; +use sampling::source::Source; + +pub struct Scalar(pub Vec); + +impl Scalar { + pub fn new(n: usize) -> Self { + Self(vec![i64::default(); Self::buffer_size(n)]) + } + + pub fn buffer_size(n: usize) -> usize { + n + } + + pub fn from_buffer(&mut self, n: usize, buf: &[i64]) { + let size = Self::buffer_size(n); + assert!( + buf.len() >= size, + "invalid buffer: buf.len()={} < self.buffer_size(n={})={}", + buf.len(), + n, + size + ); + self.0 = Vec::from(&buf[..size]) + } + + pub fn as_ptr(&self) -> *const i64 { + self.0.as_ptr() + } + + pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { + let choices: [i64; 3] = [-1, 0, 1]; + let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0]; + let dist: WeightedIndex = WeightedIndex::new(&weights).unwrap(); + self.0 + .iter_mut() + .for_each(|x: &mut i64| *x = choices[dist.sample(source)]); + } + + pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { + self.0[..hw] + .iter_mut() + .for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1); + self.0.shuffle(source); + } +} diff --git a/spqlios/tests/module.rs b/spqlios/tests/module.rs deleted file mode 100644 index 2fb065a..0000000 --- a/spqlios/tests/module.rs +++ /dev/null @@ -1,9 +0,0 @@ -use spqlios::bindings::{module_info_t, module_type_t_FFT64}; -use spqlios::module::create_module; - -#[test] -fn test_new_module_info() { - let N: u64 = 1024; - let module_ptr: *mut module_info_t = create_module(N, module_type_t_FFT64); - assert!(!module_ptr.is_null()); -}