Added base for Montgomery arithmetic

This commit is contained in:
Jean-Philippe Bossuat
2024-12-04 12:53:13 +01:00
parent a957701614
commit ee96c2f904
9 changed files with 353 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/target

16
Cargo.lock generated Normal file
View File

@@ -0,0 +1,16 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "math"
version = "0.1.0"
dependencies = [
"primality-test",
]
[[package]]
name = "primality-test"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98439e9658b9548a33abdab8c82532554dc08e49ddc5398a9262222fb360ae24"

7
Cargo.toml Normal file
View File

@@ -0,0 +1,7 @@
[package]
name = "math"
version = "0.1.0"
edition = "2021"
[dependencies]
primality-test = "0.3.0"

3
src/lib.rs Normal file
View File

@@ -0,0 +1,3 @@
#![feature(bigint_helper_methods)]
pub mod modulus;

24
src/modulus.rs Normal file
View File

@@ -0,0 +1,24 @@
pub(crate) mod prime;
pub(crate) mod montgomery;
pub(crate) mod barrett;
trait ReduceOnce<O>{
fn reduce_once_assign(&mut self, q: O);
fn reduce_once(&self, q:O) -> O;
}
impl ReduceOnce<u64> for u64{
fn reduce_once_assign(&mut self, q: u64){
if *self >= q{
*self -= q
}
}
fn reduce_once(&self, q:u64) -> u64{
if *self >= q {
*self - q
} else {
*self
}
}
}

19
src/modulus/barrett.rs Normal file
View File

@@ -0,0 +1,19 @@
pub struct BarrettPrecomp<O>(O, O);
impl<O> BarrettPrecomp<O>{
#[inline(always)]
pub fn new(a:O, b: O) -> Self{
Self(a, b)
}
#[inline(always)]
pub fn value_hi(&self) -> &O{
&self.0
}
#[inline(always)]
pub fn value_lo(&self) -> &O{
&self.1
}
}

113
src/modulus/montgomery.rs Normal file
View File

@@ -0,0 +1,113 @@
use crate::modulus::barrett::BarrettPrecomp;
use crate::modulus::ReduceOnce;
pub struct Montgomery<O>(O);
impl<O> Montgomery<O>{
#[inline(always)]
pub fn new(lhs: O) -> Self{
Self(lhs)
}
#[inline(always)]
pub fn value(&self) -> &O{
&self.0
}
pub fn value_mut(&mut self) -> &mut O{
&mut self.0
}
}
pub struct MontgomeryPrecomp<O>{
q: O,
q_barrett: BarrettPrecomp<O>,
q_inv: O,
}
impl MontgomeryPrecomp<u64>{
#[inline(always)]
fn new(&self, q: u64) -> MontgomeryPrecomp<u64>{
let mut r: u64 = 1;
let mut q_pow = q;
for _i in 0..63{
r = r.wrapping_mul(r);
q_pow = q_pow.wrapping_mul(q_pow)
}
Self{ q: q, q_barrett: BarrettPrecomp::new(q, q), q_inv: q_pow}
}
#[inline(always)]
fn prepare(&self, lhs: u64) -> Montgomery<u64>{
let mut rhs = Montgomery(0);
self.prepare_assign(lhs, &mut rhs);
rhs
}
fn prepare_assign(&self, lhs: u64, rhs: &mut Montgomery<u64>){
self.prepare_lazy_assign(lhs, rhs);
rhs.value_mut().reduce_once_assign(self.q);
}
#[inline(always)]
fn prepare_lazy(&self, lhs: u64) -> Montgomery<u64>{
let mut rhs = Montgomery(0);
self.prepare_lazy_assign(lhs, &mut rhs);
rhs
}
fn prepare_lazy_assign(&self, lhs: u64, rhs: &mut Montgomery<u64>){
let (mhi, _) = lhs.widening_mul(*self.q_barrett.value_lo());
*rhs = Montgomery((lhs.wrapping_mul(*self.q_barrett.value_hi()).wrapping_add(mhi)).wrapping_mul(self.q).wrapping_neg());
}
#[inline(always)]
fn mul_external(&self, lhs: Montgomery<u64>, rhs: u64) -> u64{
let mut r = self.mul_external_lazy(lhs, rhs);
r.reduce_once_assign(self.q);
r
}
#[inline(always)]
fn mul_external_assign(&self, lhs: Montgomery<u64>, rhs: &mut u64){
self.mul_external_lazy_assign(lhs, rhs);
rhs.reduce_once_assign(self.q);
}
#[inline(always)]
fn mul_external_lazy(&self, lhs: Montgomery<u64>, rhs: u64) -> u64{
let mut result = rhs;
self.mul_external_lazy_assign(lhs, &mut result);
result
}
#[inline(always)]
fn mul_external_lazy_assign(&self, lhs: Montgomery<u64>, rhs: &mut u64){
let (mhi, mlo) = lhs.value().widening_mul(*rhs);
let (hhi, _) = self.q.widening_mul(mlo * self.q_inv);
*rhs = mhi - hhi + self.q
}
#[inline(always)]
fn mul_internal(&self, lhs: Montgomery<u64>, rhs: Montgomery<u64>) -> Montgomery<u64>{
Montgomery(self.mul_external(lhs, *rhs.value()))
}
#[inline(always)]
fn mul_internal_assign(&self, lhs: Montgomery<u64>, rhs: &mut Montgomery<u64>){
self.mul_external_assign(lhs, rhs.value_mut());
}
#[inline(always)]
fn mul_internal_lazy(&self, lhs: Montgomery<u64>, rhs: Montgomery<u64>) -> Montgomery<u64>{
Montgomery(self.mul_external_lazy(lhs, *rhs.value()))
}
#[inline(always)]
fn mul_internal_lazy_assign(&self, lhs: Montgomery<u64>, rhs: &mut Montgomery<u64>){
self.mul_external_lazy_assign(lhs, rhs.value_mut());
}
}

