fixed rounding rescaling

This commit is contained in:
Jean-Philippe Bossuat
2025-01-08 11:06:56 +01:00
parent 3db800f4ce
commit bdd57b91ed
13 changed files with 649 additions and 362 deletions

19
Cargo.lock generated
View File

@@ -256,6 +256,12 @@ version = "0.2.167"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc"
[[package]]
name = "libm"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
[[package]]
name = "log"
version = "0.4.22"
@@ -270,9 +276,11 @@ dependencies = [
"itertools 0.14.0",
"num",
"num-bigint",
"num-integer",
"num-traits",
"primality-test",
"prime_factorization",
"rand_distr",
"sampling",
]
@@ -353,6 +361,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@@ -469,6 +478,16 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rayon"
version = "1.10.0"

View File

@@ -8,9 +8,11 @@ num = "0.4.3"
primality-test = "0.3.0"
num-bigint = "0.4.6"
num-traits = "0.2.19"
num-integer ="0.1.46"
prime_factorization = "1.0.5"
itertools = "0.14.0"
criterion = "0.5.1"
rand_distr = "0.4.3"
sampling = { path = "../sampling" }
[[bench]]

View File

@@ -8,7 +8,7 @@ fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
let mut b: PolyRNS<u64> = r.new_polyrns();
let mut c: PolyRNS<u64> = r.new_polyrns();
Box::new(move || r.div_floor_by_last_modulus::<true>(&a, &mut b, &mut c))
Box::new(move || r.div_by_last_modulus::<false, true>(&a, &mut b, &mut c))
}
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> =

View File

@@ -6,6 +6,7 @@ pub mod modulus;
pub mod poly;
pub mod ring;
pub mod scalar;
pub mod num_bigint;
pub const CHUNK: usize = 8;
@@ -398,4 +399,59 @@ pub mod macros {
}
};
}
#[macro_export]
macro_rules! apply_vvssv {
($self:expr, $f:expr, $a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $CHUNK:expr) => {
let n: usize = $a.len();
debug_assert!(
$b.len() == n,
"invalid argument b: b.len() = {} != a.len() = {}",
$b.len(),
n
);
debug_assert!(
$e.len() == n,
"invalid argument e: e.len() = {} != a.len() = {}",
$e.len(),
n
);
debug_assert!(
CHUNK & (CHUNK - 1) == 0,
"invalid CHUNK const: not a power of two"
);
match CHUNK {
8 => {
izip!(
$a.chunks_exact(8),
$b.chunks_exact(8),
$e.chunks_exact_mut(8)
)
.for_each(|(a, b, e)| {
$f(&$self, &a[0], &b[0], $c, $d, &mut e[0]);
$f(&$self, &a[1], &b[1], $c, $d, &mut e[1]);
$f(&$self, &a[2], &b[2], $c, $d, &mut e[2]);
$f(&$self, &a[3], &b[3], $c, $d, &mut e[3]);
$f(&$self, &a[4], &b[4], $c, $d, &mut e[4]);
$f(&$self, &a[5], &b[5], $c, $d, &mut e[5]);
$f(&$self, &a[6], &b[6], $c, $d, &mut e[6]);
$f(&$self, &a[7], &b[7], $c, $d, &mut e[7]);
});
let m = n - (n & 7);
izip!($a[m..].iter(), $b[m..].iter(), $e[m..].iter_mut()).for_each(
|(a, b, e)| {
$f(&$self, a, b, $c, $d, e);
},
);
}
_ => {
izip!($a.iter(), $b.iter(), $e.iter_mut()).for_each(|(a, b, e)| {
$f(&$self, a, b, $c, $d, e);
});
}
}
};
}
}

View File

@@ -79,6 +79,9 @@ pub trait ScalarOperations<O> {
// Assigns a - b to b.
fn sa_sub_sb_into_sb<const SARANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &O, b: &mut O);
// Assigns a - b to a.
fn sa_sub_sb_into_sa<const SARANGE: u8, const REDUCE: REDUCEMOD>(&self, b: &O, a: &mut O);
// Assigns -a to a.
fn sa_neg_into_sa<const SBRANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &mut O);
@@ -155,6 +158,24 @@ pub trait ScalarOperations<O> {
c: &barrett::Barrett<u64>,
d: &mut u64
);
// Assigns (a - b + c) * d to e.
fn sb_sub_sa_add_sc_mul_sd_into_se<const SBRANGE: u8,const REDUCE: REDUCEMOD>(
&self,
a: &u64,
b: &u64,
c: &u64,
d: &barrett::Barrett<u64>,
e: &mut u64
);
fn sb_sub_sa_add_sc_mul_sd_into_sa<const SBRANGE: u8,const REDUCE: REDUCEMOD>(
&self,
b: &u64,
c: &u64,
d: &barrett::Barrett<u64>,
a: &mut u64
);
}
pub trait VectorOperations<O> {
@@ -187,6 +208,9 @@ pub trait VectorOperations<O> {
// vec(b) <- vec(a) - vec(b).
fn va_sub_vb_into_vb<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &[O], b: &mut [O]);
// vec(a) <- vec(a) - vec(b).
fn va_sub_vb_into_va<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(&self, b: &[O], a: &mut [O]);
// vec(c) <- vec(a) - vec(b).
fn va_sub_vb_into_vc<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
@@ -271,4 +295,23 @@ pub trait VectorOperations<O> {
sc: &barrett::Barrett<u64>,
va: &mut [u64],
);
// vec(e) <- (vec(b) - vec(a) + scalar(c)) * scalar(e).
fn vb_sub_va_add_sc_mul_sd_into_ve<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
va: &[u64],
vb: &[u64],
sc: &u64,
sd: &barrett::Barrett<u64>,
ve: &mut [u64],
);
// vec(a) <- (vec(b) - vec(a) + scalar(c)) * scalar(e).
fn vb_sub_va_add_sc_mul_sd_into_va<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
vb: &[u64],
sc: &u64,
sd: &barrett::Barrett<u64>,
va: &mut [u64],
);
}

