diff --git a/math/src/poly.rs b/math/src/poly.rs index c5988bc..8faa8f1 100644 --- a/math/src/poly.rs +++ b/math/src/poly.rs @@ -38,12 +38,12 @@ where self.0.resize(n, O::default()); } - pub fn set_all(&mut self, v: &O) { + pub fn fill(&mut self, v: &O) { self.0.fill(*v) } pub fn zero(&mut self) { - self.set_all(&O::default()) + self.fill(&O::default()) } pub fn copy_from(&mut self, other: &Poly) { @@ -134,12 +134,12 @@ where &mut self.0[level] } - pub fn set_all(&mut self, v: &O) { - (0..self.level() + 1).for_each(|i| self.at_mut(i).set_all(v)) + pub fn fill(&mut self, v: &O) { + (0..self.level() + 1).for_each(|i| self.at_mut(i).fill(v)) } pub fn zero(&mut self) { - self.set_all(&O::default()) + self.fill(&O::default()) } pub fn copy(&mut self, other: &PolyRNS) { diff --git a/math/src/ring/impl_u64/packing.rs b/math/src/ring/impl_u64/packing.rs index 1f3ef4d..45c00aa 100644 --- a/math/src/ring/impl_u64/packing.rs +++ b/math/src/ring/impl_u64/packing.rs @@ -1,10 +1,9 @@ use crate::modulus::barrett::Barrett; -use crate::modulus::{WordOps, ONCE}; +use crate::modulus::ONCE; use crate::poly::Poly; use crate::ring::Ring; use std::cmp::min; use std::collections::HashSet; -use std::mem::transmute; impl Ring { // Generates a vector storing {X^{2^0}, X^{2^1}, .., X^{2^log_n}}. @@ -64,14 +63,17 @@ impl Ring { let set: HashSet<_> = indices.into_iter().collect(); - let max_pow2_gap_divisor: usize = 1 << gap.trailing_zeros(); - if !ZEROGARBAGE { if gap > 0 { - log_end -= max_pow2_gap_divisor; + log_end -= gap.trailing_zeros() as usize; } } + assert!( + log_start < log_end, + "invalid input polys: gap between non None value is smaller than 2^log_gap" + ); + let n_inv: Barrett = self .modulus .barrett @@ -102,31 +104,35 @@ impl Ring { if let Some(poly_lo) = polys_lo[j].as_mut() { self.a_sub_b_into_c::<1, ONCE>(poly_lo, poly_hi, &mut tmpa); self.a_add_b_into_b::(poly_hi, poly_lo); - } else { - std::mem::swap(&mut polys_lo[j], &mut polys_hi[j]); } } if let Some(poly_lo) = polys_lo[j].as_mut() { - let gal_el: usize = self.galois_element(1 << (i - 1), i == 0, log_nth_root); + let gal_el: usize = self.galois_element((1 << i) >> 1, i == 0, log_nth_root); if !polys_hi[j].is_none() { self.automorphism::(&tmpa, gal_el, 2 << self.log_n(), &mut tmpb); + self.a_add_b_into_b::(&tmpb, poly_lo); } else { self.automorphism::(poly_lo, gal_el, nth_root, &mut tmpa); + self.a_add_b_into_b::(&tmpa, poly_lo); } - - self.a_add_b_into_b::(&tmpa, poly_lo); } else if let Some(poly_hi) = polys_hi[j].as_mut() { - let gal_el: usize = self.galois_element(1 << (i - 1), i == 0, log_nth_root); - + let gal_el: usize = self.galois_element((1 << i) >> 1, i == 0, log_nth_root); self.automorphism::(poly_hi, gal_el, nth_root, &mut tmpa); - self.a_sub_b_into_a::<1, ONCE>(&tmpa, poly_hi) + self.a_sub_b_into_a::<1, ONCE>(&tmpa, poly_hi); + std::mem::swap(&mut polys_lo[j], &mut polys_hi[j]); } } polys.truncate(t); } + + if !NTT { + if let Some(poly) = polys[0].as_mut() { + self.intt_inplace::(poly); + } + } } } @@ -135,7 +141,7 @@ fn max_gap(vec: &[usize]) -> usize { let mut gap: usize = usize::MAX; for i in 1..vec.len() { let (l, r) = (vec[i - 1], vec[i]); - assert!(l > r, "invalid input vec: not sorted"); + assert!(r > l, "invalid input vec: not sorted"); gap = min(gap, r - l); if gap == 1 { break; diff --git a/math/src/ring/impl_u64/ring.rs b/math/src/ring/impl_u64/ring.rs index ef0c9ae..2f090e3 100644 --- a/math/src/ring/impl_u64/ring.rs +++ b/math/src/ring/impl_u64/ring.rs @@ -51,12 +51,12 @@ impl Ring { gal_el = gal_el.wrapping_mul(gen_1_pow); } - gen_1_pow *= gen_1_pow; + gen_1_pow = gen_1_pow.wrapping_mul(gen_1_pow); e >>= 1; } let nth_root = 1 << log_nth_root; - gal_el &= (nth_root - 1); + gal_el &= nth_root - 1; if gen_2 { return nth_root - gal_el; diff --git a/math/tests/automorphism.rs b/math/tests/automorphism.rs index 93b95e4..421e9a5 100644 --- a/math/tests/automorphism.rs +++ b/math/tests/automorphism.rs @@ -1,5 +1,6 @@ use itertools::izip; use math::poly::Poly; +use math::ring::impl_u64::ring; use math::ring::Ring; #[test] @@ -51,3 +52,86 @@ fn test_automorphism_u64(ring: &Ring, nth_root: usize) { izip!(p0.0, p1.0).for_each(|(a, b)| assert_eq!(a, b)); } + +#[test] +fn packing_u64() { + let n: usize = 1 << 5; + let q_base: u64 = 65537u64; + let q_power: usize = 1usize; + let ring: Ring = Ring::new(n, q_base, q_power); + + sub_test("test_packing_u64::", || { + test_packing_full_u64::(&ring) + }); + sub_test("test_packing_u64::", || { + test_packing_full_u64::(&ring) + }); + sub_test("test_packing_sparse_u64::", || { + test_packing_sparse_u64::(&ring) + }); + sub_test("test_packing_sparse_u64::", || { + test_packing_sparse_u64::(&ring) + }); +} + +fn test_packing_full_u64(ring: &Ring) { + let n: usize = ring.n(); + + let mut result: Vec>> = vec![None; n]; + + for i in 0..n { + let mut poly: Poly = ring.new_poly(); + poly.fill(&(1 + i as u64)); + if NTT { + ring.ntt_inplace::(&mut poly); + } + + result[i] = Some(poly); + } + + ring.pack::(&mut result, ring.log_n()); + + if let Some(poly) = result[0].as_mut() { + if NTT { + ring.intt_inplace::(poly); + } + + poly.0 + .iter() + .enumerate() + .for_each(|(i, x)| assert_eq!(*x, 1 + i as u64)); + } +} + +fn test_packing_sparse_u64(ring: &Ring) { + let n: usize = ring.n(); + + let mut result: Vec>> = vec![None; n]; + + let gap: usize = 3; + + for i in (0..n).step_by(gap) { + let mut poly: Poly = ring.new_poly(); + poly.fill(&(1 + i as u64)); + if NTT { + ring.ntt_inplace::(&mut poly); + } + result[i] = Some(poly); + } + + ring.pack::(&mut result, ring.log_n()); + + if let Some(poly) = result[0].as_mut() { + if NTT { + ring.intt_inplace::(poly); + } + + poly.0.iter().enumerate().for_each(|(i, x)| { + if i % gap == 0 { + assert_eq!(*x, 1 + i as u64) + } else { + assert_eq!(*x, 0u64) + } + }); + } +}