use itertools::Itertools; use num_traits::{AsPrimitive, FromPrimitive, Num, One, PrimInt, ToPrimitive, WrappingSub, Zero}; use std::{fmt::Debug, marker::PhantomData, ops::Rem}; use crate::backend::{ArithmeticOps, ModularOpsU64}; fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); let ignored_limbs = d_ideal - d; (ignored_limbs..ignored_limbs + d) .into_iter() .map(|i| T::one() << (logb * i)) .collect_vec() } pub trait RlweDecomposer { type Element; type D: Decomposer; /// Decomposer for RLWE Part A fn a(&self) -> &Self::D; /// Decomposer for RLWE Part B fn b(&self) -> &Self::D; } impl RlweDecomposer for (D, D) where D: Decomposer, { type D = D; type Element = D::Element; fn a(&self) -> &Self::D { &self.0 } fn b(&self) -> &Self::D { &self.1 } } pub trait Decomposer { type Element; type Iter: Iterator; fn new(q: Self::Element, logb: usize, d: usize) -> Self; fn decompose_to_vec(&self, v: &Self::Element) -> Vec; fn decompose_iter(&self, v: &Self::Element) -> Self::Iter; fn decomposition_count(&self) -> usize; } pub struct DefaultDecomposer { /// Ciphertext modulus q: T, /// Log of ciphertext modulus logq: usize, /// Log of base B logb: usize, /// base B b: T, /// (B - 1). To simulate (% B) as &(B-1), that is extract least significant /// logb bits b_mask: T, /// B/2 bby2: T, /// Decomposition count d: usize, /// No. of bits to ignore in rounding ignore_bits: usize, /// No. of limbs to ignore in rounding. Set to ceil(logq / logb) - d ignore_limbs: usize, } pub trait NumInfo { const BITS: u32; } impl NumInfo for u64 { const BITS: u32 = u64::BITS; } impl NumInfo for u32 { const BITS: u32 = u32::BITS; } impl NumInfo for u128 { const BITS: u32 = u128::BITS; } impl DefaultDecomposer { fn recompose(&self, limbs: &[T], modq_op: &Op) -> T where Op: ArithmeticOps, { let mut value = T::zero(); for i in 0..self.d { value = modq_op.add( &value, &(modq_op.mul( &limbs[i], &(T::one() << (self.logb * (i + self.ignore_limbs))), )), ) } value } pub(crate) fn gadget_vector(&self) -> Vec { return gadget_vector(self.logq, self.logb, self.d); } } impl Decomposer for DefaultDecomposer { type Element = T; type Iter = DecomposerIter; fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1. let logq = if q & (q - T::one()) == T::zero() { (T::BITS - q.leading_zeros() - 1) as usize } else { (T::BITS - q.leading_zeros()) as usize }; let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); let ignore_limbs = (d_ideal - d); let ignore_bits = (d_ideal - d) * logb; DefaultDecomposer { q, logq, logb, b: T::one() << logb, b_mask: (T::one() << logb) - T::one(), bby2: T::one() << (logb - 1), d, ignore_bits, ignore_limbs, } } // TODO(Jay): Outline the caveat fn decompose_to_vec(&self, value: &T) -> Vec { let mut value = round_value(*value, self.ignore_bits); let q = self.q; let logb = self.logb; let b = T::one() << logb; let full_mask = b - T::one(); let bby2 = b >> 1; if value >= (q >> 1) { value = !(q - value) + T::one() } let mut out = Vec::with_capacity(self.d); for _ in 0..self.d { let k_i = value & full_mask; value = (value - k_i) >> logb; if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) { out.push(q - (b - k_i)); value = value + T::one(); } else { out.push(k_i); } } return out; } fn decomposition_count(&self) -> usize { self.d } fn decompose_iter(&self, value: &T) -> DecomposerIter { let mut value = round_value(*value, self.ignore_bits); if value >= (self.q >> 1) { value = !(self.q - value) + T::one() } DecomposerIter { value, q: self.q, logb: self.logb, b: self.b, bby2: self.bby2, b_mask: self.b_mask, steps_left: self.d, } } } impl DefaultDecomposer {} pub struct DecomposerIter { /// Value to decompose value: T, steps_left: usize, /// (1 << logb) - 1 (for % (1< Iterator for DecomposerIter { type Item = T; fn next(&mut self) -> Option { if self.steps_left != 0 { self.steps_left -= 1; let k_i = self.value & self.b_mask; self.value = (self.value - k_i) >> self.logb; if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())) { self.value = self.value + T::one(); Some(self.q + k_i - self.b) } else { Some(k_i) } // let carry = >::from( // k_i > self.bby2 || (k_i == self.bby2 && ((self.value & // T::one()) == T::one())), ); // self.value = self.value + carry; // Some((self.q & ((carry << 55) - (T::one() & carry))) + k_i - // (carry << self.logb)) // Some(k_i) } else { None } } } fn round_value(value: T, ignore_bits: usize) -> T { if ignore_bits == 0 { return value; } let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1); (value >> ignore_bits) + ignored_msb } #[cfg(test)] mod tests { use itertools::Itertools; use rand::{thread_rng, Rng}; use crate::{ backend::{ModInit, ModularOpsU64}, decomposer::round_value, utils::{generate_prime, Stats, TryConvertFrom1}, }; use super::{Decomposer, DefaultDecomposer}; #[test] fn decomposition_works() { let logq = 55; let logb = 11; let d = 5; let ring_size = 1 << 11; let mut rng = thread_rng(); let mut stats = Stats { samples: vec![] }; for i in [true] { let q = if i { generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap() } else { 1u64 << logq }; let decomposer = DefaultDecomposer::new(q, logb, d); let modq_op = ModularOpsU64::new(q); for _ in 0..100000 { let value = rng.gen_range(0..q); let limbs = decomposer.decompose_to_vec(&value); let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec(); assert_eq!(limbs, limbs_from_iter); let value_back = round_value( decomposer.recompose(&limbs, &modq_op), decomposer.ignore_bits, ); let rounded_value = round_value(value, decomposer.ignore_bits); assert_eq!( rounded_value, value_back, "Expected {rounded_value} got {value_back} for q={q}" ); stats.add_more(&Vec::::try_convert_from(&limbs, &q)); } } println!("Mean: {}", stats.mean()); println!("Std: {}", stats.std_dev().abs().log2()); } }