View File

@@ -3,7 +3,7 @@ use crate::modulus::montgomery::Montgomery;
use crate::modulus::prime::Prime;
use crate::modulus::{REDUCEMOD, NONE};
use crate::modulus::{ScalarOperations, VectorOperations};
use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv, apply_ssv, apply_vssv};
use crate::{apply_sv, apply_svv, apply_v, apply_vsv, apply_vv, apply_vvsv, apply_vvv, apply_ssv, apply_vssv, apply_vvssv};
use itertools::izip;
impl ScalarOperations<u64> for Prime<u64> {
@@ -42,6 +42,17 @@ impl ScalarOperations<u64> for Prime<u64> {
self.sa_reduce_into_sa::<REDUCE>(c)
}
#[inline(always)]
fn sa_sub_sb_into_sa<const SBRANGE: u8, const REDUCE: REDUCEMOD>(&self, b: &u64, a: &mut u64) {
match SBRANGE{
1 =>{*a = *a + self.q - *b}
2 =>{*a = *a + self.two_q - *b}
4 =>{*a = *a + self.four_q - *b}
_ => unreachable!("invalid SBRANGE argument"),
}
self.sa_reduce_into_sa::<REDUCE>(a)
}
#[inline(always)]
fn sa_sub_sb_into_sb<const SBRANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut u64) {
match SBRANGE{
@@ -159,6 +170,31 @@ impl ScalarOperations<u64> for Prime<u64> {
*a = self.barrett.mul_external::<REDUCE>(*c, *a + *b);
}
#[inline(always)]
fn sb_sub_sa_add_sc_mul_sd_into_se<const SBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
a: &u64,
b: &u64,
c: &u64,
d: &Barrett<u64>,
e: &mut u64
) {
self.sa_sub_sb_into_sc::<SBRANGE, NONE>(&(b + c), a, e);
self.barrett.mul_external_assign::<REDUCE>(*d, e);
}
#[inline(always)]
fn sb_sub_sa_add_sc_mul_sd_into_sa<const SBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
b: &u64,
c: &u64,
d: &Barrett<u64>,
a: &mut u64
) {
self.sa_sub_sb_into_sb::<SBRANGE, NONE>(&(b + c), a);
self.barrett.mul_external_assign::<REDUCE>(*d, a);
}
}
impl VectorOperations<u64> for Prime<u64> {
@@ -222,6 +258,15 @@ impl VectorOperations<u64> for Prime<u64> {
apply_vvv!(self, Self::sa_sub_sb_into_sc::<VBRANGE, REDUCE>, a, b, c, CHUNK);
}
#[inline(always)]
fn va_sub_vb_into_va<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
b: &[u64],
a: &mut [u64],
) {
apply_vv!(self, Self::sa_sub_sb_into_sa::<VBRANGE, REDUCE>, b, a, CHUNK);
}
#[inline(always)]
fn va_sub_vb_into_vb<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
@@ -368,4 +413,45 @@ impl VectorOperations<u64> for Prime<u64> {
CHUNK
);
}
// vec(e) <- (vec(a) - vec(b) + scalar(c)) * scalar(e).
fn vb_sub_va_add_sc_mul_sd_into_ve<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
va: &[u64],
vb: &[u64],
sc: &u64,
sd: &Barrett<u64>,
ve: &mut [u64],
){
apply_vvssv!(
self,
Self::sb_sub_sa_add_sc_mul_sd_into_se::<VBRANGE, REDUCE>,
va,
vb,
sc,
sd,
ve,
CHUNK
);
}
// vec(a) <- (vec(b) - vec(a) + scalar(c)) * scalar(e).
fn vb_sub_va_add_sc_mul_sd_into_va<const CHUNK: usize, const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
vb: &[u64],
sc: &u64,
sd: &Barrett<u64>,
va: &mut [u64],
){
apply_vssv!(
self,
Self::sb_sub_sa_add_sc_mul_sd_into_sa::<VBRANGE, REDUCE>,
vb,
sc,
sd,
va,
CHUNK
);
}
}

34
math/src/num_bigint.rs Normal file
View File

@@ -0,0 +1,34 @@
use num_bigint::BigInt;
use num_bigint::Sign;
use num_integer::Integer;
use num_traits::{Zero, One, Signed};
pub trait Div{
fn div_floor(&self, other: &Self) -> Self;
fn div_round(&self, other: &Self) -> Self;
}
impl Div for BigInt{
fn div_floor(&self, other:&Self) -> Self{
let quo: BigInt = self / other;
if self.sign() == Sign::Minus {
return quo - BigInt::one()
}
return quo
}
fn div_round(&self, other:&Self) -> Self{
let (quo, mut rem) = self.div_rem(other);
rem <<= 1;
if rem != BigInt::zero() && &rem.abs() > other{
if self.sign() == other.sign(){
return quo + BigInt::one()
}else{
return quo - BigInt::one()
}
}
return quo
}
}

View File

