use itertools::{izip, Itertools}; use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub}; use std::fmt::{Debug, Display}; use crate::{ backend::ArithmeticOps, parameters::{ DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams, }, utils::log2, }; fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { assert!(logq >= (logb * d)); let ignored_bits = logq - (logb * d); (0..d) .into_iter() .map(|i| T::one() << (logb * i + ignored_bits)) .collect_vec() } pub trait RlweDecomposer { type Element; type D: Decomposer; /// Decomposer for RLWE Part A fn a(&self) -> &Self::D; /// Decomposer for RLWE Part B fn b(&self) -> &Self::D; } impl RlweDecomposer for (D, D) where D: Decomposer, { type D = D; type Element = D::Element; fn a(&self) -> &Self::D { &self.0 } fn b(&self) -> &Self::D { &self.1 } } impl DoubleDecomposerParams for D where D: RlweDecomposer, { type Base = DecompostionLogBase; type Count = DecompositionCount; fn decomposition_base(&self) -> Self::Base { assert!( Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b()) ); Decomposer::decomposition_base(self.a()) } fn decomposition_count_a(&self) -> Self::Count { Decomposer::decomposition_count(self.a()) } fn decomposition_count_b(&self) -> Self::Count { Decomposer::decomposition_count(self.b()) } } impl SingleDecomposerParams for D where D: Decomposer, { type Base = DecompostionLogBase; type Count = DecompositionCount; fn decomposition_base(&self) -> Self::Base { Decomposer::decomposition_base(self) } fn decomposition_count(&self) -> Self::Count { Decomposer::decomposition_count(self) } } pub trait Decomposer { type Element; type Iter: Iterator; fn new(q: Self::Element, logb: usize, d: usize) -> Self; fn decompose_to_vec(&self, v: &Self::Element) -> Vec; fn decompose_iter(&self, v: &Self::Element) -> Self::Iter; fn decomposition_count(&self) -> DecompositionCount; fn decomposition_base(&self) -> DecompostionLogBase; fn gadget_vector(&self) -> Vec; } pub struct DefaultDecomposer { /// Ciphertext modulus q: T, /// Log of ciphertext modulus logq: usize, /// Log of base B logb: usize, /// base B b: T, /// (B - 1). To simulate (% B) as &(B-1), that is extract least significant /// logb bits b_mask: T, /// B/2 bby2: T, /// Decomposition count d: usize, /// No. of bits to ignore in rounding ignore_bits: usize, } pub trait NumInfo { const BITS: u32; } impl NumInfo for u64 { const BITS: u32 = u64::BITS; } impl NumInfo for u32 { const BITS: u32 = u32::BITS; } impl NumInfo for u128 { const BITS: u32 = u128::BITS; } impl DefaultDecomposer { fn recompose(&self, limbs: &[T], modq_op: &Op) -> T where Op: ArithmeticOps, { let mut value = T::zero(); let gadget_vector = gadget_vector(self.logq, self.logb, self.d); assert!(limbs.len() == gadget_vector.len()); izip!(limbs.iter(), gadget_vector.iter()) .for_each(|(d_el, beta)| value = modq_op.add(&value, &modq_op.mul(d_el, beta))); value } } impl< T: PrimInt + ToPrimitive + FromPrimitive + WrappingSub + WrappingAdd + NumInfo + From + Display + Debug, > Decomposer for DefaultDecomposer { type Element = T; type Iter = DecomposerIter; fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1. let logq = log2(&q); assert!( logq >= (logb * d), "Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}" ); let ignore_bits = logq - (logb * d); DefaultDecomposer { q, logq, logb, b: T::one() << logb, b_mask: (T::one() << logb) - T::one(), bby2: T::one() << (logb - 1), d, ignore_bits, } } fn decompose_to_vec(&self, value: &T) -> Vec { let q = self.q; let logb = self.logb; let b = T::one() << logb; let full_mask = b - T::one(); let bby2 = b >> 1; let mut value = *value; if value >= (q >> 1) { value = !(q - value) + T::one() } value = round_value(value, self.ignore_bits); let mut out = Vec::with_capacity(self.d); for _ in 0..(self.d) { let k_i = value & full_mask; value = (value - k_i) >> logb; if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) { out.push(q - (b - k_i)); value = value + T::one(); } else { out.push(k_i); } } return out; } fn decomposition_count(&self) -> DecompositionCount { DecompositionCount(self.d) } fn decomposition_base(&self) -> DecompostionLogBase { DecompostionLogBase(self.logb) } fn decompose_iter(&self, value: &T) -> DecomposerIter { let mut value = *value; if value >= (self.q >> 1) { value = !(self.q - value) + T::one() } value = round_value(value, self.ignore_bits); DecomposerIter { value, q: self.q, logq: self.logq, logb: self.logb, b: self.b, bby2: self.bby2, b_mask: self.b_mask, steps_left: self.d, } } fn gadget_vector(&self) -> Vec { return gadget_vector(self.logq, self.logb, self.d); } } impl DefaultDecomposer {} pub struct DecomposerIter { /// Value to decompose value: T, steps_left: usize, /// (1 << logb) - 1 (for % (1< + WrappingSub + Display> Iterator for DecomposerIter { type Item = T; fn next(&mut self) -> Option { if self.steps_left != 0 { self.steps_left -= 1; let k_i = self.value & self.b_mask; self.value = (self.value - k_i) >> self.logb; // if k_i > self.bby2 || (k_i == self.bby2 && ((self.value & // T::one()) == T::one())) { self.value = self.value // + T::one(); Some(self.q + k_i - self.b) // } else { // Some(k_i) // } // Following is without branching impl of the commented version above. It // happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other // parameters as well but I haven't tested) by roughly 15ms. // Suprisingly the improvement does not show up when I benchmark // `decomposer_iter` in isolation. Putting this remark here as a // future task to investiage (TODO). let carry_bool = k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one())); let carry = >::from(carry_bool); let neg_carry = (T::zero().wrapping_sub(&carry)); self.value = self.value + carry; Some((neg_carry & self.q) + k_i - (carry << self.logb)) // Some( // (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i // - (carry << self.logb), ) // Some(k_i) } else { None } } } fn round_value(value: T, ignore_bits: usize) -> T { if ignore_bits == 0 { return value; } let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1); (value >> ignore_bits).wrapping_add(&ignored_msb) } #[cfg(test)] mod tests { use itertools::Itertools; use rand::{thread_rng, Rng}; use crate::{ backend::{ModInit, ModularOpsU64}, decomposer::round_value, utils::generate_prime, }; use super::{Decomposer, DefaultDecomposer}; #[test] fn decomposition_works() { let ring_size = 1 << 11; let mut rng = thread_rng(); for logq in [37, 55] { let logb = 11; let d = 3; // let mut stats = vec![Stats::new(); d]; for i in [true, false] { let q = if i { generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap() } else { 1u64 << logq }; let decomposer = DefaultDecomposer::new(q, logb, d); let modq_op = ModularOpsU64::new(q); for _ in 0..1000000 { let value = rng.gen_range(0..q); let limbs = decomposer.decompose_to_vec(&value); let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec(); assert_eq!(limbs, limbs_from_iter); let value_back = round_value( decomposer.recompose(&limbs, &modq_op), decomposer.ignore_bits, ); let rounded_value = round_value(value, decomposer.ignore_bits); assert!((rounded_value as i64 - value_back as i64).abs() <= 1,); // izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| { // s.add_more(&vec![q.map_element_to_i64(l)]); // }); } } // stats.iter().enumerate().for_each(|(index, s)| { // println!( // "Limb {index} - Mean: {}, Std: {}", // s.mean(), // s.std_dev().abs().log2() // ); // }); } } }