diff --git a/src/backend.rs b/src/backend.rs index 32021c5..845db22 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -42,10 +42,9 @@ impl Modulus for u64 { 0 } fn map_element_to_i64(&self, v: &Self::Element) -> i64 { - assert!(v < self); - + assert!(v <= self, "{v} must be <= {self}"); if *v > (self >> 1) { - ToPrimitive::to_i64(&(self - v)).unwrap() + -ToPrimitive::to_i64(&(self - v)).unwrap() } else { ToPrimitive::to_i64(v).unwrap() } diff --git a/src/decomposer.rs b/src/decomposer.rs index 2db3f58..a490019 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -1,5 +1,5 @@ use itertools::Itertools; -use num_traits::{AsPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero}; +use num_traits::{AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero}; use std::{fmt::Debug, marker::PhantomData, ops::Rem}; use crate::backend::{ArithmeticOps, ModularOpsU64}; @@ -92,7 +92,9 @@ impl DefaultDecomposer { } } -impl Decomposer for DefaultDecomposer { +impl Decomposer + for DefaultDecomposer +{ type Element = T; fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { @@ -117,45 +119,34 @@ impl Decomposer for DefaultDecompose } } + /// Signed BNAF decomposition. Only returns most significant `d` + /// decomposition limbs + /// + /// Implements algorithm 3 of https://eprint.iacr.org/2021/1161.pdf fn decompose(&self, value: &T) -> Vec { - let value = round_value(*value, self.ignore_bits); + let mut value = round_value(*value, self.ignore_bits); let q = self.q; - // if value >= (q >> 1) { - // value = value.wrapping_sub(&q); - // } - let logb = self.logb; - let b = T::one() << logb; // base - let b_by2 = T::one() << (logb - 1); - // let neg_b_by2_modq = q - b_by2; - let full_mask = (T::one() << logb) - T::one(); - // let half_mask = b_by2 - T::one(); - let mut carry = T::zero(); - let mut out = Vec::::with_capacity(self.d); - for i in 0..self.d { - let mut limb = ((value >> (logb * i)) & full_mask) + carry; - carry = T::zero(); - if limb >= b_by2 { - limb = (q + limb) - b; - carry = T::one(); - } - - // carry = ((q + g - limb) % q) >> logb; + let b = T::one() << logb; + let full_mask = b - T::one(); + let bby2 = b >> 1; - // carry = limb & b_by2; - // limb = (q + limb) - (carry << 1); - // if limb > q { - // limb = limb - q; - // } - out.push(limb); - - // carry = carry >> (logb - 1); + if value > (q >> 1) { + value = !(q - value) + T::one() } - out[self.d - 1] = out[self.d - 1] + (carry << logb); - if out[self.d - 1] > q { - out[self.d - 1] = out[self.d - 1] - q; + let mut out = Vec::with_capacity(self.d); + for _ in 0..self.d { + let k_i = value & full_mask; + value = (value - k_i) >> logb; + + if k_i > bby2 || (k_i == bby2 && ((value & full_mask) >= bby2)) { + out.push(q - (b - k_i)); + value = value + T::one(); + } else { + out.push(k_i) + } } return out; @@ -215,27 +206,29 @@ fn round_value(value: T, ignore_bits: usize) -> T { #[cfg(test)] mod tests { + use num_traits::Float; use rand::{thread_rng, Rng}; use crate::{ backend::{ModInit, ModularOpsU64}, decomposer::round_value, - utils::generate_prime, + utils::{generate_prime, Stats, TryConvertFrom1}, }; use super::{Decomposer, DefaultDecomposer}; #[test] fn decomposition_works() { - let logq = 55; - let logb = 9; - let d = 6; + let logq = 50; + let logb = 5; + let d = 10; let mut rng = thread_rng(); + let mut stats = Stats { samples: vec![] }; // q is prime of bits logq and i is true, other q = 1<::try_convert_from(&limbs, &q)); assert_eq!( rounded_value, value_back, "Expected {rounded_value} got {value_back} for q={q}" ); } } + println!("Mean: {}", stats.mean()); + println!("Std: {}", stats.std_dev()); } } diff --git a/src/rgsw.rs b/src/rgsw.rs index 02c2079..a2eea9d 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -2122,7 +2122,7 @@ pub(crate) mod tests { #[test] fn some_work() { let logq = 50; - let ring_size = 1 << 10; + let ring_size = 1 << 11; let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); let d_rgsw = 10; let logb = 5;