@@ -8,7 +8,7 @@ extern crate test;
impl RingRNS<u64> {
/// Updates b to floor(a / q[b.level()]).
pub fn div_floor_by_last_modulus<const NTT: bool>(
pub fn div_by_last_modulus<const ROUND: bool, const NTT: bool>(
&self,
a: &PolyRNS<u64>,
buf: &mut PolyRNS<u64>,
@@ -30,34 +30,76 @@ impl RingRNS<u64> {
let level = self.level();
let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant();
let r_last: &Ring<u64> = &self.0[level];
if ROUND{
if NTT {
let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1);
self.0[level].intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>(
&buf_ntt_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
let q_level_half: u64 = r_last.modulus.q >> 1;
let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1);
if NTT {
r_last.intt::<false>(a.at(level), &mut buf_q_scaling[0]);
r_last.a_add_b_scalar_into_a::<ONCE>(&q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.a_add_b_scalar_into_c::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.ntt_inplace::<true>(&mut buf_qi_scaling[0]);
r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>(
&buf_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}
} else {
r_last.a_add_b_scalar_into_c::<ONCE>(a.at(self.level()), &q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.a_add_b_scalar_into_c::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>(
&buf_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}
}
} else {
for (i, r) in self.0[0..level].iter().enumerate() {
r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>(
a.at(level),
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}else{
if NTT {
let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1);
self.0[level].intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>(
&buf_ntt_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}
} else {
for (i, r) in self.0[0..level].iter().enumerate() {
r.a_sub_b_mul_c_scalar_barrett_into_d::<2, ONCE>(
a.at(level),
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}
}
}
}
/// Updates a to floor(a / q[b.level()]).
/// Expects a to be in the NTT domain.
pub fn div_floor_by_last_modulus_inplace<const NTT: bool>(
pub fn div_by_last_modulus_inplace<const ROUND: bool, const NTT: bool>(
&self,
buf: &mut PolyRNS<u64>,
a: &mut PolyRNS<u64>,
@@ -71,32 +113,70 @@ impl RingRNS<u64> {
let level = self.level();
let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant();
let r_last: &Ring<u64> = &self.0[level];
if NTT {
let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1);
self.0[level].intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>(
&buf_ntt_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
if ROUND{
let q_level_half: u64 = r_last.modulus.q >> 1;
let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1);
if NTT {
r_last.intt::<false>(a.at(level), &mut buf_q_scaling[0]);
r_last.a_add_b_scalar_into_a::<ONCE>(&q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.a_add_b_scalar_into_c::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.ntt_inplace::<false>(&mut buf_qi_scaling[0]);
r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>(
&buf_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
}
} else {
let (a_qi, a_q_last) = a.0.split_at_mut(self.level());
r_last.a_add_b_scalar_into_a::<ONCE>(&q_level_half, &mut a_q_last[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r.b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a::<1, ONCE>(
&a_q_last[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&rescaling_constants.0[i],
&mut a_qi[i],
);
}
}
} else {
let (a_i, a_level) = a.0.split_at_mut(level);
for (i, r) in self.0[0..level].iter().enumerate() {
r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>(
&a_level[0],
&rescaling_constants.0[i],
&mut a_i[i],
);
}else{
if NTT {
let (buf_ntt_q_scaling, buf_ntt_qi_scaling) = buf.0.split_at_mut(1);
r_last.intt::<false>(a.at(level), &mut buf_ntt_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r.ntt::<true>(&buf_ntt_q_scaling[0], &mut buf_ntt_qi_scaling[0]);
r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>(
&buf_ntt_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
}
}else{
let (a_i, a_level) = a.0.split_at_mut(level);
for (i, r) in self.0[0..level].iter().enumerate() {
r.b_sub_a_mul_c_scalar_barrett_into_a::<2, ONCE>(
&a_level[0],
&rescaling_constants.0[i],
&mut a_i[i],
);
}
}
}
}
/// Updates b to floor(a / prod_{level - nb_moduli}^{level} q[i])
pub fn div_floor_by_last_moduli<const NTT: bool>(
pub fn div_by_last_moduli<const ROUND: bool, const NTT: bool>(
&self,
nb_moduli: usize,
a: &PolyRNS<u64>,
@@ -133,38 +213,35 @@ impl RingRNS<u64> {
c.copy(a);
}
} else {
if NTT {
if NTT{
self.intt::<false>(a, buf);
(0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(
.div_by_last_modulus_inplace::<ROUND, false>(
&mut PolyRNS::<u64>::default(),
buf,
)
});
self.at_level(self.level() - nb_moduli).ntt::<false>(buf, c);
} else {
let empty_buf: &mut PolyRNS<u64> = &mut PolyRNS::<u64>::default();
if nb_moduli == 1{
self.div_floor_by_last_modulus::<false>(a, empty_buf, c);
}else{
self.div_floor_by_last_modulus::<false>(a, empty_buf, buf);
}
}else{
(1..nb_moduli-1).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(empty_buf, buf);
});
println!("{} {:?}", self.level(), buf.level());
self.div_by_last_modulus::<ROUND, false>(a, buf, c);
self.at_level(self.level()-nb_moduli+1).div_floor_by_last_modulus::<false>(buf, empty_buf, c);
(1..nb_moduli-1).for_each(|i| {
println!("{} {:?}", self.level() - i, buf.level());
self.at_level(self.level() - i)
.div_by_last_modulus_inplace::<ROUND, false>(buf, c);
});
self.at_level(self.level()-nb_moduli+1).div_by_last_modulus_inplace::<ROUND, false>(buf, c);
}
}
}
/// Updates a to floor(a / prod_{level - nb_moduli}^{level} q[i])
pub fn div_floor_by_last_moduli_inplace<const NTT: bool>(
pub fn div_by_last_moduli_inplace<const ROUND:bool, const NTT: bool>(
&self,
nb_moduli: usize,
buf: &mut PolyRNS<u64>,
@@ -185,218 +262,18 @@ impl RingRNS<u64> {
if nb_moduli == 0{
return
}
if NTT {
self.intt::<false>(a, buf);
(0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(&mut PolyRNS::<u64>::default(), buf)
});
self.at_level(self.level() - nb_moduli+1).ntt::<false>(buf, a);
} else {
(0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_floor_by_last_modulus_inplace::<false>(buf, a);
});
}
}
/// Updates b to round(a / q[b.level()]).
/// Expects b to be in the NTT domain.
pub fn div_round_by_last_modulus<const NTT: bool>(
&self,
a: &PolyRNS<u64>,
buf: &mut PolyRNS<u64>,
b: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
b.level() >= a.level() - 1,
"invalid input b: b.level()={} < a.level()-1={}",
b.level(),
a.level() - 1
);
let level: usize = self.level();
let r_last: &Ring<u64> = &self.0[level];
let q_level_half: u64 = r_last.modulus.q >> 1;
let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant();
let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1);
if NTT {
r_last.intt::<false>(a.at(level), &mut buf_q_scaling[0]);
r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.add_scalar::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.ntt_inplace::<true>(&mut buf_qi_scaling[0]);
r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>(
&buf_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}
} else {
r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.add_scalar::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.a_sub_b_mul_c_scalar_barrett::<2, ONCE>(
&buf_qi_scaling[0],
a.at(i),
&rescaling_constants.0[i],
b.at_mut(i),
);
}
}
}
/// Updates a to round(a / q[b.level()]).
/// Expects a to be in the NTT domain.
pub fn div_round_by_last_modulus_inplace<const NTT: bool>(
&self,
buf: &mut PolyRNS<u64>,
a: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
let level = self.level();
let r_last: &Ring<u64> = &self.0[level];
let q_level_half: u64 = r_last.modulus.q >> 1;
let rescaling_constants: ScalarRNS<Barrett<u64>> = self.rescaling_constant();
let (buf_q_scaling, buf_qi_scaling) = buf.0.split_at_mut(1);
if NTT {
r_last.intt::<false>(a.at(level), &mut buf_q_scaling[0]);
r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.add_scalar::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.ntt_inplace::<false>(&mut buf_qi_scaling[0]);
r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>(
&buf_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
}
} else {
r_last.add_scalar_inplace::<ONCE>(&q_level_half, &mut buf_q_scaling[0]);
for (i, r) in self.0[0..level].iter().enumerate() {
r_last.add_scalar::<NONE>(
&buf_q_scaling[0],
&(r.modulus.q - r_last.modulus.barrett.reduce::<BARRETT>(&q_level_half)),
&mut buf_qi_scaling[0],
);
r.a_sub_b_mul_c_scalar_barrett_inplace::<2, ONCE>(
&buf_qi_scaling[0],
&rescaling_constants.0[i],
a.at_mut(i),
);
}
}
}
/// Updates b to round(a / prod_{level - nb_moduli}^{level} q[i])
pub fn div_round_by_last_moduli<const NTT: bool>(
&self,
nb_moduli: usize,
a: &PolyRNS<u64>,
buf: &mut PolyRNS<u64>,
c: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
c.level() >= a.level() - 1,
"invalid input b: b.level()={} < a.level()-1={}",
c.level(),
a.level() - 1
);
debug_assert!(
nb_moduli <= a.level(),
"invalid input nb_moduli: nb_moduli={} > a.level()={}",
nb_moduli,
a.level()
);
if nb_moduli == 0 {
if a != c {
c.copy(a);
}
} else {
if NTT {
self.intt::<false>(a, buf);
(0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_round_by_last_modulus_inplace::<false>(
&mut PolyRNS::<u64>::default(),
buf,
)
});
self.at_level(self.level() - nb_moduli).ntt::<false>(buf, c);
} else {
self.div_round_by_last_modulus::<false>(a, buf, c);
(1..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_round_by_last_modulus_inplace::<false>(buf, c)
});
}
}
}
/// Updates a to round(a / prod_{level - nb_moduli}^{level} q[i])
pub fn div_round_by_last_moduli_inplace<const NTT: bool>(
&self,
nb_moduli: usize,
buf: &mut PolyRNS<u64>,
a: &mut PolyRNS<u64>,
) {
debug_assert!(
self.level() <= a.level(),
"invalid input a: self.level()={} > a.level()={}",
self.level(),
a.level()
);
debug_assert!(
nb_moduli <= a.level(),
"invalid input nb_moduli: nb_moduli={} > a.level()={}",
nb_moduli,
a.level()
);
if NTT {
self.intt::<false>(a, buf);
(0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_round_by_last_modulus_inplace::<false>(&mut PolyRNS::<u64>::default(), buf)
.div_by_last_modulus_inplace::<ROUND, false>(&mut PolyRNS::<u64>::default(), buf)
});
self.at_level(self.level() - nb_moduli).ntt::<false>(buf, a);
} else {
(0..nb_moduli).for_each(|i| {
self.at_level(self.level() - i)
.div_round_by_last_modulus_inplace::<false>(buf, a)
.div_by_last_modulus_inplace::<ROUND, false>(buf, a)
});
}
}

View File

@@ -75,7 +75,7 @@ impl Ring<u64> {
impl Ring<u64> {
#[inline(always)]
pub fn add_inplace<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
pub fn a_add_b_into_b<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus
@@ -83,7 +83,7 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn add<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &Poly<u64>, c: &mut Poly<u64>) {
pub fn a_add_b_into_c<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &Poly<u64>, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
@@ -92,13 +92,13 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn add_scalar_inplace<const REDUCE: REDUCEMOD>(&self, b: &u64, a: &mut Poly<u64>) {
pub fn a_add_b_scalar_into_a<const REDUCE: REDUCEMOD>(&self, b: &u64, a: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
self.modulus.va_add_sb_into_va::<CHUNK, REDUCE>(b, &mut a.0);
}
#[inline(always)]
pub fn add_scalar<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) {
pub fn a_add_b_scalar_into_c<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus
@@ -106,7 +106,7 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn add_scalar_then_mul_scalar_barrett_inplace<const REDUCE: REDUCEMOD>(&self, b: &u64, c: &Barrett<u64>, a: &mut Poly<u64>) {
pub fn a_add_scalar_b_mul_c_scalar_barrett_into_a<const REDUCE: REDUCEMOD>(&self, b: &u64, c: &Barrett<u64>, a: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n());
self.modulus.va_add_sb_mul_sc_into_va::<CHUNK, REDUCE>(b, c, &mut a.0);
}
@@ -120,7 +120,7 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn sub_inplace<const BRANGE:u8, const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
pub fn a_sub_b_into_b<const BRANGE:u8, const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus
@@ -128,7 +128,15 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn sub<const BRANGE:u8, const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &Poly<u64>, c: &mut Poly<u64>) {
pub fn a_sub_b_into_a<const BRANGE:u8, const REDUCE: REDUCEMOD>(&self, b: &Poly<u64>, a: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus
.va_sub_vb_into_va::<CHUNK, BRANGE, REDUCE>(&b.0, &mut a.0);
}
#[inline(always)]
pub fn a_sub_b_into_c<const BRANGE:u8, const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &Poly<u64>, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
@@ -137,20 +145,20 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn neg<const ARANGE:u8, const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
pub fn a_neg_into_b<const ARANGE:u8, const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus.va_neg_into_vb::<CHUNK, ARANGE, REDUCE>(&a.0, &mut b.0);
}
#[inline(always)]
pub fn neg_inplace<const ARANGE:u8,const REDUCE: REDUCEMOD>(&self, a: &mut Poly<u64>) {
pub fn a_neg_into_a<const ARANGE:u8,const REDUCE: REDUCEMOD>(&self, a: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
self.modulus.va_neg_into_va::<CHUNK, ARANGE, REDUCE>(&mut a.0);
}
#[inline(always)]
pub fn mul_montgomery_external<const REDUCE: REDUCEMOD>(
pub fn a_mul_b_montgomery_into_c<const REDUCE: REDUCEMOD>(
&self,
a: &Poly<Montgomery<u64>>,
b: &Poly<u64>,
@@ -164,20 +172,20 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn mul_montgomery_external_inplace<const REDUCE: REDUCEMOD>(
pub fn a_mul_b_montgomery_into_a<const REDUCE: REDUCEMOD>(
&self,
a: &Poly<Montgomery<u64>>,
b: &mut Poly<u64>,
b: &Poly<Montgomery<u64>>,
a: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus
.va_mont_mul_vb_into_vb::<CHUNK, REDUCE>(&a.0, &mut b.0);
.va_mont_mul_vb_into_vb::<CHUNK, REDUCE>(&b.0, &mut a.0);
}
#[inline(always)]
pub fn mul_scalar<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "b.n()={} != n={}", a.n(), self.n());
pub fn a_mul_b_scalar_into_c<const REDUCE: REDUCEMOD>(&self, a: &Poly<u64>, b: &u64, c: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(c.n() == self.n(), "c.n()={} != n={}", c.n(), self.n());
self.modulus.sa_barrett_mul_vb_into_vc::<CHUNK, REDUCE>(
&self.modulus.barrett.prepare(*b),
@@ -187,30 +195,30 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut Poly<u64>) {
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
pub fn a_mul_b_scalar_into_a<const REDUCE: REDUCEMOD>(&self, b: &u64, a: &mut Poly<u64>) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
self.modulus.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(
&self
.modulus
.barrett
.prepare(self.modulus.barrett.reduce::<BARRETT>(a)),
&mut b.0,
.prepare(self.modulus.barrett.reduce::<BARRETT>(b)),
&mut a.0,
);
}
#[inline(always)]
pub fn mul_scalar_barrett_inplace<const REDUCE: REDUCEMOD>(
pub fn a_mul_b_scalar_barrett_into_a<const REDUCE: REDUCEMOD>(
&self,
a: &Barrett<u64>,
b: &mut Poly<u64>,
b: &Barrett<u64>,
a: &mut Poly<u64>,
) {
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
self.modulus
.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(a, &mut b.0);
.sa_barrett_mul_vb_into_vb::<CHUNK, REDUCE>(b, &mut a.0);
}
#[inline(always)]
pub fn mul_scalar_barrett<const REDUCE: REDUCEMOD>(
pub fn a_mul_b_scalar_barrett_into_c<const REDUCE: REDUCEMOD>(
&self,
a: &Barrett<u64>,
b: &Poly<u64>,
@@ -222,7 +230,7 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn a_sub_b_mul_c_scalar_barrett<const VBRANGE: u8, const REDUCE: REDUCEMOD>(
pub fn a_sub_b_mul_c_scalar_barrett_into_d<const VBRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
a: &Poly<u64>,
b: &Poly<u64>,
@@ -237,15 +245,46 @@ impl Ring<u64> {
}
#[inline(always)]
pub fn a_sub_b_mul_c_scalar_barrett_inplace<const BRANGE: u8, const REDUCE: REDUCEMOD>(
pub fn b_sub_a_mul_c_scalar_barrett_into_a<const BRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
a: &Poly<u64>,
b: &Poly<u64>,
c: &Barrett<u64>,
b: &mut Poly<u64>,
a: &mut Poly<u64>,
) {
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus
.va_sub_vb_mul_sc_into_vb::<CHUNK, BRANGE, REDUCE>(&a.0, c, &mut b.0);
.va_sub_vb_mul_sc_into_vb::<CHUNK, BRANGE, REDUCE>(&b.0, c, &mut a.0);
}
#[inline(always)]
pub fn a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e<const BRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
a: &Poly<u64>,
b: &Poly<u64>,
c: &u64,
d: &Barrett<u64>,
e: &mut Poly<u64>,
){
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
debug_assert!(e.n() == self.n(), "e.n()={} != n={}", e.n(), self.n());
self.modulus
.vb_sub_va_add_sc_mul_sd_into_ve::<CHUNK, BRANGE, REDUCE>(&a.0, &b.0, c, d, &mut e.0);
}
#[inline(always)]
pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a<const BRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
b: &Poly<u64>,
c: &u64,
d: &Barrett<u64>,
a: &mut Poly<u64>,
){
debug_assert!(a.n() == self.n(), "a.n()={} != n={}", a.n(), self.n());
debug_assert!(b.n() == self.n(), "b.n()={} != n={}", b.n(), self.n());
self.modulus
.vb_sub_va_add_sc_mul_sd_into_va::<CHUNK, BRANGE, REDUCE>(&b.0, c, d, &mut a.0);
}
}

View File

@@ -7,6 +7,8 @@ use crate::scalar::ScalarRNS;
use num_bigint::BigInt;
use std::sync::Arc;
impl RingRNS<u64> {
pub fn new(n: usize, moduli: Vec<u64>) -> Self {
assert!(!moduli.is_empty(), "moduli cannot be empty");
@@ -121,7 +123,7 @@ impl RingRNS<u64> {
impl RingRNS<u64> {
#[inline(always)]
pub fn add<const REDUCE: REDUCEMOD>(
pub fn a_add_b_into_c<const REDUCE: REDUCEMOD>(
&self,
a: &PolyRNS<u64>,
b: &PolyRNS<u64>,
@@ -148,11 +150,11 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.add::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i]));
.for_each(|(i, ring)| ring.a_add_b_into_c::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i]));
}
#[inline(always)]
pub fn add_inplace<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
pub fn a_add_b_into_b<const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
@@ -168,11 +170,11 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.add_inplace::<REDUCE>(&a.0[i], &mut b.0[i]));
.for_each(|(i, ring)| ring.a_add_b_into_b::<REDUCE>(&a.0[i], &mut b.0[i]));
}
#[inline(always)]
pub fn sub<const BRANGE: u8, const REDUCE: REDUCEMOD>(
pub fn a_sub_b_into_c<const BRANGE: u8, const REDUCE: REDUCEMOD>(
&self,
a: &PolyRNS<u64>,
b: &PolyRNS<u64>,
@@ -199,11 +201,11 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.sub::<BRANGE, REDUCE>(&a.0[i], &b.0[i], &mut c.0[i]));
.for_each(|(i, ring)| ring.a_sub_b_into_c::<BRANGE, REDUCE>(&a.0[i], &b.0[i], &mut c.0[i]));
}
#[inline(always)]
pub fn sub_inplace<const BRANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
pub fn a_sub_b_into_b<const BRANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
@@ -219,11 +221,11 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.sub_inplace::<BRANGE, REDUCE>(&a.0[i], &mut b.0[i]));
.for_each(|(i, ring)| ring.a_sub_b_into_b::<BRANGE, REDUCE>(&a.0[i], &mut b.0[i]));
}
#[inline(always)]
pub fn neg<const ARANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
pub fn a_sub_b_into_a<const BRANGE: u8, const REDUCE: REDUCEMOD>(&self, b: &PolyRNS<u64>, a: &mut PolyRNS<u64>) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
@@ -239,11 +241,31 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.neg::<ARANGE, REDUCE>(&a.0[i], &mut b.0[i]));
.for_each(|(i, ring)| ring.a_sub_b_into_a::<BRANGE, REDUCE>(&b.0[i], &mut a.0[i]));
}
#[inline(always)]
pub fn neg_inplace<const ARANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &mut PolyRNS<u64>) {
pub fn a_neg_into_b<const ARANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &mut PolyRNS<u64>) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.a_neg_into_b::<ARANGE, REDUCE>(&a.0[i], &mut b.0[i]));
}
#[inline(always)]
pub fn a_neg_into_a<const ARANGE: u8, const REDUCE: REDUCEMOD>(&self, a: &mut PolyRNS<u64>) {
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
@@ -253,7 +275,7 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.neg_inplace::<ARANGE, REDUCE>(&mut a.0[i]));
.for_each(|(i, ring)| ring.a_neg_into_a::<ARANGE, REDUCE>(&mut a.0[i]));
}
#[inline(always)]
@@ -282,7 +304,7 @@ impl RingRNS<u64> {
self.level()
);
self.0.iter().enumerate().for_each(|(i, ring)| {
ring.mul_montgomery_external::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i])
ring.a_mul_b_montgomery_into_c::<REDUCE>(&a.0[i], &b.0[i], &mut c.0[i])
});
}
@@ -305,7 +327,7 @@ impl RingRNS<u64> {
self.level()
);
self.0.iter().enumerate().for_each(|(i, ring)| {
ring.mul_montgomery_external_inplace::<REDUCE>(&a.0[i], &mut b.0[i])
ring.a_mul_b_montgomery_into_a::<REDUCE>(&a.0[i], &mut b.0[i])
});
}
@@ -331,11 +353,57 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.mul_scalar::<REDUCE>(&a.0[i], b, &mut c.0[i]));
.for_each(|(i, ring)| ring.a_mul_b_scalar_into_c::<REDUCE>(&a.0[i], b, &mut c.0[i]));
}
#[inline(always)]
pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, a: &u64, b: &mut PolyRNS<u64>) {
pub fn mul_scalar_inplace<const REDUCE: REDUCEMOD>(&self, b: &u64, a: &mut PolyRNS<u64>) {
debug_assert!(
a.level() >= self.level(),
"b.level()={} < self.level()={}",
a.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.a_mul_b_scalar_into_a::<REDUCE>(b, &mut a.0[i]));
}
#[inline(always)]
pub fn a_sub_b_add_scalar_mul_scalar_barrett_into_e<const BRANGE:u8, const REDUCE:REDUCEMOD>(&self, a: &PolyRNS<u64>, b: &PolyRNS<u64>, c: &u64, d: &Barrett<u64>, e: &mut PolyRNS<u64>){
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
b.level(),
self.level()
);
debug_assert!(
e.level() >= self.level(),
"e.level()={} < self.level()={}",
e.level(),
self.level()
);
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.a_sub_b_add_c_scalar_mul_d_scalar_barrett_into_e::<BRANGE, REDUCE>(&a.0[i], &b.0[i], c, d, &mut e.0[i]));
}
#[inline(always)]
pub fn b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a<const BRANGE:u8, const REDUCE:REDUCEMOD>(&self, b: &PolyRNS<u64>, c: &u64, d: &Barrett<u64>, a: &mut PolyRNS<u64>){
debug_assert!(
a.level() >= self.level(),
"a.level()={} < self.level()={}",
a.level(),
self.level()
);
debug_assert!(
b.level() >= self.level(),
"b.level()={} < self.level()={}",
@@ -345,6 +413,6 @@ impl RingRNS<u64> {
self.0
.iter()
.enumerate()
.for_each(|(i, ring)| ring.mul_scalar_inplace::<REDUCE>(a, &mut b.0[i]));
.for_each(|(i, ring)| ring.b_sub_a_add_c_scalar_mul_d_scalar_barrett_into_a::<BRANGE, REDUCE>(&b.0[i], c, d, &mut a.0[i]));
}
}

View File

@@ -1,6 +1,8 @@
use crate::modulus::WordOps;
use crate::poly::{Poly, PolyRNS};
use crate::ring::{Ring, RingRNS};
use num::ToPrimitive;
use rand_distr::{Normal, Distribution};
use sampling::source::Source;
impl Ring<u64> {
@@ -10,6 +12,24 @@ impl Ring<u64> {
a.0.iter_mut()
.for_each(|a| *a = source.next_u64n(max, mask));
}
pub fn fill_dist_f64<T: Distribution<f64>>(&self, source: &mut Source, dist: T, bound: f64, a: &mut Poly<u64>) {
let max: u64 = self.modulus.q;
a.0.iter_mut()
.for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound{
dist_f64 = dist.sample(source)
}
let dist_u64: u64 = (dist_f64+0.5).abs().to_u64().unwrap();
let sign: u64 = dist_f64.to_bits()>>63;
*a = (dist_u64 * sign) | (max-dist_u64)*(sign^1)
});
}
}
impl RingRNS<u64> {
@@ -19,4 +39,21 @@ impl RingRNS<u64> {
.enumerate()
.for_each(|(i, r)| r.fill_uniform(source, a.at_mut(i)));
}
pub fn fill_dist_f64<T: Distribution<f64>>(&self, source: &mut Source, dist: T, bound: f64, a: &mut PolyRNS<u64>) {
(0..a.n()).for_each(|j|{
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound{
dist_f64 = dist.sample(source)
}
let dist_u64: u64 = (dist_f64+0.5).abs().to_u64().unwrap();
let sign: u64 = dist_f64.to_bits()>>63;
self.0.iter().enumerate().for_each(|(i, r)|{
a.at_mut(i).0[j] = (dist_u64 * sign) | (r.modulus.q-dist_u64)*(sign^1);
})
})
}
}

