use crate::dft::ntt::Table; use crate::modulus::barrett::Barrett; use crate::modulus::montgomery::Montgomery; use crate::modulus::prime::Prime; use crate::modulus::VectorOperations; use crate::modulus::{BARRETT, REDUCEMOD}; use crate::poly::Poly; use crate::ring::Ring; use crate::CHUNK; use num_bigint::BigInt; use num_traits::ToPrimitive; impl Ring { pub fn new(n: usize, q_base: u64, q_power: usize) -> Self { let prime: Prime = Prime::::new(q_base, q_power); Self { n: n, modulus: prime.clone(), dft: Box::new(Table::::new(prime, n << 1)), } } pub fn from_bigint(&self, coeffs: &[BigInt], step: usize, a: &mut Poly) { assert!( step <= a.n(), "invalid step: step={} > a.n()={}", step, a.n() ); assert!( coeffs.len() <= a.n() / step, "invalid coeffs: coeffs.len()={} > a.n()/step={}", coeffs.len(), a.n() / step ); let q_big: BigInt = BigInt::from(self.modulus.q); a.0.iter_mut() .step_by(step) .enumerate() .for_each(|(i, v)| *v = (&coeffs[i] % &q_big).to_u64().unwrap()); } } impl Ring { pub fn ntt_inplace(&self, poly: &mut Poly) { match LAZY { true => self.dft.forward_inplace_lazy(&mut poly.0), false => self.dft.forward_inplace(&mut poly.0), } } pub fn intt_inplace(&self, poly: &mut Poly) { match LAZY { true => self.dft.backward_inplace_lazy(&mut poly.0), false => self.dft.backward_inplace(&mut poly.0), } } pub fn ntt(&self, poly_in: &Poly, poly_out: &mut Poly) { poly_out.0.copy_from_slice(&poly_in.0); match LAZY { true => self.dft.forward_inplace_lazy(&mut poly_out.0), false => self.dft.forward_inplace(&mut poly_out.0), } } pub fn intt(&self, poly_in: &Poly, poly_out: &mut Poly) { poly_out.0.copy_from_slice(&poly_in.0); match LAZY { true => self.dft.backward_inplace_lazy(&mut poly_out.0), false => self.dft.backward_inplace(&mut poly_out.0), } } } impl Ring { #[inline(always)] pub fn a_add_b_into_b(&self, a: &Poly, b: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .va_add_vb_into_vb::(&a.0, &mut b.0); } #[inline(always)] pub fn a_add_b_into_c( &self, a: &Poly, b: &Poly, c: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus .va_add_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] pub fn a_add_b_scalar_into_a(&self, b: &u64, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus.va_add_sb_into_va::(b, &mut a.0); } #[inline(always)] pub fn a_add_b_scalar_into_c( &self, a: &Poly, b: &u64, c: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus .va_add_sb_into_vc::(&a.0, b, &mut c.0); } #[inline(always)] pub fn a_add_scalar_b_mul_c_scalar_barrett_into_a( &self, b: &u64, c: &Barrett, a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n()); self.modulus .va_add_sb_mul_sc_into_va::(b, c, &mut a.0); } #[inline(always)] pub fn add_scalar_then_mul_scalar_barrett( &self, a: &Poly, b: &u64, c: &Barrett, d: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(d.n() == self.n(), "c.n()={} != n={}", d.n(), self.n()); self.modulus .va_add_sb_mul_sc_into_vd::(&a.0, b, c, &mut d.0); } #[inline(always)] pub fn a_sub_b_into_b( &self, a: &Poly, b: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .va_sub_vb_into_vb::(&a.0, &mut b.0); } #[inline(always)] pub fn a_sub_b_into_a( &self, b: &Poly, a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .va_sub_vb_into_va::(&b.0, &mut a.0); } #[inline(always)] pub fn a_sub_b_into_c( &self, a: &Poly, b: &Poly, c: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus .va_sub_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] pub fn a_neg_into_b( &self, a: &Poly, b: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .va_neg_into_vb::(&a.0, &mut b.0); } #[inline(always)] pub fn a_neg_into_a(&self, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus .va_neg_into_va::(&mut a.0); } #[inline(always)] pub fn a_mul_b_montgomery_into_c( &self, a: &Poly>, b: &Poly, c: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus .va_mont_mul_vb_into_vc::(&a.0, &b.0, &mut c.0); } #[inline(always)] pub fn a_mul_b_montgomery_into_a( &self, b: &Poly>, a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .va_mont_mul_vb_into_vb::(&b.0, &mut a.0); } #[inline(always)] pub fn a_mul_b_scalar_into_c( &self, a: &Poly, b: &u64, c: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n()); self.modulus.sa_barrett_mul_vb_into_vc::( &self.modulus.barrett.prepare(*b), &a.0, &mut c.0, ); } #[inline(always)] pub fn a_mul_b_scalar_into_a(&self, b: &u64, a: &mut Poly) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus.sa_barrett_mul_vb_into_vb::( &self .modulus .barrett .prepare(self.modulus.barrett.reduce::(b)), &mut a.0, ); } #[inline(always)] pub fn a_mul_b_scalar_barrett_into_a( &self, b: &Barrett, a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); self.modulus .sa_barrett_mul_vb_into_vb::(b, &mut a.0); } #[inline(always)] pub fn a_mul_b_scalar_barrett_into_c( &self, a: &Barrett, b: &Poly, c: &mut Poly, ) { debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .sa_barrett_mul_vb_into_vc::(a, &b.0, &mut c.0); } #[inline(always)] pub fn a_sub_b_mul_c_scalar_barrett_into_d( &self, a: &Poly, b: &Poly, c: &Barrett, d: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(d.n() == self.n(), "d.n()={} != n={}", d.n(), self.n()); self.modulus .va_sub_vb_mul_sc_into_vd::(&a.0, &b.0, c, &mut d.0); } #[inline(always)] pub fn b_sub_a_mul_c_scalar_barrett_into_a( &self, b: &Poly, c: &Barrett, a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .va_sub_vb_mul_sc_into_vb::(&b.0, c, &mut a.0); } #[inline(always)] pub fn a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e< const BRANGE: u8, const REDUCE: REDUCEMOD, >( &self, a: &Poly, b: &Poly, c: &u64, d: &Barrett, e: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); debug_assert!(e.n() == self.n(), "e.n()={} != n={}", e.n(), self.n()); self.modulus .vb_sub_va_add_sc_mul_sd_into_ve::(&a.0, &b.0, c, d, &mut e.0); } #[inline(always)] pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a< const BRANGE: u8, const REDUCE: REDUCEMOD, >( &self, b: &Poly, c: &u64, d: &Barrett, a: &mut Poly, ) { debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n()); debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n()); self.modulus .vb_sub_va_add_sc_mul_sd_into_va::(&b.0, c, d, &mut a.0); } }