use itertools::Itertools; use num_traits::{AsPrimitive, One, PrimInt, ToPrimitive, WrappingSub, Zero}; use std::{fmt::Debug, marker::PhantomData, ops::Rem}; use crate::backend::{ArithmeticOps, ModularOpsU64}; pub fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); let ignored_limbs = d_ideal - d; (ignored_limbs..ignored_limbs + d) .into_iter() .map(|i| T::one() << (logb * i)) .collect_vec() } pub trait Decomposer { type Element; //FIXME(Jay): there's no reason why it returns a vec instead of an iterator fn decompose(&self, v: &Self::Element) -> Vec; fn d(&self) -> usize; } // TODO(Jay): Shouldn't Decompose also return corresponding gadget vector ? pub struct DefaultDecomposer { q: T, logq: usize, logb: usize, d: usize, ignore_bits: usize, ignore_limbs: usize, } pub trait NumInfo { const BITS: u32; } impl NumInfo for u64 { const BITS: u32 = u64::BITS; } impl NumInfo for u32 { const BITS: u32 = u32::BITS; } impl NumInfo for u128 { const BITS: u32 = u128::BITS; } impl DefaultDecomposer { pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1. let logq = if q & (q - T::one()) == T::zero() { (T::BITS - q.leading_zeros() - 1) as usize } else { (T::BITS - q.leading_zeros()) as usize }; let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); let ignore_limbs = (d_ideal - d); let ignore_bits = (d_ideal - d) * logb; DefaultDecomposer { q, logq, logb, d, ignore_bits, ignore_limbs, } } fn recompose(&self, limbs: &[T], modq_op: &Op) -> T where Op: ArithmeticOps, { let mut value = T::zero(); for i in 0..self.d { value = modq_op.add( &value, &(modq_op.mul( &limbs[i], &(T::one() << (self.logb * (i + self.ignore_limbs))), )), ) } value } } impl Decomposer for DefaultDecomposer { type Element = T; fn decompose(&self, value: &T) -> Vec { let mut value = round_value(*value, self.ignore_bits); let q = self.q; // if value >= (q >> 1) { // value = value.wrapping_sub(&q); // } let logb = self.logb; let b = T::one() << logb; // base let b_by2 = T::one() << (logb - 1); // let neg_b_by2_modq = q - b_by2; let full_mask = (T::one() << logb) - T::one(); // let half_mask = b_by2 - T::one(); let mut carry = T::zero(); let mut out = Vec::::with_capacity(self.d); for i in 0..self.d { let mut limb = ((value >> (logb * i)) & full_mask) + carry; carry = T::zero(); if limb >= b_by2 { limb = (q + limb) - b; carry = T::one(); } // carry = ((q + g - limb) % q) >> logb; // carry = limb & b_by2; // limb = (q + limb) - (carry << 1); // if limb > q { // limb = limb - q; // } out.push(limb); // carry = carry >> (logb - 1); } out[self.d - 1] = out[self.d - 1] + (carry << logb); if out[self.d - 1] > q { out[self.d - 1] = out[self.d - 1] - q; } return out; } fn d(&self) -> usize { self.d } } fn round_value(value: T, ignore_bits: usize) -> T { if ignore_bits == 0 { return value; } let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1); (value >> ignore_bits) + ignored_msb } #[cfg(test)] mod tests { use rand::{thread_rng, Rng}; use crate::{ backend::{ModInit, ModularOpsU64}, decomposer::round_value, utils::generate_prime, }; use super::{Decomposer, DefaultDecomposer}; #[test] fn decomposition_works() { let logq = 55; let logb = 9; let d = 6; let mut rng = thread_rng(); // q is prime of bits logq and i is true, other q = 1<