View File

@@ -1,8 +1,9 @@
use math::poly::PolyRNS;
use math::ring::RingRNS;
use num_bigint::BigInt;
use num_bigint::Sign;
use math::num_bigint::Div;
use sampling::source::Source;
use itertools::izip;
#[test]
fn rescaling_rns_u64() {
@@ -10,17 +11,31 @@ fn rescaling_rns_u64() {
let moduli: Vec<u64> = vec![0x1fffffffffc80001u64, 0x1fffffffffe00001u64, 0x1fffffffffb40001, 0x1fffffffff500001];
let ring_rns: RingRNS<u64> = RingRNS::new(n, moduli);
test_div_floor_by_last_modulus::<false>(&ring_rns);
test_div_floor_by_last_modulus::<true>(&ring_rns);
test_div_floor_by_last_modulus_inplace::<false>(&ring_rns);
test_div_floor_by_last_modulus_inplace::<true>(&ring_rns);
test_div_floor_by_last_moduli::<false>(&ring_rns);
test_div_floor_by_last_moduli::<true>(&ring_rns);
test_div_floor_by_last_moduli_inplace::<false>(&ring_rns);
test_div_floor_by_last_moduli_inplace::<true>(&ring_rns);
sub_test("test_div_by_last_modulus::<ROUND:false, NTT:false>", ||{test_div_by_last_modulus::<false, false>(&ring_rns)});
sub_test("test_div_by_last_modulus::<ROUND:false, NTT:true>", ||{test_div_by_last_modulus::<false, true>(&ring_rns)});
sub_test("test_div_by_last_modulus::<ROUND:true, NTT:false>", ||{test_div_by_last_modulus::<true, false>(&ring_rns)});
sub_test("test_div_by_last_modulus::<ROUND:true, NTT:true>", ||{test_div_by_last_modulus::<true, true>(&ring_rns)});
sub_test("test_div_by_last_modulus_inplace::<ROUND:false, NTT:false>", ||{test_div_by_last_modulus_inplace::<false, false>(&ring_rns)});
sub_test("test_div_by_last_modulus_inplace::<ROUND:false, NTT:true>", ||{test_div_by_last_modulus_inplace::<false, true>(&ring_rns)});
sub_test("test_div_by_last_modulus_inplace::<ROUND:true, NTT:true>", ||{test_div_by_last_modulus_inplace::<true, true>(&ring_rns)});
sub_test("test_div_by_last_modulus_inplace::<ROUND:true, NTT:false>", ||{test_div_by_last_modulus_inplace::<true, false>(&ring_rns)});
//sub_test("test_div_by_last_moduli::<ROUND:false, NTT:false>", ||{test_div_by_last_moduli::<false, false>(&ring_rns)});
}
fn test_div_floor_by_last_modulus<const NTT: bool>(ring_rns: &RingRNS<u64>) {
fn sub_test<F: FnOnce()>(name: &str, f: F) {
println!("Running {}", name);
f();
}
fn test_div_by_last_modulus<const ROUND:bool, const NTT:bool>(ring_rns: &RingRNS<u64>){
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
@@ -42,7 +57,8 @@ fn test_div_floor_by_last_modulus<const NTT: bool>(ring_rns: &RingRNS<u64>) {
ring_rns.ntt_inplace::<false>(&mut a);
}
ring_rns.div_floor_by_last_modulus::<NTT>(&a, &mut b, &mut c);
ring_rns.div_by_last_modulus::<ROUND,NTT>(&a, &mut b, &mut c);
if NTT {
ring_rns.at_level(c.level()).intt_inplace::<false>(&mut c);
@@ -57,22 +73,23 @@ fn test_div_floor_by_last_modulus<const NTT: bool>(ring_rns: &RingRNS<u64>) {
// Performs floor division on a
let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q);
coeffs_a.iter_mut().for_each(|a| {
// Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1]
*a /= &scalar_big;
if a.sign() == Sign::Minus {
*a -= 1;
}
if ROUND{
*a = a.div_round(&scalar_big);
}else{
*a = a.div_floor(&scalar_big);
}
});
assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_modulus");
izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b));
}
fn test_div_floor_by_last_modulus_inplace<const NTT: bool>(ring_rns: &RingRNS<u64>) {
fn test_div_by_last_modulus_inplace<const ROUND:bool, const NTT:bool>(ring_rns: &RingRNS<u64>) {
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
let mut a: PolyRNS<u64> = ring_rns.new_polyrns();
let mut b: PolyRNS<u64> = ring_rns.new_polyrns();
let mut buf: PolyRNS<u64> = ring_rns.new_polyrns();
// Allocates a random PolyRNS
ring_rns.fill_uniform(&mut source, &mut a);
@@ -88,7 +105,7 @@ fn test_div_floor_by_last_modulus_inplace<const NTT: bool>(ring_rns: &RingRNS<u6
ring_rns.ntt_inplace::<false>(&mut a);
}
ring_rns.div_floor_by_last_modulus_inplace::<NTT>(&mut b, &mut a);
ring_rns.div_by_last_modulus_inplace::<ROUND,NTT>(&mut buf, &mut a);
if NTT {
ring_rns.at_level(a.level()-1).intt_inplace::<false>(&mut a);
@@ -103,24 +120,26 @@ fn test_div_floor_by_last_modulus_inplace<const NTT: bool>(ring_rns: &RingRNS<u6
// Performs floor division on a
let scalar_big = BigInt::from(ring_rns.0[ring_rns.level()].modulus.q);
coeffs_a.iter_mut().for_each(|a| {
// Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1]
*a /= &scalar_big;
if a.sign() == Sign::Minus {
*a -= 1;
}
if ROUND{
*a = a.div_round(&scalar_big);
}else{
*a = a.div_floor(&scalar_big);
}
});
assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_modulus_inplace");
izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b));
}
fn test_div_floor_by_last_moduli<const NTT: bool>(ring_rns: &RingRNS<u64>) {
fn test_div_by_last_moduli<const ROUND:bool, const NTT:bool>(ring_rns: &RingRNS<u64>){
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
let nb_moduli: usize = ring_rns.level();
let mut a: PolyRNS<u64> = ring_rns.new_polyrns();
let mut b: PolyRNS<u64> = ring_rns.new_polyrns();
let mut buf: PolyRNS<u64> = ring_rns.new_polyrns();
let mut c: PolyRNS<u64> = ring_rns.at_level(ring_rns.level() - nb_moduli).new_polyrns();
// Allocates a random PolyRNS
@@ -137,14 +156,14 @@ fn test_div_floor_by_last_moduli<const NTT: bool>(ring_rns: &RingRNS<u64>) {
ring_rns.ntt_inplace::<false>(&mut a);
}
ring_rns.div_floor_by_last_moduli::<NTT>(nb_moduli, &a, &mut b, &mut c);
ring_rns.div_by_last_moduli::<ROUND,NTT>(nb_moduli, &a, &mut buf, &mut c);
if NTT {
ring_rns.at_level(c.level()).intt_inplace::<false>(&mut c);
}
// Exports c to coeffs_c
let mut coeffs_c = vec![BigInt::from(0); c.n()];
let mut coeffs_c = vec![BigInt::from(0); a.n()];
ring_rns
.at_level(c.level())
.to_bigint_inplace(&c, 1, &mut coeffs_c);
@@ -152,18 +171,18 @@ fn test_div_floor_by_last_moduli<const NTT: bool>(ring_rns: &RingRNS<u64>) {
// Performs floor division on a
let mut scalar_big = BigInt::from(1);
(0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()-i].modulus.q)});
coeffs_a.iter_mut().for_each(|a| {
// Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1]
*a /= &scalar_big;
if a.sign() == Sign::Minus {
*a -= 1;
}
if ROUND{
*a = a.div_round(&scalar_big);
}else{
*a = a.div_floor(&scalar_big);
}
});
assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_moduli");
izip!(coeffs_a, coeffs_c).for_each(|(a, b)| assert_eq!(a, b));
}
/*
fn test_div_floor_by_last_moduli_inplace<const NTT: bool>(ring_rns: &RingRNS<u64>) {
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
@@ -202,13 +221,8 @@ fn test_div_floor_by_last_moduli_inplace<const NTT: bool>(ring_rns: &RingRNS<u64
// Performs floor division on a
let mut scalar_big = BigInt::from(1);
(0..nb_moduli).for_each(|i|{scalar_big *= BigInt::from(ring_rns.0[ring_rns.level()-i].modulus.q)});
coeffs_a.iter_mut().for_each(|a| {
// Emulates floor division in [0, q-1] and maps to [-(q-1)/2, (q-1)/2-1]
*a /= &scalar_big;
if a.sign() == Sign::Minus {
*a -= 1;
}
});
coeffs_a.iter_mut().for_each(|a| {a.div_floor(&scalar_big)});
assert!(coeffs_a == coeffs_c, "test_div_floor_by_last_moduli_inplace");
}
}
*/

View File

@@ -21,11 +21,6 @@ impl Source {
seed
}
#[inline(always)]
pub fn next_u64(&mut self) -> u64 {
self.source.next_u64()
}
#[inline(always)]
pub fn next_u64n(&mut self, max: u64, mask: u64) -> u64 {
let mut x: u64 = self.next_u64() & mask;
@@ -39,9 +34,26 @@ impl Source {
pub fn next_f64(&mut self, min: f64, max: f64) -> f64 {
min + ((self.next_u64() << 11 >> 11) as f64) / MAXF64 * (max - min)
}
}
impl RngCore for Source{
#[inline(always)]
fn next_u32(&mut self) -> u32 {
self.source.next_u32()
}
#[inline(always)]
pub fn fill_bytes(&mut self, bytes: &mut [u8]) {
fn next_u64(&mut self) -> u64 {
self.source.next_u64()
}
#[inline(always)]
fn fill_bytes(&mut self, bytes: &mut [u8]) {
self.source.fill_bytes(bytes)
}
}
#[inline(always)]
fn try_fill_bytes(&mut self, bytes: &mut [u8]) -> Result<(), rand_core::Error>{
self.source.try_fill_bytes(bytes)
}
}