19
src/modulus/prime.rs Normal file
View File

@@ -0,0 +1,19 @@
use primality_test::is_prime;
pub struct Prime {
q: u64,
}
impl Prime {
pub fn new(q: u64) -> Self{
assert!(is_prime(q) && q > 2);
Self::new_unchecked(q)
}
pub fn new_unchecked(q: u64) -> Self {
assert!(q.next_power_of_two().ilog2() <= 61);
Self {
q,
}
}
}

View File

@@ -0,0 +1,151 @@
use crate::modulus::prime;
use prime::Prime;
use primality_test::is_prime;
pub struct NTTFriendlyPrimesGenerator{
size: f64,
next_prime: u64,
prev_prime: u64,
nth_root: u64,
check_next_prime: bool,
check_prev_prime: bool,
}
impl NTTFriendlyPrimesGenerator {
pub fn new(bit_size: u64, nth_root: u64) -> Self{
let mut check_next_prime: bool = true;
let mut check_prev_prime: bool = true;
let next_prime = (1<<bit_size) + 1;
let mut prev_prime = next_prime;
if next_prime > 0xffff_ffff_ffff_ffff-nth_root{
check_next_prime = false;
}
if prev_prime < nth_root{
check_prev_prime = false
}
prev_prime -= nth_root;
Self{
size: bit_size as f64,
check_next_prime,
check_prev_prime,
nth_root,
next_prime,
prev_prime,
}
}
pub fn next_upstream_primes(&mut self, k: usize) -> Vec<Prime>{
let mut primes: Vec<Prime> = Vec::with_capacity(k);
for i in 0..k{
primes.push(self.next_upstream_prime())
}
primes
}
pub fn next_downstream_primes(&mut self, k: usize) -> Vec<Prime>{
let mut primes: Vec<Prime> = Vec::with_capacity(k);
for i in 0..k{
primes.push(self.next_downstream_prime())
}
primes
}
pub fn next_alternating_primes(&mut self, k: usize) -> Vec<Prime>{
let mut primes: Vec<Prime> = Vec::with_capacity(k);
for i in 0..k{
primes.push(self.next_alternating_prime())
}
primes
}
pub fn next_upstream_prime(&mut self) -> Prime{
loop {
if self.check_next_prime{
if (self.next_prime as f64).log2() - self.size >= 0.5 || self.next_prime > 0xffff_ffff_ffff_ffff-self.nth_root{
self.check_next_prime = false;
panic!("prime list for upstream primes is exhausted (overlap with next bit-size or prime > 2^64)");
}
}else{
if is_prime(self.next_prime) {
let prime = Prime::new_unchecked(self.next_prime);
self.next_prime += self.nth_root;
return prime
}
self.next_prime += self.nth_root;
}
}
}
pub fn next_downstream_prime(&mut self) -> Prime{
loop {
if self.size - (self.prev_prime as f64).log2() >= 0.5 || self.prev_prime < self.nth_root{
self.check_next_prime = false;
panic!("prime list for downstream primes is exhausted (overlap with previous bit-size or prime < nth_root)")
}else{
if is_prime(self.prev_prime){
let prime = Prime::new_unchecked(self.next_prime);
self.prev_prime -= self.nth_root;
return prime
}
self.prev_prime -= self.nth_root;
}
}
}
pub fn next_alternating_prime(&mut self) -> Prime{
loop {
if !(self.check_next_prime || self.check_prev_prime){
panic!("prime list for upstream and downstream prime is exhausted for the (overlap with previous/next bit-size or NthRoot > prime > 2^64)")
}
if self.check_next_prime{
if (self.next_prime as f64).log2() - self.size >= 0.5 || self.next_prime > 0xffff_ffff_ffff_ffff-self.nth_root{
self.check_next_prime = false;
}else{
if is_prime(self.next_prime){
let prime = Prime::new_unchecked(self.next_prime);
self.next_prime += self.nth_root;
return prime
}
self.next_prime += self.nth_root;
}
}
if self.check_prev_prime {
if self.size - (self.prev_prime as f64).log2() >= 0.5 || self.prev_prime < self.nth_root{
self.check_prev_prime = false;
}else{
if is_prime(self.prev_prime){
let prime = Prime::new_unchecked(self.prev_prime);
self.prev_prime -= self.nth_root;
return prime
}
self.prev_prime -= self.nth_root;
}
}
}
}
}
#[cfg(test)]
mod test {
use crate::modulus::prime_generator;
#[test]
fn prime_generation() {
let nth_root: u64 = 1<<16 ;
let mut g: prime_generator::NTTFriendlyPrimesGenerator = prime_generator::NTTFriendlyPrimesGenerator::new(30, nth_root);
let primes = g.next_alternating_primes(10);
println!("{:?}", primes);
for prime in primes.iter(){
assert!(prime.q() % nth_root == 1);
}
}
}