diff --git a/ff/arith.go b/ff/arith.go index 938c87a..86f0348 100644 --- a/ff/arith.go +++ b/ff/arith.go @@ -12,14 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by goff DO NOT EDIT +// Code generated by goff (v0.2.0) DO NOT EDIT +// Package ff contains field arithmetic operations package ff import ( "math/bits" + + "golang.org/x/sys/cpu" ) +var supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 + func madd(a, b, t, u, v uint64) (uint64, uint64, uint64) { var carry uint64 hi, lo := bits.Mul64(a, b) diff --git a/ff/element.go b/ff/element.go index 60b4e6b..7f016e9 100644 --- a/ff/element.go +++ b/ff/element.go @@ -12,29 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -// field modulus q = -// -// 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// Code generated by goff DO NOT EDIT -// goff version: - build: -// Element are assumed to be in Montgomery form in all methods +// Code generated by goff (v0.2.0) DO NOT EDIT -// Package ff (generated by goff) contains field arithmetics operations +// Package ff contains field arithmetic operations package ff +// /!\ WARNING /!\ +// this code has not been audited and is provided as-is. In particular, +// there is no security guarantees such as constant time implementation +// or side-channel attack resistance +// /!\ WARNING /!\ + import ( "crypto/rand" "encoding/binary" "io" "math/big" "math/bits" + "strconv" "sync" - "unsafe" ) // Element represents a field element stored on 4 words (uint64) // Element are assumed to be in Montgomery form in all methods +// field modulus q = +// +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 type Element [4]uint64 // ElementLimbs number of 64 bits words needed to represent Element @@ -311,6 +315,7 @@ func (z *Element) SetRandom() *Element { z[3] %= 3486998266802970665 // if z > q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -322,6 +327,38 @@ func (z *Element) SetRandom() *Element { return z } +// One returns 1 (in montgommery form) +func One() Element { + var one Element + one.SetOne() + return one +} + +// FromInterface converts i1 from uint64, int, string, or Element, big.Int into Element +// panic if provided type is not supported +func FromInterface(i1 interface{}) Element { + var val Element + + switch c1 := i1.(type) { + case uint64: + val.SetUint64(c1) + case int: + val.SetString(strconv.Itoa(c1)) + case string: + val.SetString(c1) + case big.Int: + val.SetBigInt(&c1) + case Element: + val = c1 + case *Element: + val.Set(c1) + default: + panic("invalid type") + } + + return val +} + // Add z = x + y mod q func (z *Element) Add(x, y *Element) *Element { var carry uint64 @@ -332,6 +369,7 @@ func (z *Element) Add(x, y *Element) *Element { z[3], _ = bits.Add64(x[3], y[3], carry) // if z > q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -352,6 +390,7 @@ func (z *Element) AddAssign(x *Element) *Element { z[3], _ = bits.Add64(z[3], x[3], carry) // if z > q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -372,6 +411,7 @@ func (z *Element) Double(x *Element) *Element { z[3], _ = bits.Add64(x[3], x[3], carry) // if z > q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -416,18 +456,31 @@ func (z *Element) SubAssign(x *Element) *Element { return z } -// Exp z = x^e mod q -func (z *Element) Exp(x Element, e uint64) *Element { - if e == 0 { +// Exp z = x^exponent mod q +// (not optimized) +// exponent (non-montgomery form) is ordered from least significant word to most significant word +func (z *Element) Exp(x Element, exponent ...uint64) *Element { + r := 0 + msb := 0 + for i := len(exponent) - 1; i >= 0; i-- { + if exponent[i] == 0 { + r++ + } else { + msb = (i * 64) + bits.Len64(exponent[i]) + break + } + } + exponent = exponent[:len(exponent)-r] + if len(exponent) == 0 { return z.SetOne() } z.Set(&x) - l := bits.Len64(e) - 2 + l := msb - 2 for i := l; i >= 0; i-- { z.Square(z) - if e&(1< q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -549,6 +603,19 @@ func (z *Element) SetBigInt(v *big.Int) *Element { zero := big.NewInt(0) q := elementModulusBigInt() + // fast path + c := v.Cmp(q) + if c == 0 { + return z + } else if c != 1 && v.Cmp(zero) != -1 { + // v should + vBits := v.Bits() + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + return z.ToMont() + } + // copy input vv := new(big.Int).Set(v) @@ -591,202 +658,97 @@ func (z *Element) SetString(s string) *Element { return z.SetBigInt(x) } -// Mul z = x * y mod q -func (z *Element) Mul(x, y *Element) *Element { +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.Exp(*z, + 11669102379873075200, + 10671829228508198984, + 15863968012492123182, + 1743499133401485332, + ) - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + if l.IsZero() { + return 0 } - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + // if l == 1 + if (l[3] == 1011752739694698287) && (l[2] == 7381016538464732718) && (l[1] == 3962172157175319849) && (l[0] == 12436184717236109307) { + return 1 } - return z + return -1 } -// MulAssign z = z * x mod q -func (z *Element) MulAssign(x *Element) *Element { - - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := z[0] - c[1], c[0] = bits.Mul64(v, x[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd1(v, x[1], c[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd1(v, x[2], c[1]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd1(v, x[3], c[1]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 1 - v := z[1] - c[1], c[0] = madd1(v, x[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, x[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, x[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, x[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 2 - v := z[2] - c[1], c[0] = madd1(v, x[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, x[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, x[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, x[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) +// Sqrt z = √x mod q +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.Exp(*x, + 14829091926808964255, + 867720185306366531, + 688207751544974772, + 6495040407, + ) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = x^s = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 7164790868263648668, + 11685701338293206998, + 6216421865291908056, + 1756667274303109607, + } + r := uint64(28) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of x^s + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) + } + if t.IsZero() { + return z.SetZero() } - { - // round 3 - v := z[3] - c[1], c[0] = madd1(v, x[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, x[1], c[1], t[1]) - c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, x[2], c[1], t[2]) - c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, x[3], c[1], t[3]) - z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + if !((t[3] == 1011752739694698287) && (t[2] == 7381016538464732718) && (t[1] == 3962172157175319849) && (t[0] == 12436184717236109307)) { + // t != 1, we don't have a square root + return nil } + for { + var m uint64 + t = b - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) - } - return z -} - -// Square z = x * x mod q -func (z *Element) Square(x *Element) *Element { - - var p [4]uint64 + // for t != 1 + for !((t[3] == 1011752739694698287) && (t[2] == 7381016538464732718) && (t[1] == 3962172157175319849) && (t[0] == 12436184717236109307)) { + t.Square(&t) + m++ + } - var u, v uint64 - { - // round 0 - u, p[0] = bits.Mul64(x[0], x[0]) - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - var t uint64 - t, u, v = madd1sb(x[0], x[1], u) - C, p[0] = madd2(m, 2896914383306846353, v, C) - t, u, v = madd1s(x[0], x[2], t, u) - C, p[1] = madd2(m, 13281191951274694749, v, C) - _, u, v = madd1s(x[0], x[3], t, u) - p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) - } - { - // round 1 - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - u, v = madd1(x[1], x[1], p[1]) - C, p[0] = madd2(m, 2896914383306846353, v, C) - var t uint64 - t, u, v = madd2sb(x[1], x[2], p[2], u) - C, p[1] = madd2(m, 13281191951274694749, v, C) - _, u, v = madd2s(x[1], x[3], p[3], t, u) - p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) - } - { - // round 2 - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - C, p[0] = madd2(m, 2896914383306846353, p[1], C) - u, v = madd1(x[2], x[2], p[2]) - C, p[1] = madd2(m, 13281191951274694749, v, C) - _, u, v = madd2sb(x[2], x[3], p[3], u) - p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) - } - { - // round 3 - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - C, z[0] = madd2(m, 2896914383306846353, p[1], C) - C, z[1] = madd2(m, 13281191951274694749, p[2], C) - u, v = madd1(x[3], x[3], p[3]) - z[3], z[2] = madd3(m, 3486998266802970665, v, C, u) - } + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) mod q + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + g.Square(&t) + y.MulAssign(&t) + b.MulAssign(&g) + r = m } - return z } diff --git a/ff/element_mul.go b/ff/element_mul.go new file mode 100644 index 0000000..fbc6ef7 --- /dev/null +++ b/ff/element_mul.go @@ -0,0 +1,170 @@ +// +build !amd64 + +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by goff (v0.2.0) DO NOT EDIT + +// Package ff contains field arithmetic operations +package ff + +// /!\ WARNING /!\ +// this code has not been audited and is provided as-is. In particular, +// there is no security guarantees such as constant time implementation +// or side-channel attack resistance +// /!\ WARNING /!\ + +import "math/bits" + +// Mul z = x * y mod q +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) Mul(x, y *Element) *Element { + + var t [4]uint64 + var c [3]uint64 + { + // round 0 + v := x[0] + c[1], c[0] = bits.Mul64(v, y[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd1(v, y[1], c[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd1(v, y[2], c[1]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd1(v, y[3], c[1]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 1 + v := x[1] + c[1], c[0] = madd1(v, y[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, y[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, y[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, y[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 2 + v := x[2] + c[1], c[0] = madd1(v, y[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, y[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, y[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, y[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 3 + v := x[3] + c[1], c[0] = madd1(v, y[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, y[1], c[1], t[1]) + c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, y[2], c[1], t[2]) + c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, y[3], c[1], t[3]) + z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// MulAssign z = z * x mod q +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) MulAssign(x *Element) *Element { + + var t [4]uint64 + var c [3]uint64 + { + // round 0 + v := z[0] + c[1], c[0] = bits.Mul64(v, x[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd1(v, x[1], c[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd1(v, x[2], c[1]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd1(v, x[3], c[1]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 1 + v := z[1] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 2 + v := z[2] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 3 + v := z[3] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} diff --git a/ff/element_mul_amd64.go b/ff/element_mul_amd64.go new file mode 100644 index 0000000..73a002f --- /dev/null +++ b/ff/element_mul_amd64.go @@ -0,0 +1,39 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by goff (v0.2.0) DO NOT EDIT + +// Package ff contains field arithmetic operations +package ff + +// MulAssignElement z = z * x mod q (constant time) +// calling this instead of z.MulAssign(x) is prefered for performance critical path +//go:noescape +func MulAssignElement(res, y *Element) + +// Mul z = x * y mod q (constant time) +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) Mul(x, y *Element) *Element { + res := *x + MulAssignElement(&res, y) + z.Set(&res) + return z +} + +// MulAssign z = z * x mod q (constant time) +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) MulAssign(x *Element) *Element { + MulAssignElement(z, x) + return z +} diff --git a/ff/element_mul_amd64.s b/ff/element_mul_amd64.s new file mode 100644 index 0000000..d7e0be9 --- /dev/null +++ b/ff/element_mul_amd64.s @@ -0,0 +1,695 @@ +// Code generated by goff (v0.2.0) DO NOT EDIT + +#include "textflag.h" + +// func MulAssignElement(res,y *Element) +// montgomery multiplication of res by y +// stores the result in res +TEXT ·MulAssignElement(SB), NOSPLIT, $0-16 + + // dereference our parameters + MOVQ res+0(FP), DI + MOVQ y+8(FP), R8 + + // check if we support adx and mulx + CMPB ·supportAdx(SB), $1 + JNE no_adx + + // the algorithm is described here + // https://hackmd.io/@zkteam/modular_multiplication + // however, to benefit from the ADCX and ADOX carry chains + // we split the inner loops in 2: + // for i=0 to N-1 + // for j=0 to N-1 + // (A,t[j]) := t[j] + a[j]*b[i] + A + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + A + + // --------------------------------------------------------------------------------------------- + // outter loop 0 + + // clear up the carry flags + XORQ R9 , R9 + + // R12 = y[0] + MOVQ 0(R8), R12 + + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + + // DX = res[0] + MOVQ 0(DI), DX + MULXQ R12, CX , R9 + + // DX = res[1] + MOVQ 8(DI), DX + MOVQ R9, BX + MULXQ R12, AX, R9 + ADOXQ AX, BX + + // DX = res[2] + MOVQ 16(DI), DX + MOVQ R9, BP + MULXQ R12, AX, R9 + ADOXQ AX, BP + + // DX = res[3] + MOVQ 24(DI), DX + MOVQ R9, SI + MULXQ R12, AX, R9 + ADOXQ AX, SI + + // add the last carries to R9 + MOVQ $0, DX + ADCXQ DX, R9 + ADOXQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, DX + MULXQ CX,R11, DX + + // clear the carry flags + XORQ DX, DX + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, DX + MULXQ R11, AX, R10 + ADCXQ CX ,AX + + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + + MOVQ $0x2833e84879b97091, DX + MULXQ R11, AX, DX + ADCXQ BX, R10 + ADOXQ AX, R10 + MOVQ R10, CX + MOVQ DX, R10 + + MOVQ $0xb85045b68181585d, DX + MULXQ R11, AX, DX + ADCXQ BP, R10 + ADOXQ AX, R10 + MOVQ R10, BX + MOVQ DX, R10 + + MOVQ $0x30644e72e131a029, DX + MULXQ R11, AX, DX + ADCXQ SI, R10 + ADOXQ AX, R10 + MOVQ R10, BP + MOVQ $0, AX + ADCXQ AX, DX + ADOXQ DX, R9 + MOVQ R9, SI + + // --------------------------------------------------------------------------------------------- + // outter loop 1 + + // clear up the carry flags + XORQ R9 , R9 + + // R12 = y[1] + MOVQ 8(R8), R12 + + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + + // DX = res[0] + MOVQ 0(DI), DX + MULXQ R12, AX, R9 + ADOXQ AX, CX + + // DX = res[1] + MOVQ 8(DI), DX + ADCXQ R9, BX + MULXQ R12, AX, R9 + ADOXQ AX, BX + + // DX = res[2] + MOVQ 16(DI), DX + ADCXQ R9, BP + MULXQ R12, AX, R9 + ADOXQ AX, BP + + // DX = res[3] + MOVQ 24(DI), DX + ADCXQ R9, SI + MULXQ R12, AX, R9 + ADOXQ AX, SI + + // add the last carries to R9 + MOVQ $0, DX + ADCXQ DX, R9 + ADOXQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, DX + MULXQ CX,R11, DX + + // clear the carry flags + XORQ DX, DX + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, DX + MULXQ R11, AX, R10 + ADCXQ CX ,AX + + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + + MOVQ $0x2833e84879b97091, DX + MULXQ R11, AX, DX + ADCXQ BX, R10 + ADOXQ AX, R10 + MOVQ R10, CX + MOVQ DX, R10 + + MOVQ $0xb85045b68181585d, DX + MULXQ R11, AX, DX + ADCXQ BP, R10 + ADOXQ AX, R10 + MOVQ R10, BX + MOVQ DX, R10 + + MOVQ $0x30644e72e131a029, DX + MULXQ R11, AX, DX + ADCXQ SI, R10 + ADOXQ AX, R10 + MOVQ R10, BP + MOVQ $0, AX + ADCXQ AX, DX + ADOXQ DX, R9 + MOVQ R9, SI + + // --------------------------------------------------------------------------------------------- + // outter loop 2 + + // clear up the carry flags + XORQ R9 , R9 + + // R12 = y[2] + MOVQ 16(R8), R12 + + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + + // DX = res[0] + MOVQ 0(DI), DX + MULXQ R12, AX, R9 + ADOXQ AX, CX + + // DX = res[1] + MOVQ 8(DI), DX + ADCXQ R9, BX + MULXQ R12, AX, R9 + ADOXQ AX, BX + + // DX = res[2] + MOVQ 16(DI), DX + ADCXQ R9, BP + MULXQ R12, AX, R9 + ADOXQ AX, BP + + // DX = res[3] + MOVQ 24(DI), DX + ADCXQ R9, SI + MULXQ R12, AX, R9 + ADOXQ AX, SI + + // add the last carries to R9 + MOVQ $0, DX + ADCXQ DX, R9 + ADOXQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, DX + MULXQ CX,R11, DX + + // clear the carry flags + XORQ DX, DX + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, DX + MULXQ R11, AX, R10 + ADCXQ CX ,AX + + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + + MOVQ $0x2833e84879b97091, DX + MULXQ R11, AX, DX + ADCXQ BX, R10 + ADOXQ AX, R10 + MOVQ R10, CX + MOVQ DX, R10 + + MOVQ $0xb85045b68181585d, DX + MULXQ R11, AX, DX + ADCXQ BP, R10 + ADOXQ AX, R10 + MOVQ R10, BX + MOVQ DX, R10 + + MOVQ $0x30644e72e131a029, DX + MULXQ R11, AX, DX + ADCXQ SI, R10 + ADOXQ AX, R10 + MOVQ R10, BP + MOVQ $0, AX + ADCXQ AX, DX + ADOXQ DX, R9 + MOVQ R9, SI + + // --------------------------------------------------------------------------------------------- + // outter loop 3 + + // clear up the carry flags + XORQ R9 , R9 + + // R12 = y[3] + MOVQ 24(R8), R12 + + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + + // DX = res[0] + MOVQ 0(DI), DX + MULXQ R12, AX, R9 + ADOXQ AX, CX + + // DX = res[1] + MOVQ 8(DI), DX + ADCXQ R9, BX + MULXQ R12, AX, R9 + ADOXQ AX, BX + + // DX = res[2] + MOVQ 16(DI), DX + ADCXQ R9, BP + MULXQ R12, AX, R9 + ADOXQ AX, BP + + // DX = res[3] + MOVQ 24(DI), DX + ADCXQ R9, SI + MULXQ R12, AX, R9 + ADOXQ AX, SI + + // add the last carries to R9 + MOVQ $0, DX + ADCXQ DX, R9 + ADOXQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, DX + MULXQ CX,R11, DX + + // clear the carry flags + XORQ DX, DX + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, DX + MULXQ R11, AX, R10 + ADCXQ CX ,AX + + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + + MOVQ $0x2833e84879b97091, DX + MULXQ R11, AX, DX + ADCXQ BX, R10 + ADOXQ AX, R10 + MOVQ R10, CX + MOVQ DX, R10 + + MOVQ $0xb85045b68181585d, DX + MULXQ R11, AX, DX + ADCXQ BP, R10 + ADOXQ AX, R10 + MOVQ R10, BX + MOVQ DX, R10 + + MOVQ $0x30644e72e131a029, DX + MULXQ R11, AX, DX + ADCXQ SI, R10 + ADOXQ AX, R10 + MOVQ R10, BP + MOVQ $0, AX + ADCXQ AX, DX + ADOXQ DX, R9 + MOVQ R9, SI + + reduce: + // reduce, constant time version + // first we copy registers storing t in a separate set of registers + // as SUBQ modifies the 2nd operand + MOVQ CX, DX + MOVQ BX, R8 + MOVQ BP, R9 + MOVQ SI, R10 + MOVQ $0x43e1f593f0000001, R11 + SUBQ R11, DX + MOVQ $0x2833e84879b97091, R11 + SBBQ R11, R8 + MOVQ $0xb85045b68181585d, R11 + SBBQ R11, R9 + MOVQ $0x30644e72e131a029, R11 + SBBQ R11, R10 + JCS t_is_smaller // no borrow, we return t + + // borrow is set, we return u + MOVQ DX, (DI) + MOVQ R8, 8(DI) + MOVQ R9, 16(DI) + MOVQ R10, 24(DI) + RET + t_is_smaller: + MOVQ CX, 0(DI) + MOVQ BX, 8(DI) + MOVQ BP, 16(DI) + MOVQ SI, 24(DI) + RET + + no_adx: + + // --------------------------------------------------------------------------------------------- + // outter loop 0 + + // (A,t[0]) := t[0] + x[0]*y[0] + MOVQ (DI), AX // x[0] + MOVQ 0(R8), R12 + MULQ R12 // x[0] * y[0] + MOVQ DX, R9 + MOVQ AX, CX + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, R11 + IMULQ CX , R11 + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, AX + MULQ R11 + ADDQ CX ,AX + ADCQ $0, DX + MOVQ DX, R10 + + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + MOVQ 8(DI), AX + MULQ R12 // x[1] * y[0] + MOVQ R9, BX + ADDQ AX, BX + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x2833e84879b97091, AX + MULQ R11 + ADDQ BX, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, CX + MOVQ DX, R10 + MOVQ 16(DI), AX + MULQ R12 // x[2] * y[0] + MOVQ R9, BP + ADDQ AX, BP + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0xb85045b68181585d, AX + MULQ R11 + ADDQ BP, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BX + MOVQ DX, R10 + MOVQ 24(DI), AX + MULQ R12 // x[3] * y[0] + MOVQ R9, SI + ADDQ AX, SI + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x30644e72e131a029, AX + MULQ R11 + ADDQ SI, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BP + MOVQ DX, R10 + + ADDQ R10, R9 + MOVQ R9, SI + + // --------------------------------------------------------------------------------------------- + // outter loop 1 + + // (A,t[0]) := t[0] + x[0]*y[1] + MOVQ (DI), AX // x[0] + MOVQ 8(R8), R12 + MULQ R12 // x[0] * y[1] + ADDQ AX, CX + ADCQ $0, DX + MOVQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, R11 + IMULQ CX , R11 + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, AX + MULQ R11 + ADDQ CX ,AX + ADCQ $0, DX + MOVQ DX, R10 + + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + MOVQ 8(DI), AX + MULQ R12 // x[1] * y[1] + ADDQ R9, BX + ADCQ $0, DX + ADDQ AX, BX + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x2833e84879b97091, AX + MULQ R11 + ADDQ BX, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, CX + MOVQ DX, R10 + MOVQ 16(DI), AX + MULQ R12 // x[2] * y[1] + ADDQ R9, BP + ADCQ $0, DX + ADDQ AX, BP + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0xb85045b68181585d, AX + MULQ R11 + ADDQ BP, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BX + MOVQ DX, R10 + MOVQ 24(DI), AX + MULQ R12 // x[3] * y[1] + ADDQ R9, SI + ADCQ $0, DX + ADDQ AX, SI + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x30644e72e131a029, AX + MULQ R11 + ADDQ SI, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BP + MOVQ DX, R10 + + ADDQ R10, R9 + MOVQ R9, SI + + // --------------------------------------------------------------------------------------------- + // outter loop 2 + + // (A,t[0]) := t[0] + x[0]*y[2] + MOVQ (DI), AX // x[0] + MOVQ 16(R8), R12 + MULQ R12 // x[0] * y[2] + ADDQ AX, CX + ADCQ $0, DX + MOVQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, R11 + IMULQ CX , R11 + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, AX + MULQ R11 + ADDQ CX ,AX + ADCQ $0, DX + MOVQ DX, R10 + + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + MOVQ 8(DI), AX + MULQ R12 // x[1] * y[2] + ADDQ R9, BX + ADCQ $0, DX + ADDQ AX, BX + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x2833e84879b97091, AX + MULQ R11 + ADDQ BX, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, CX + MOVQ DX, R10 + MOVQ 16(DI), AX + MULQ R12 // x[2] * y[2] + ADDQ R9, BP + ADCQ $0, DX + ADDQ AX, BP + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0xb85045b68181585d, AX + MULQ R11 + ADDQ BP, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BX + MOVQ DX, R10 + MOVQ 24(DI), AX + MULQ R12 // x[3] * y[2] + ADDQ R9, SI + ADCQ $0, DX + ADDQ AX, SI + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x30644e72e131a029, AX + MULQ R11 + ADDQ SI, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BP + MOVQ DX, R10 + + ADDQ R10, R9 + MOVQ R9, SI + + // --------------------------------------------------------------------------------------------- + // outter loop 3 + + // (A,t[0]) := t[0] + x[0]*y[3] + MOVQ (DI), AX // x[0] + MOVQ 24(R8), R12 + MULQ R12 // x[0] * y[3] + ADDQ AX, CX + ADCQ $0, DX + MOVQ DX, R9 + + // m := t[0]*q'[0] mod W + MOVQ $0xc2e1f593efffffff, R11 + IMULQ CX , R11 + + // C,_ := t[0] + m*q[0] + MOVQ $0x43e1f593f0000001, AX + MULQ R11 + ADDQ CX ,AX + ADCQ $0, DX + MOVQ DX, R10 + + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + MOVQ 8(DI), AX + MULQ R12 // x[1] * y[3] + ADDQ R9, BX + ADCQ $0, DX + ADDQ AX, BX + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x2833e84879b97091, AX + MULQ R11 + ADDQ BX, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, CX + MOVQ DX, R10 + MOVQ 16(DI), AX + MULQ R12 // x[2] * y[3] + ADDQ R9, BP + ADCQ $0, DX + ADDQ AX, BP + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0xb85045b68181585d, AX + MULQ R11 + ADDQ BP, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BX + MOVQ DX, R10 + MOVQ 24(DI), AX + MULQ R12 // x[3] * y[3] + ADDQ R9, SI + ADCQ $0, DX + ADDQ AX, SI + ADCQ $0, DX + MOVQ DX, R9 + + MOVQ $0x30644e72e131a029, AX + MULQ R11 + ADDQ SI, R10 + ADCQ $0, DX + ADDQ AX, R10 + ADCQ $0, DX + + MOVQ R10, BP + MOVQ DX, R10 + + ADDQ R10, R9 + MOVQ R9, SI + + JMP reduce diff --git a/ff/element_square.go b/ff/element_square.go new file mode 100644 index 0000000..9e568ae --- /dev/null +++ b/ff/element_square.go @@ -0,0 +1,93 @@ +// +build !amd64 + +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by goff (v0.2.0) DO NOT EDIT + +// Package ff contains field arithmetic operations +package ff + +// /!\ WARNING /!\ +// this code has not been audited and is provided as-is. In particular, +// there is no security guarantees such as constant time implementation +// or side-channel attack resistance +// /!\ WARNING /!\ + +import "math/bits" + +// Square z = x * x mod q +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) Square(x *Element) *Element { + + var p [4]uint64 + + var u, v uint64 + { + // round 0 + u, p[0] = bits.Mul64(x[0], x[0]) + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + var t uint64 + t, u, v = madd1sb(x[0], x[1], u) + C, p[0] = madd2(m, 2896914383306846353, v, C) + t, u, v = madd1s(x[0], x[2], t, u) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd1s(x[0], x[3], t, u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 1 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + u, v = madd1(x[1], x[1], p[1]) + C, p[0] = madd2(m, 2896914383306846353, v, C) + var t uint64 + t, u, v = madd2sb(x[1], x[2], p[2], u) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd2s(x[1], x[3], p[3], t, u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 2 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + C, p[0] = madd2(m, 2896914383306846353, p[1], C) + u, v = madd1(x[2], x[2], p[2]) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd2sb(x[2], x[3], p[3], u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 3 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + C, z[0] = madd2(m, 2896914383306846353, p[1], C) + C, z[1] = madd2(m, 13281191951274694749, p[2], C) + u, v = madd1(x[3], x[3], p[3]) + z[3], z[2] = madd3(m, 3486998266802970665, v, C, u) + } + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z + +} diff --git a/ff/element_square_amd64.go b/ff/element_square_amd64.go new file mode 100644 index 0000000..af55d12 --- /dev/null +++ b/ff/element_square_amd64.go @@ -0,0 +1,34 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by goff (v0.2.0) DO NOT EDIT + +// Package ff contains field arithmetic operations +package ff + +// SquareElement z = x * x mod q +// calling this instead of z.Square(x) is prefered for performance critical path +// go - noescape +// func SquareElement(res,x *Element) + +// Square z = x * x mod q +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) Square(x *Element) *Element { + if z != x { + z.Set(x) + } + MulAssignElement(z, x) + // SquareElement(z, x) + return z +} diff --git a/ff/element_test.go b/ff/element_test.go index 090313f..be6f674 100644 --- a/ff/element_test.go +++ b/ff/element_test.go @@ -1,9 +1,26 @@ -// Code generated by goff DO NOT EDIT +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by goff (v0.2.0) DO NOT EDIT + +// Package ff contains field arithmetic operations package ff import ( "crypto/rand" "math/big" + "math/bits" mrand "math/rand" "testing" ) @@ -21,7 +38,14 @@ func TestELEMENTCorrectnessAgainstBigInt(t *testing.T) { modulusMinusOne.Sub(modulus, &one) - for i := 0; i < 1000; i++ { + var n int + if testing.Short() { + n = 10 + } else { + n = 500 + } + + for i := 0; i < n; i++ { // sample 2 random big int b1, _ := rand.Int(rand.Reader, modulus) @@ -57,7 +81,7 @@ func TestELEMENTCorrectnessAgainstBigInt(t *testing.T) { rbExp := new(big.Int).SetUint64(rExp) - var bMul, bAdd, bSub, bDiv, bNeg, bLsh, bInv, bExp, bSquare big.Int + var bMul, bAdd, bSub, bDiv, bNeg, bLsh, bInv, bExp, bExp2, bSquare big.Int // e1 = mont(b1), e2 = mont(b2) var e1, e2, eMul, eAdd, eSub, eDiv, eNeg, eLsh, eInv, eExp, eSquare, eMulAssign, eSubAssign, eAddAssign Element @@ -106,12 +130,40 @@ func TestELEMENTCorrectnessAgainstBigInt(t *testing.T) { cmpEandB(&eNeg, &bNeg, "Neg") cmpEandB(&eInv, &bInv, "Inv") cmpEandB(&eExp, &bExp, "Exp") + cmpEandB(&eLsh, &bLsh, "Lsh") + + // legendre symbol + if e1.Legendre() != big.Jacobi(b1, modulus) { + t.Fatal("legendre symbol computation failed") + } + if e2.Legendre() != big.Jacobi(b2, modulus) { + t.Fatal("legendre symbol computation failed") + } + + // these are slow, killing circle ci + if n <= 5 { + // sqrt + var eSqrt, eExp2 Element + var bSqrt big.Int + bSqrt.ModSqrt(b1, modulus) + eSqrt.Sqrt(&e1) + cmpEandB(&eSqrt, &bSqrt, "Sqrt") + + bits := b2.Bits() + exponent := make([]uint64, len(bits)) + for k := 0; k < len(bits); k++ { + exponent[k] = uint64(bits[k]) + } + eExp2.Exp(e1, exponent...) + bExp2.Exp(b1, b2, modulus) + cmpEandB(&eExp2, &bExp2, "Exp multi words") + } } } func TestELEMENTIsRandom(t *testing.T) { - for i := 0; i < 1000; i++ { + for i := 0; i < 50; i++ { var x, y Element x.SetRandom() y.SetRandom() @@ -125,7 +177,6 @@ func TestELEMENTIsRandom(t *testing.T) { // benchmarks // most benchmarks are rudimentary and should sample a large number of random inputs // or be run multiple times to ensure it didn't measure the fastest path of the function -// TODO: clean up and push benchmarking branch var benchResElement Element @@ -219,6 +270,15 @@ func BenchmarkSquareELEMENT(b *testing.B) { } } +func BenchmarkSqrtELEMENT(b *testing.B) { + var a Element + a.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + func BenchmarkMulAssignELEMENT(b *testing.B) { x := Element{ 1997599621687373223, @@ -232,3 +292,183 @@ func BenchmarkMulAssignELEMENT(b *testing.B) { benchResElement.MulAssign(&x) } } + +func BenchmarkMulAssignASMELEMENT(b *testing.B) { + x := Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulAssignElement(&benchResElement, &x) + } +} + +func TestELEMENTAsm(t *testing.T) { + // ensure ASM implementations matches the ones using math/bits + modulus, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + for i := 0; i < 500; i++ { + // sample 2 random big int + b1, _ := rand.Int(rand.Reader, modulus) + b2, _ := rand.Int(rand.Reader, modulus) + + // e1 = mont(b1), e2 = mont(b2) + var e1, e2, eTestMul, eMulAssign, eSquare, eTestSquare Element + e1.SetBigInt(b1) + e2.SetBigInt(b2) + + eTestMul = e1 + eTestMul.testMulAssign(&e2) + eMulAssign = e1 + eMulAssign.MulAssign(&e2) + + if !eTestMul.Equal(&eMulAssign) { + t.Fatal("inconsisntencies between MulAssign and testMulAssign --> check if MulAssign is calling ASM implementaiton on amd64") + } + + // square + eSquare.Square(&e1) + eTestSquare.testSquare(&e1) + + if !eTestSquare.Equal(&eSquare) { + t.Fatal("inconsisntencies between Square and testSquare --> check if Square is calling ASM implementaiton on amd64") + } + } +} + +// this is here for consistency purposes, to ensure MulAssign on AMD64 using asm implementation gives consistent results +func (z *Element) testMulAssign(x *Element) *Element { + + var t [4]uint64 + var c [3]uint64 + { + // round 0 + v := z[0] + c[1], c[0] = bits.Mul64(v, x[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd1(v, x[1], c[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd1(v, x[2], c[1]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd1(v, x[3], c[1]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 1 + v := z[1] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 2 + v := z[2] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 3 + v := z[3] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// this is here for consistency purposes, to ensure Square on AMD64 using asm implementation gives consistent results +func (z *Element) testSquare(x *Element) *Element { + + var p [4]uint64 + + var u, v uint64 + { + // round 0 + u, p[0] = bits.Mul64(x[0], x[0]) + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + var t uint64 + t, u, v = madd1sb(x[0], x[1], u) + C, p[0] = madd2(m, 2896914383306846353, v, C) + t, u, v = madd1s(x[0], x[2], t, u) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd1s(x[0], x[3], t, u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 1 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + u, v = madd1(x[1], x[1], p[1]) + C, p[0] = madd2(m, 2896914383306846353, v, C) + var t uint64 + t, u, v = madd2sb(x[1], x[2], p[2], u) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd2s(x[1], x[3], p[3], t, u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 2 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + C, p[0] = madd2(m, 2896914383306846353, p[1], C) + u, v = madd1(x[2], x[2], p[2]) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd2sb(x[2], x[3], p[3], u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 3 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + C, z[0] = madd2(m, 2896914383306846353, p[1], C) + C, z[1] = madd2(m, 13281191951274694749, p[2], C) + u, v = madd1(x[3], x[3], p[3]) + z[3], z[2] = madd3(m, 3486998266802970665, v, C, u) + } + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z + +} diff --git a/go.mod b/go.mod index 2cbfeef..842f51c 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,5 @@ require ( github.com/ethereum/go-ethereum v1.8.27 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4 + golang.org/x/sys v0.0.0-20190412213103-97732733099d )