From 17bad7585348f5a3e3643276eb8e99404c5cabbb Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 3 Mar 2020 16:30:00 +0100 Subject: [PATCH 1/4] Add goff generated finite field arithmetic code for used field --- ff/arith.go | 122 ++++++++ ff/element.go | 764 +++++++++++++++++++++++++++++++++++++++++++++ ff/element_test.go | 234 ++++++++++++++ 3 files changed, 1120 insertions(+) create mode 100644 ff/arith.go create mode 100644 ff/element.go create mode 100644 ff/element_test.go diff --git a/ff/arith.go b/ff/arith.go new file mode 100644 index 0000000..938c87a --- /dev/null +++ b/ff/arith.go @@ -0,0 +1,122 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by goff DO NOT EDIT + +package ff + +import ( + "math/bits" +) + +func madd(a, b, t, u, v uint64) (uint64, uint64, uint64) { + var carry uint64 + hi, lo := bits.Mul64(a, b) + v, carry = bits.Add64(lo, v, 0) + u, carry = bits.Add64(hi, u, carry) + t, _ = bits.Add64(t, 0, carry) + return t, u, v +} + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2s superhi, hi, lo = 2*a*b + c + d + e +func madd2s(a, b, c, d, e uint64) (superhi, hi, lo uint64) { + var carry, sum uint64 + + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, lo, 0) + hi, superhi = bits.Add64(hi, hi, carry) + + sum, carry = bits.Add64(c, e, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, sum, 0) + hi, _ = bits.Add64(hi, 0, carry) + hi, _ = bits.Add64(hi, 0, d) + return +} + +func madd1s(a, b, d, e uint64) (superhi, hi, lo uint64) { + var carry uint64 + + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, lo, 0) + hi, superhi = bits.Add64(hi, hi, carry) + lo, carry = bits.Add64(lo, e, 0) + hi, _ = bits.Add64(hi, 0, carry) + hi, _ = bits.Add64(hi, 0, d) + return +} + +func madd2sb(a, b, c, e uint64) (superhi, hi, lo uint64) { + var carry, sum uint64 + + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, lo, 0) + hi, superhi = bits.Add64(hi, hi, carry) + + sum, carry = bits.Add64(c, e, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, sum, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd1sb(a, b, e uint64) (superhi, hi, lo uint64) { + var carry uint64 + + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, lo, 0) + hi, superhi = bits.Add64(hi, hi, carry) + lo, carry = bits.Add64(lo, e, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/ff/element.go b/ff/element.go new file mode 100644 index 0000000..95334b0 --- /dev/null +++ b/ff/element.go @@ -0,0 +1,764 @@ +// Copyright 2020 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// field modulus q = +// +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// Code generated by goff DO NOT EDIT +// goff version: - build: +// Element are assumed to be in Montgomery form in all methods + +// Package ff (generated by goff) contains field arithmetics operations +package ff + +import ( + "crypto/rand" + "encoding/binary" + "io" + "math/big" + "math/bits" + "sync" + + "unsafe" +) + +// Element represents a field element stored on 4 words (uint64) +// Element are assumed to be in Montgomery form in all methods +type Element [4]uint64 + +// ElementLimbs number of 64 bits words needed to represent Element +const ElementLimbs = 4 + +// ElementBits number bits needed to represent Element +const ElementBits = 254 + +// SetUint64 z = v, sets z LSB to v (non-Montgomery form) and convert z to Montgomery form +func (z *Element) SetUint64(v uint64) *Element { + z[0] = v + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z.ToMont() +} + +// Set z = x +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + z[1] = x[1] + z[2] = x[2] + z[3] = x[3] + return z +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 12436184717236109307 + z[1] = 3962172157175319849 + z[2] = 7381016538464732718 + z[3] = 1011752739694698287 + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + return z.SetZero() + } + var borrow uint64 + z[0], borrow = bits.Sub64(4891460686036598785, x[0], 0) + z[1], borrow = bits.Sub64(2896914383306846353, x[1], borrow) + z[2], borrow = bits.Sub64(13281191951274694749, x[2], borrow) + z[3], _ = bits.Sub64(3486998266802970665, x[3], borrow) + return z +} + +// Div z = x*y^-1 mod q +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Equal returns z == x +func (z *Element) Equal(x *Element) bool { + return (z[3] == x[3]) && (z[2] == x[2]) && (z[1] == x[1]) && (z[0] == x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[3] | z[2] | z[1] | z[0]) == 0 +} + +// field modulus stored as big.Int +var _elementModulusBigInt big.Int +var onceelementModulus sync.Once + +func elementModulusBigInt() *big.Int { + onceelementModulus.Do(func() { + _elementModulusBigInt.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + }) + return &_elementModulusBigInt +} + +// Inverse z = x^-1 mod q +// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" +// if x == 0, sets and returns z = x +func (z *Element) Inverse(x *Element) *Element { + if x.IsZero() { + return z.Set(x) + } + + // initialize u = q + var u = Element{ + 4891460686036598785, + 2896914383306846353, + 13281191951274694749, + 3486998266802970665, + } + + // initialize s = r^2 + var s = Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + + // r = 0 + r := Element{} + + v := *x + + var carry, borrow, t, t2 uint64 + var bigger, uIsOne, vIsOne bool + + for !uIsOne && !vIsOne { + for v[0]&1 == 0 { + + // v = v >> 1 + t2 = v[3] << 63 + v[3] >>= 1 + t = t2 + t2 = v[2] << 63 + v[2] = (v[2] >> 1) | t + t = t2 + t2 = v[1] << 63 + v[1] = (v[1] >> 1) | t + t = t2 + v[0] = (v[0] >> 1) | t + + if s[0]&1 == 1 { + + // s = s + q + s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) + s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) + s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) + s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + + } + + // s = s >> 1 + t2 = s[3] << 63 + s[3] >>= 1 + t = t2 + t2 = s[2] << 63 + s[2] = (s[2] >> 1) | t + t = t2 + t2 = s[1] << 63 + s[1] = (s[1] >> 1) | t + t = t2 + s[0] = (s[0] >> 1) | t + + } + for u[0]&1 == 0 { + + // u = u >> 1 + t2 = u[3] << 63 + u[3] >>= 1 + t = t2 + t2 = u[2] << 63 + u[2] = (u[2] >> 1) | t + t = t2 + t2 = u[1] << 63 + u[1] = (u[1] >> 1) | t + t = t2 + u[0] = (u[0] >> 1) | t + + if r[0]&1 == 1 { + + // r = r + q + r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) + r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) + r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) + r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + + } + + // r = r >> 1 + t2 = r[3] << 63 + r[3] >>= 1 + t = t2 + t2 = r[2] << 63 + r[2] = (r[2] >> 1) | t + t = t2 + t2 = r[1] << 63 + r[1] = (r[1] >> 1) | t + t = t2 + r[0] = (r[0] >> 1) | t + + } + + // v >= u + bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + + if bigger { + + // v = v - u + v[0], borrow = bits.Sub64(v[0], u[0], 0) + v[1], borrow = bits.Sub64(v[1], u[1], borrow) + v[2], borrow = bits.Sub64(v[2], u[2], borrow) + v[3], _ = bits.Sub64(v[3], u[3], borrow) + + // r >= s + bigger = !(r[3] < s[3] || (r[3] == s[3] && (r[2] < s[2] || (r[2] == s[2] && (r[1] < s[1] || (r[1] == s[1] && (r[0] < s[0]))))))) + + if bigger { + + // s = s + q + s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) + s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) + s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) + s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + + } + + // s = s - r + s[0], borrow = bits.Sub64(s[0], r[0], 0) + s[1], borrow = bits.Sub64(s[1], r[1], borrow) + s[2], borrow = bits.Sub64(s[2], r[2], borrow) + s[3], _ = bits.Sub64(s[3], r[3], borrow) + + } else { + + // u = u - v + u[0], borrow = bits.Sub64(u[0], v[0], 0) + u[1], borrow = bits.Sub64(u[1], v[1], borrow) + u[2], borrow = bits.Sub64(u[2], v[2], borrow) + u[3], _ = bits.Sub64(u[3], v[3], borrow) + + // s >= r + bigger = !(s[3] < r[3] || (s[3] == r[3] && (s[2] < r[2] || (s[2] == r[2] && (s[1] < r[1] || (s[1] == r[1] && (s[0] < r[0]))))))) + + if bigger { + + // r = r + q + r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) + r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) + r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) + r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + + } + + // r = r - s + r[0], borrow = bits.Sub64(r[0], s[0], 0) + r[1], borrow = bits.Sub64(r[1], s[1], borrow) + r[2], borrow = bits.Sub64(r[2], s[2], borrow) + r[3], _ = bits.Sub64(r[3], s[3], borrow) + + } + uIsOne = (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 + vIsOne = (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 + } + + if uIsOne { + z.Set(&r) + } else { + z.Set(&s) + } + + return z +} + +// SetRandom sets z to a random element < q +func (z *Element) SetRandom() *Element { + bytes := make([]byte, 32) + io.ReadFull(rand.Reader, bytes) + z[0] = binary.BigEndian.Uint64(bytes[0:8]) + z[1] = binary.BigEndian.Uint64(bytes[8:16]) + z[2] = binary.BigEndian.Uint64(bytes[16:24]) + z[3] = binary.BigEndian.Uint64(bytes[24:32]) + z[3] %= 3486998266802970665 + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + + return z +} + +// Add z = x + y mod q +func (z *Element) Add(x, y *Element) *Element { + var carry uint64 + + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// AddAssign z = z + x mod q +func (z *Element) AddAssign(x *Element) *Element { + var carry uint64 + + z[0], carry = bits.Add64(z[0], x[0], 0) + z[1], carry = bits.Add64(z[1], x[1], carry) + z[2], carry = bits.Add64(z[2], x[2], carry) + z[3], _ = bits.Add64(z[3], x[3], carry) + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// Double z = x + x mod q, aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + var carry uint64 + + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// Sub z = x - y mod q +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], 4891460686036598785, 0) + z[1], c = bits.Add64(z[1], 2896914383306846353, c) + z[2], c = bits.Add64(z[2], 13281191951274694749, c) + z[3], _ = bits.Add64(z[3], 3486998266802970665, c) + } + return z +} + +// SubAssign z = z - x mod q +func (z *Element) SubAssign(x *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(z[0], x[0], 0) + z[1], b = bits.Sub64(z[1], x[1], b) + z[2], b = bits.Sub64(z[2], x[2], b) + z[3], b = bits.Sub64(z[3], x[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], 4891460686036598785, 0) + z[1], c = bits.Add64(z[1], 2896914383306846353, c) + z[2], c = bits.Add64(z[2], 13281191951274694749, c) + z[3], _ = bits.Add64(z[3], 3486998266802970665, c) + } + return z +} + +// Exp z = x^e mod q +func (z *Element) Exp(x Element, e uint64) *Element { + if e == 0 { + return z.SetOne() + } + + z.Set(&x) + + l := bits.Len64(e) - 2 + for i := l; i >= 0; i-- { + z.Square(z) + if e&(1< q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// ToMont converts z to Montgomery form +// sets and returns z = z * r^2 +func (z *Element) ToMont() *Element { + var rSquare = Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + return z.MulAssign(&rSquare) +} + +// ToRegular returns z in regular form (doesn't mutate z) +func (z Element) ToRegular() Element { + return *z.FromMont() +} + +// String returns the string form of an Element in Montgomery form +func (z *Element) String() string { + var _z big.Int + return z.ToBigIntRegular(&_z).String() +} + +// ToBigInt returns z as a big.Int in Montgomery form +func (z *Element) ToBigInt(res *big.Int) *big.Int { + bits := (*[4]big.Word)(unsafe.Pointer(z)) + return res.SetBits(bits[:]) +} + +// ToBigIntRegular returns z as a big.Int in regular form +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.FromMont() + bits := (*[4]big.Word)(unsafe.Pointer(&z)) + return res.SetBits(bits[:]) +} + +// SetBigInt sets z to v (regular form) and returns z in Montgomery form +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + zero := big.NewInt(0) + q := elementModulusBigInt() + + // copy input + vv := new(big.Int).Set(v) + + // while v < 0, v+=q + for vv.Cmp(zero) == -1 { + vv.Add(vv, q) + } + // while v > q, v-=q + for vv.Cmp(q) == 1 { + vv.Sub(vv, q) + } + // if v == q, return 0 + if vv.Cmp(q) == 0 { + return z + } + // v should + vBits := vv.Bits() + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + return z.ToMont() +} + +// SetString creates a big.Int with s (in base 10) and calls SetBigInt on z +func (z *Element) SetString(s string) *Element { + x, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("Element.SetString failed -> can't parse number in base10 into a big.Int") + } + return z.SetBigInt(x) +} + +// Mul z = x * y mod q +func (z *Element) Mul(x, y *Element) *Element { + + var t [4]uint64 + var c [3]uint64 + { + // round 0 + v := x[0] + c[1], c[0] = bits.Mul64(v, y[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd1(v, y[1], c[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd1(v, y[2], c[1]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd1(v, y[3], c[1]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 1 + v := x[1] + c[1], c[0] = madd1(v, y[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, y[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, y[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, y[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 2 + v := x[2] + c[1], c[0] = madd1(v, y[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, y[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, y[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, y[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 3 + v := x[3] + c[1], c[0] = madd1(v, y[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, y[1], c[1], t[1]) + c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, y[2], c[1], t[2]) + c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, y[3], c[1], t[3]) + z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// MulAssign z = z * x mod q +func (z *Element) MulAssign(x *Element) *Element { + + var t [4]uint64 + var c [3]uint64 + { + // round 0 + v := z[0] + c[1], c[0] = bits.Mul64(v, x[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd1(v, x[1], c[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd1(v, x[2], c[1]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd1(v, x[3], c[1]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 1 + v := z[1] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 2 + v := z[2] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + { + // round 3 + v := z[3] + c[1], c[0] = madd1(v, x[0], t[0]) + m := c[0] * 14042775128853446655 + c[2] = madd0(m, 4891460686036598785, c[0]) + c[1], c[0] = madd2(v, x[1], c[1], t[1]) + c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) + c[1], c[0] = madd2(v, x[2], c[1], t[2]) + c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) + c[1], c[0] = madd2(v, x[3], c[1], t[3]) + z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + } + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} + +// Square z = x * x mod q +func (z *Element) Square(x *Element) *Element { + + var p [4]uint64 + + var u, v uint64 + { + // round 0 + u, p[0] = bits.Mul64(x[0], x[0]) + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + var t uint64 + t, u, v = madd1sb(x[0], x[1], u) + C, p[0] = madd2(m, 2896914383306846353, v, C) + t, u, v = madd1s(x[0], x[2], t, u) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd1s(x[0], x[3], t, u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 1 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + u, v = madd1(x[1], x[1], p[1]) + C, p[0] = madd2(m, 2896914383306846353, v, C) + var t uint64 + t, u, v = madd2sb(x[1], x[2], p[2], u) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd2s(x[1], x[3], p[3], t, u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 2 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + C, p[0] = madd2(m, 2896914383306846353, p[1], C) + u, v = madd1(x[2], x[2], p[2]) + C, p[1] = madd2(m, 13281191951274694749, v, C) + _, u, v = madd2sb(x[2], x[3], p[3], u) + p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + } + { + // round 3 + m := p[0] * 14042775128853446655 + C := madd0(m, 4891460686036598785, p[0]) + C, z[0] = madd2(m, 2896914383306846353, p[1], C) + C, z[1] = madd2(m, 13281191951274694749, p[2], C) + u, v = madd1(x[3], x[3], p[3]) + z[3], z[2] = madd3(m, 3486998266802970665, v, C, u) + } + + // if z > q --> z -= q + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + return z +} diff --git a/ff/element_test.go b/ff/element_test.go new file mode 100644 index 0000000..090313f --- /dev/null +++ b/ff/element_test.go @@ -0,0 +1,234 @@ +// Code generated by goff DO NOT EDIT +package ff + +import ( + "crypto/rand" + "math/big" + mrand "math/rand" + "testing" +) + +func TestELEMENTCorrectnessAgainstBigInt(t *testing.T) { + modulus, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + cmpEandB := func(e *Element, b *big.Int, name string) { + var _e big.Int + if e.FromMont().ToBigInt(&_e).Cmp(b) != 0 { + t.Fatal(name, "failed") + } + } + var modulusMinusOne, one big.Int + one.SetUint64(1) + + modulusMinusOne.Sub(modulus, &one) + + for i := 0; i < 1000; i++ { + + // sample 2 random big int + b1, _ := rand.Int(rand.Reader, modulus) + b2, _ := rand.Int(rand.Reader, modulus) + rExp := mrand.Uint64() + + // adding edge cases + // TODO need more edge cases + switch i { + case 0: + rExp = 0 + b1.SetUint64(0) + case 1: + b2.SetUint64(0) + case 2: + b1.SetUint64(0) + b2.SetUint64(0) + case 3: + rExp = 0 + case 4: + rExp = 1 + case 5: + rExp = ^uint64(0) // max uint + case 6: + rExp = 2 + b1.Set(&modulusMinusOne) + case 7: + b2.Set(&modulusMinusOne) + case 8: + b1.Set(&modulusMinusOne) + b2.Set(&modulusMinusOne) + } + + rbExp := new(big.Int).SetUint64(rExp) + + var bMul, bAdd, bSub, bDiv, bNeg, bLsh, bInv, bExp, bSquare big.Int + + // e1 = mont(b1), e2 = mont(b2) + var e1, e2, eMul, eAdd, eSub, eDiv, eNeg, eLsh, eInv, eExp, eSquare, eMulAssign, eSubAssign, eAddAssign Element + e1.SetBigInt(b1) + e2.SetBigInt(b2) + + // (e1*e2).FromMont() === b1*b2 mod q ... etc + eSquare.Square(&e1) + eMul.Mul(&e1, &e2) + eMulAssign.Set(&e1) + eMulAssign.MulAssign(&e2) + eAdd.Add(&e1, &e2) + eAddAssign.Set(&e1) + eAddAssign.AddAssign(&e2) + eSub.Sub(&e1, &e2) + eSubAssign.Set(&e1) + eSubAssign.SubAssign(&e2) + eDiv.Div(&e1, &e2) + eNeg.Neg(&e1) + eInv.Inverse(&e1) + eExp.Exp(e1, rExp) + eLsh.Double(&e1) + + // same operations with big int + bAdd.Add(b1, b2).Mod(&bAdd, modulus) + bMul.Mul(b1, b2).Mod(&bMul, modulus) + bSquare.Mul(b1, b1).Mod(&bSquare, modulus) + bSub.Sub(b1, b2).Mod(&bSub, modulus) + bDiv.ModInverse(b2, modulus) + bDiv.Mul(&bDiv, b1). + Mod(&bDiv, modulus) + bNeg.Neg(b1).Mod(&bNeg, modulus) + + bInv.ModInverse(b1, modulus) + bExp.Exp(b1, rbExp, modulus) + bLsh.Lsh(b1, 1).Mod(&bLsh, modulus) + + cmpEandB(&eSquare, &bSquare, "Square") + cmpEandB(&eMul, &bMul, "Mul") + cmpEandB(&eMulAssign, &bMul, "MulAssign") + cmpEandB(&eAdd, &bAdd, "Add") + cmpEandB(&eAddAssign, &bAdd, "AddAssign") + cmpEandB(&eSub, &bSub, "Sub") + cmpEandB(&eSubAssign, &bSub, "SubAssign") + cmpEandB(&eDiv, &bDiv, "Div") + cmpEandB(&eNeg, &bNeg, "Neg") + cmpEandB(&eInv, &bInv, "Inv") + cmpEandB(&eExp, &bExp, "Exp") + cmpEandB(&eLsh, &bLsh, "Lsh") + } +} + +func TestELEMENTIsRandom(t *testing.T) { + for i := 0; i < 1000; i++ { + var x, y Element + x.SetRandom() + y.SetRandom() + if x.Equal(&y) { + t.Fatal("2 random numbers are unlikely to be equal") + } + } +} + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function +// TODO: clean up and push benchmarking branch + +var benchResElement Element + +func BenchmarkInverseELEMENT(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} +func BenchmarkExpELEMENT(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, mrand.Uint64()) + } +} + +func BenchmarkDoubleELEMENT(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkAddELEMENT(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkSubELEMENT(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkNegELEMENT(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkDivELEMENT(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkFromMontELEMENT(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.FromMont() + } +} + +func BenchmarkToMontELEMENT(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.ToMont() + } +} +func BenchmarkSquareELEMENT(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkMulAssignELEMENT(b *testing.B) { + x := Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.MulAssign(&x) + } +} From 83f87bfa46c4e64ea66d732c2bda3b4dd8a7e88e Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 3 Mar 2020 16:31:09 +0100 Subject: [PATCH 2/4] Resolve #4 --- constants/constants.go | 14 +++++++++++--- mimc7/mimc7.go | 4 ++-- utils/utils.go | 31 +++++++++++++++++++++++++------ 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/constants/constants.go b/constants/constants.go index 986ceb8..16933e3 100644 --- a/constants/constants.go +++ b/constants/constants.go @@ -1,12 +1,15 @@ package constants import ( - "github.com/iden3/go-iden3-crypto/utils" + "fmt" "math/big" + + "github.com/iden3/go-iden3-crypto/ff" ) // Q is the order of the integer field (Zq) that fits inside the SNARK. var Q *big.Int +var QE *ff.Element // Zero is 0. var Zero *big.Int @@ -21,6 +24,11 @@ func init() { Zero = big.NewInt(0) One = big.NewInt(1) MinusOne = big.NewInt(-1) - Q = utils.NewIntFromString( - "21888242871839275222246405745257275088548364400416034343698204186575808495617") + + qString := "21888242871839275222246405745257275088548364400416034343698204186575808495617" + var ok bool + Q, ok = new(big.Int).SetString(qString, 10) + if !ok { + panic(fmt.Sprintf("Bad base 10 string %s", qString)) + } } diff --git a/mimc7/mimc7.go b/mimc7/mimc7.go index f2b618f..f323b33 100644 --- a/mimc7/mimc7.go +++ b/mimc7/mimc7.go @@ -75,7 +75,7 @@ func MIMC7HashGeneric(fqR field.Fq, xIn, k *big.Int, nRounds int) *big.Int { // HashGeneric performs the MIMC7 hash over a *big.Int array, in a generic way, where it can be specified the Finite Field over R, and the number of rounds func HashGeneric(iv *big.Int, arr []*big.Int, fqR field.Fq, nRounds int) (*big.Int, error) { - if !utils.CheckBigIntArrayInField(arr, constants.fqR.Q) { + if !utils.CheckBigIntArrayInField(arr) { return nil, errors.New("inputs values not inside Finite Field") } r := iv @@ -108,7 +108,7 @@ func MIMC7Hash(xIn, k *big.Int) *big.Int { // Hash performs the MIMC7 hash over a *big.Int array func Hash(arr []*big.Int, key *big.Int) (*big.Int, error) { - if !utils.CheckBigIntArrayInField(arr, constants.fqR.Q) { + if !utils.CheckBigIntArrayInField(arr) { return nil, errors.New("inputs values not inside Finite Field") } var r *big.Int diff --git a/utils/utils.go b/utils/utils.go index 0f2d639..188d0ff 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,6 +6,9 @@ import ( "fmt" "math/big" "strings" + + "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/ff" ) // NewIntFromString creates a new big.Int from a decimal integer encoded as a @@ -87,20 +90,36 @@ func HexDecodeInto(dst []byte, h []byte) error { return nil } -// CheckBigIntInField checks if given big.Int fits in a Field Q element -func CheckBigIntInField(a *big.Int, q *big.Int) bool { - if a.Cmp(q) != -1 { +// CheckBigIntInField checks if given *big.Int fits in a Field Q element +func CheckBigIntInField(a *big.Int) bool { + if a.Cmp(constants.Q) != -1 { return false } return true } -// CheckBigIntArrayInField checks if given big.Int fits in a Field Q element -func CheckBigIntArrayInField(arr []*big.Int, q *big.Int) bool { +// CheckBigIntArrayInField checks if given *big.Int fits in a Field Q element +func CheckBigIntArrayInField(arr []*big.Int) bool { for _, a := range arr { - if !CheckBigIntInField(a, q) { + if !CheckBigIntInField(a) { return false } } return true } + +// CheckElementArrayInField checks if given *ff.Element fits in a Field Q element +func CheckElementArrayInField(arr []*ff.Element) bool { + for _, aE := range arr { + a := big.NewInt(0) + aE.ToBigIntRegular(a) + if !CheckBigIntInField(a) { + return false + } + } + return true +} + +func NewElement() *ff.Element { + return &ff.Element{0, 0, 0, 0} +} From b45d8a582b2ec4bb798d7fd61698b3a1722c08ea Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 3 Mar 2020 16:31:40 +0100 Subject: [PATCH 3/4] Optimize Poseidon migrating from *big.Int to goff Optimize Poseidon migrating from *big.Int to goff generated finite field operations. Benchmarks: Tested on a Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz, with 16GB of RAM. - Before the optimizations: ``` BenchmarkPoseidon-4 470 2489678 ns/op BenchmarkPoseidonLarge-4 476 2530568 ns/op ``` - With the optimizations of #12: ``` BenchmarkPoseidon-4 766 1550013 ns/op BenchmarkPoseidonLarge-4 782 1547572 ns/op ``` - With the changes of this PR, where uses goff generated code instead of *big.Int: ``` BenchmarkPoseidon-4 9638 121651 ns/op BenchmarkPoseidonLarge-4 9781 119921 ns/op ``` --- poseidon/poseidon.go | 104 ++++++++++++++++---------------------- poseidon/poseidon_test.go | 59 ++++++++++++--------- 2 files changed, 80 insertions(+), 83 deletions(-) diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index d4e2b8b..79cd651 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -1,12 +1,11 @@ package poseidon import ( - "bytes" "errors" "math/big" "strconv" - "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "golang.org/x/crypto/blake2b" ) @@ -16,15 +15,11 @@ const NROUNDSF = 8 const NROUNDSP = 57 const T = 6 -var constC []*big.Int -var constM [T][T]*big.Int +var constC []*ff.Element +var constM [T][T]*ff.Element -func Zero() *big.Int { - return new(big.Int) -} - -func modQ(v *big.Int) { - v.Mod(v, constants.Q) +func Zero() *ff.Element { + return utils.NewElement().SetZero() } func init() { @@ -32,22 +27,12 @@ func init() { constM = getMDS() } -func leByteArrayToBigInt(b []byte) *big.Int { - res := big.NewInt(0) - for i := 0; i < len(b); i++ { - n := big.NewInt(int64(b[i])) - res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8))) - } - return res -} - -func getPseudoRandom(seed string, n int) []*big.Int { - res := make([]*big.Int, n) +func getPseudoRandom(seed string, n int) []*ff.Element { + res := make([]*ff.Element, n) hash := blake2b.Sum256([]byte(seed)) for i := 0; i < n; i++ { - hashBigInt := Zero() - res[i] = utils.SetBigIntFromLEBytes(hashBigInt, hash[:]) - modQ(res[i]) + hashBigInt := big.NewInt(int64(0)) + res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:])) hash = blake2b.Sum256(hash[:]) } return res @@ -62,31 +47,30 @@ func nonceToString(n int) string { } // https://eprint.iacr.org/2019/458.pdf pag.8 -func getMDS() [T][T]*big.Int { +func getMDS() [T][T]*ff.Element { nonce := 0 cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) for !checkAllDifferent(cauchyMatrix) { nonce += 1 cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) } - var m [T][T]*big.Int + var m [T][T]*ff.Element for i := 0; i < T; i++ { - // var mi []*big.Int for j := 0; j < T; j++ { - m[i][j] = new(big.Int).Sub(cauchyMatrix[i], cauchyMatrix[T+j]) - m[i][j].ModInverse(m[i][j], constants.Q) + m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j]) + m[i][j].Inverse(m[i][j]) } } return m } -func checkAllDifferent(v []*big.Int) bool { +func checkAllDifferent(v []*ff.Element) bool { for i := 0; i < len(v); i++ { - if bytes.Equal(v[i].Bytes(), big.NewInt(int64(0)).Bytes()) { + if v[i].Equal(utils.NewElement().SetZero()) { return false } for j := i + 1; j < len(v); j++ { - if bytes.Equal(v[i].Bytes(), v[j].Bytes()) { + if v[i].Equal(v[j]) { return false } } @@ -95,22 +79,22 @@ func checkAllDifferent(v []*big.Int) bool { } // ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf -func ark(state [T]*big.Int, c *big.Int) { +func ark(state [T]*ff.Element, c *ff.Element) { for i := 0; i < T; i++ { - modQ(state[i].Add(state[i], c)) + state[i].Add(state[i], c) } } // cubic performs x^5 mod p // https://eprint.iacr.org/2019/458.pdf page 8 -var five = big.NewInt(5) +// var five = big.NewInt(5) -func cubic(a *big.Int) { - a.Exp(a, five, constants.Q) +func cubic(a *ff.Element) { + a.Exp(*a, 5) } // sbox https://eprint.iacr.org/2019/458.pdf page 6 -func sbox(state [T]*big.Int, i int) { +func sbox(state [T]*ff.Element, i int) { if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) { for j := 0; j < T; j++ { cubic(state[j]) @@ -121,30 +105,29 @@ func sbox(state [T]*big.Int, i int) { } // mix returns [[matrix]] * [vector] -func mix(state [T]*big.Int, newState [T]*big.Int, m [T][T]*big.Int) { +func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) { mul := Zero() for i := 0; i < T; i++ { - newState[i].SetInt64(0) + newState[i].SetUint64(0) for j := 0; j < T; j++ { - modQ(mul.Mul(m[i][j], state[j])) + mul.Mul(m[i][j], state[j]) newState[i].Add(newState[i], mul) } - modQ(newState[i]) } } // PoseidonHash computes the Poseidon hash for the given inputs -func PoseidonHash(inp [T]*big.Int) (*big.Int, error) { - if !utils.CheckBigIntArrayInField(inp[:], constants.Q) { +func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) { + if !utils.CheckElementArrayInField(inp[:]) { return nil, errors.New("inputs values not inside Finite Field") } - state := [T]*big.Int{} + state := [T]*ff.Element{} for i := 0; i < T; i++ { - state[i] = new(big.Int).Set(inp[i]) + state[i] = utils.NewElement().Set(inp[i]) } // ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5 - var newState [T]*big.Int + var newState [T]*ff.Element for i := 0; i < T; i++ { newState[i] = Zero() } @@ -157,16 +140,16 @@ func PoseidonHash(inp [T]*big.Int) (*big.Int, error) { return state[0], nil } -// Hash performs the Poseidon hash over a *big.Int array +// Hash performs the Poseidon hash over a ff.Element array // in chunks of 5 elements -func Hash(arr []*big.Int) (*big.Int, error) { - if !utils.CheckBigIntArrayInField(arr, constants.Q) { +func Hash(arr []*ff.Element) (*ff.Element, error) { + if !utils.CheckElementArrayInField(arr) { return nil, errors.New("inputs values not inside Finite Field") } - r := big.NewInt(1) + r := utils.NewElement().SetOne() for i := 0; i < len(arr); i = i + T - 1 { - var toHash [T]*big.Int + var toHash [T]*ff.Element j := 0 for ; j < T-1; j++ { if i+j >= len(arr) { @@ -177,14 +160,14 @@ func Hash(arr []*big.Int) (*big.Int, error) { toHash[j] = r j++ for ; j < T; j++ { - toHash[j] = constants.Zero + toHash[j] = Zero() } ph, err := PoseidonHash(toHash) if err != nil { return nil, err } - modQ(r.Add(r, ph)) + r.Add(r, ph) } return r, nil @@ -192,18 +175,19 @@ func Hash(arr []*big.Int) (*big.Int, error) { // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // little-endian -func HashBytes(b []byte) (*big.Int, error) { +func HashBytes(b []byte) (*ff.Element, error) { n := 31 - bElems := make([]*big.Int, 0, len(b)/n+1) + bElems := make([]*ff.Element, 0, len(b)/n+1) for i := 0; i < len(b)/n; i++ { - v := Zero() + v := big.NewInt(int64(0)) utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)]) - bElems = append(bElems, v) + bElems = append(bElems, utils.NewElement().SetBigInt(v)) + } if len(b)%n != 0 { - v := Zero() + v := big.NewInt(int64(0)) utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:]) - bElems = append(bElems, v) + bElems = append(bElems, utils.NewElement().SetBigInt(v)) } return Hash(bElems) } diff --git a/poseidon/poseidon_test.go b/poseidon/poseidon_test.go index de13104..b6791ef 100644 --- a/poseidon/poseidon_test.go +++ b/poseidon/poseidon_test.go @@ -5,6 +5,7 @@ import ( "math/big" "testing" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "github.com/stretchr/testify/assert" "golang.org/x/crypto/blake2b" @@ -16,46 +17,46 @@ func TestBlake2bVersion(t *testing.T) { } func TestPoseidon(t *testing.T) { - b1 := big.NewInt(int64(1)) - b2 := big.NewInt(int64(2)) - h, err := Hash([]*big.Int{b1, b2}) + b1 := utils.NewElement().SetUint64(1) + b2 := utils.NewElement().SetUint64(2) + h, err := Hash([]*ff.Element{b1, b2}) assert.Nil(t, err) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String()) - b3 := big.NewInt(int64(3)) - b4 := big.NewInt(int64(4)) - h, err = Hash([]*big.Int{b3, b4}) + b3 := utils.NewElement().SetUint64(3) + b4 := utils.NewElement().SetUint64(4) + h, err = Hash([]*ff.Element{b3, b4}) assert.Nil(t, err) assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String()) msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") n := 31 - msgElems := make([]*big.Int, 0, len(msg)/n+1) + msgElems := make([]*ff.Element, 0, len(msg)/n+1) for i := 0; i < len(msg)/n; i++ { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)]) - msgElems = append(msgElems, v) + msgElems = append(msgElems, utils.NewElement().SetBigInt(v)) } if len(msg)%n != 0 { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:]) - msgElems = append(msgElems, v) + msgElems = append(msgElems, utils.NewElement().SetBigInt(v)) } hmsg, err := Hash(msgElems) assert.Nil(t, err) assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String()) msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.") - msg2Elems := make([]*big.Int, 0, len(msg2)/n+1) + msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1) for i := 0; i < len(msg2)/n; i++ { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)]) - msg2Elems = append(msg2Elems, v) + msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v)) } if len(msg2)%n != 0 { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:]) - msg2Elems = append(msg2Elems, v) + msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v)) } hmsg2, err := Hash(msg2Elems) assert.Nil(t, err) @@ -67,29 +68,41 @@ func TestPoseidon(t *testing.T) { } func TestPoseidonBrokenChunks(t *testing.T) { - h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4), - big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)}) + h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4), + utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)}) assert.Nil(t, err) - h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9), - big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}) + h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9), + utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)}) assert.Nil(t, err) assert.NotEqual(t, h1, h2) } func TestPoseidonBrokenPadding(t *testing.T) { - h1, err := Hash([]*big.Int{big.NewInt(1)}) + h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)}) assert.Nil(t, err) - h2, err := Hash([]*big.Int{big.NewInt(1), big.NewInt(0)}) + h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)}) assert.Nil(t, err) assert.NotEqual(t, h1, h2) } func BenchmarkPoseidon(b *testing.B) { - b12 := big.NewInt(int64(12)) - b45 := big.NewInt(int64(45)) - b78 := big.NewInt(int64(78)) - b41 := big.NewInt(int64(41)) - bigArray4 := []*big.Int{b12, b45, b78, b41} + b12 := utils.NewElement().SetUint64(12) + b45 := utils.NewElement().SetUint64(45) + b78 := utils.NewElement().SetUint64(78) + b41 := utils.NewElement().SetUint64(41) + bigArray4 := []*ff.Element{b12, b45, b78, b41} + + for i := 0; i < b.N; i++ { + Hash(bigArray4) + } +} + +func BenchmarkPoseidonLarge(b *testing.B) { + b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + bigArray4 := []*ff.Element{b12, b45, b78, b41} for i := 0; i < b.N; i++ { Hash(bigArray4) From 2a3f0d9ed58364ce0936739ea936c3bf763941fa Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 3 Mar 2020 16:32:49 +0100 Subject: [PATCH 4/4] Adapt babyjub/eddsa to new Poseidon methods --- babyjub/eddsa.go | 2 ++ ff/util.go | 5 +++ poseidon/poseidon.go | 47 ++++++++++++++------------ poseidon/poseidon_test.go | 70 +++++++++++++++++++++++---------------- utils/utils.go | 18 +++------- 5 files changed, 79 insertions(+), 63 deletions(-) create mode 100644 ff/util.go diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go index 5870213..3093a4a 100644 --- a/babyjub/eddsa.go +++ b/babyjub/eddsa.go @@ -222,11 +222,13 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { r.Mod(r, SubOrder) R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B A := k.Public().Point() + hmInput := [poseidon.T]*big.Int{R8.X, R8.Y, A.X, A.Y, msg, big.NewInt(int64(0))} hm, err := poseidon.PoseidonHash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { panic(err) } + S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) S = S.Mul(hm, S) S.Add(r, S) diff --git a/ff/util.go b/ff/util.go new file mode 100644 index 0000000..501b6ae --- /dev/null +++ b/ff/util.go @@ -0,0 +1,5 @@ +package ff + +func NewElement() *Element { + return &Element{} +} diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index 79cd651..a5dc7b0 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -5,6 +5,7 @@ import ( "math/big" "strconv" + "github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "golang.org/x/crypto/blake2b" @@ -19,7 +20,11 @@ var constC []*ff.Element var constM [T][T]*ff.Element func Zero() *ff.Element { - return utils.NewElement().SetZero() + return ff.NewElement().SetZero() +} + +func modQ(v *big.Int) { + v.Mod(v, constants.Q) } func init() { @@ -32,7 +37,7 @@ func getPseudoRandom(seed string, n int) []*ff.Element { hash := blake2b.Sum256([]byte(seed)) for i := 0; i < n; i++ { hashBigInt := big.NewInt(int64(0)) - res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:])) + res[i] = ff.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:])) hash = blake2b.Sum256(hash[:]) } return res @@ -57,7 +62,7 @@ func getMDS() [T][T]*ff.Element { var m [T][T]*ff.Element for i := 0; i < T; i++ { for j := 0; j < T; j++ { - m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j]) + m[i][j] = ff.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j]) m[i][j].Inverse(m[i][j]) } } @@ -66,7 +71,7 @@ func getMDS() [T][T]*ff.Element { func checkAllDifferent(v []*ff.Element) bool { for i := 0; i < len(v); i++ { - if v[i].Equal(utils.NewElement().SetZero()) { + if v[i].Equal(ff.NewElement().SetZero()) { return false } for j := i + 1; j < len(v); j++ { @@ -117,13 +122,14 @@ func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) { } // PoseidonHash computes the Poseidon hash for the given inputs -func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) { - if !utils.CheckElementArrayInField(inp[:]) { +func PoseidonHash(inpBI [T]*big.Int) (*big.Int, error) { + if !utils.CheckBigIntArrayInField(inpBI[:]) { return nil, errors.New("inputs values not inside Finite Field") } + inp := utils.BigIntArrayToElementArray(inpBI[:]) state := [T]*ff.Element{} for i := 0; i < T; i++ { - state[i] = utils.NewElement().Set(inp[i]) + state[i] = ff.NewElement().Set(inp[i]) } // ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5 @@ -137,19 +143,18 @@ func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) { mix(state, newState, constM) state, newState = newState, state } - return state[0], nil + rE := state[0] + r := big.NewInt(0) + rE.ToBigIntRegular(r) + return r, nil } // Hash performs the Poseidon hash over a ff.Element array // in chunks of 5 elements -func Hash(arr []*ff.Element) (*ff.Element, error) { - if !utils.CheckElementArrayInField(arr) { - return nil, errors.New("inputs values not inside Finite Field") - } - - r := utils.NewElement().SetOne() +func Hash(arr []*big.Int) (*big.Int, error) { + r := big.NewInt(int64(1)) for i := 0; i < len(arr); i = i + T - 1 { - var toHash [T]*ff.Element + var toHash [T]*big.Int j := 0 for ; j < T-1; j++ { if i+j >= len(arr) { @@ -160,14 +165,14 @@ func Hash(arr []*ff.Element) (*ff.Element, error) { toHash[j] = r j++ for ; j < T; j++ { - toHash[j] = Zero() + toHash[j] = big.NewInt(0) } ph, err := PoseidonHash(toHash) if err != nil { return nil, err } - r.Add(r, ph) + modQ(r.Add(r, ph)) } return r, nil @@ -175,19 +180,19 @@ func Hash(arr []*ff.Element) (*ff.Element, error) { // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // little-endian -func HashBytes(b []byte) (*ff.Element, error) { +func HashBytes(b []byte) (*big.Int, error) { n := 31 - bElems := make([]*ff.Element, 0, len(b)/n+1) + bElems := make([]*big.Int, 0, len(b)/n+1) for i := 0; i < len(b)/n; i++ { v := big.NewInt(int64(0)) utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)]) - bElems = append(bElems, utils.NewElement().SetBigInt(v)) + bElems = append(bElems, v) } if len(b)%n != 0 { v := big.NewInt(int64(0)) utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:]) - bElems = append(bElems, utils.NewElement().SetBigInt(v)) + bElems = append(bElems, v) } return Hash(bElems) } diff --git a/poseidon/poseidon_test.go b/poseidon/poseidon_test.go index b6791ef..59016cd 100644 --- a/poseidon/poseidon_test.go +++ b/poseidon/poseidon_test.go @@ -5,7 +5,6 @@ import ( "math/big" "testing" - "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "github.com/stretchr/testify/assert" "golang.org/x/crypto/blake2b" @@ -17,46 +16,58 @@ func TestBlake2bVersion(t *testing.T) { } func TestPoseidon(t *testing.T) { - b1 := utils.NewElement().SetUint64(1) - b2 := utils.NewElement().SetUint64(2) - h, err := Hash([]*ff.Element{b1, b2}) + b1 := big.NewInt(1) + b2 := big.NewInt(2) + h, err := Hash([]*big.Int{b1, b2}) assert.Nil(t, err) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String()) - b3 := utils.NewElement().SetUint64(3) - b4 := utils.NewElement().SetUint64(4) - h, err = Hash([]*ff.Element{b3, b4}) + b3 := big.NewInt(3) + b4 := big.NewInt(4) + h, err = Hash([]*big.Int{b3, b4}) assert.Nil(t, err) assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String()) + b5 := big.NewInt(5) + b6 := big.NewInt(6) + b7 := big.NewInt(7) + b8 := big.NewInt(8) + b9 := big.NewInt(9) + b10 := big.NewInt(10) + b11 := big.NewInt(11) + b12 := big.NewInt(12) + h, err = Hash([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12}) + assert.Nil(t, err) + assert.Equal(t, "15278801138972282646981503374384603641625274360649669926363020545395022098027", h.String()) + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") n := 31 - msgElems := make([]*ff.Element, 0, len(msg)/n+1) + msgElems := make([]*big.Int, 0, len(msg)/n+1) for i := 0; i < len(msg)/n; i++ { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)]) - msgElems = append(msgElems, utils.NewElement().SetBigInt(v)) + msgElems = append(msgElems, v) } if len(msg)%n != 0 { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:]) - msgElems = append(msgElems, utils.NewElement().SetBigInt(v)) + msgElems = append(msgElems, v) } hmsg, err := Hash(msgElems) assert.Nil(t, err) assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String()) msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.") - msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1) + msg2Elems := make([]*big.Int, 0, len(msg2)/n+1) for i := 0; i < len(msg2)/n; i++ { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)]) - msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v)) + msg2Elems = append(msg2Elems, v) } if len(msg2)%n != 0 { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:]) - msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v)) + msg2Elems = append(msg2Elems, v) } hmsg2, err := Hash(msg2Elems) assert.Nil(t, err) @@ -68,29 +79,29 @@ func TestPoseidon(t *testing.T) { } func TestPoseidonBrokenChunks(t *testing.T) { - h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4), - utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)}) + h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4), + big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)}) assert.Nil(t, err) - h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9), - utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)}) + h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9), + big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}) assert.Nil(t, err) assert.NotEqual(t, h1, h2) } func TestPoseidonBrokenPadding(t *testing.T) { - h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)}) + h1, err := Hash([]*big.Int{big.NewInt(int64(1))}) assert.Nil(t, err) - h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)}) + h2, err := Hash([]*big.Int{big.NewInt(int64(1)), big.NewInt(int64(0))}) assert.Nil(t, err) assert.NotEqual(t, h1, h2) } func BenchmarkPoseidon(b *testing.B) { - b12 := utils.NewElement().SetUint64(12) - b45 := utils.NewElement().SetUint64(45) - b78 := utils.NewElement().SetUint64(78) - b41 := utils.NewElement().SetUint64(41) - bigArray4 := []*ff.Element{b12, b45, b78, b41} + b12 := big.NewInt(int64(12)) + b45 := big.NewInt(int64(45)) + b78 := big.NewInt(int64(78)) + b41 := big.NewInt(int64(41)) + bigArray4 := []*big.Int{b12, b45, b78, b41} for i := 0; i < b.N; i++ { Hash(bigArray4) @@ -98,11 +109,12 @@ func BenchmarkPoseidon(b *testing.B) { } func BenchmarkPoseidonLarge(b *testing.B) { - b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") - b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") - b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") - b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") - bigArray4 := []*ff.Element{b12, b45, b78, b41} + b12 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b45 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b78 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b41 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + + bigArray4 := []*big.Int{b12, b45, b78, b41} for i := 0; i < b.N; i++ { Hash(bigArray4) diff --git a/utils/utils.go b/utils/utils.go index 188d0ff..84628cc 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -108,18 +108,10 @@ func CheckBigIntArrayInField(arr []*big.Int) bool { return true } -// CheckElementArrayInField checks if given *ff.Element fits in a Field Q element -func CheckElementArrayInField(arr []*ff.Element) bool { - for _, aE := range arr { - a := big.NewInt(0) - aE.ToBigIntRegular(a) - if !CheckBigIntInField(a) { - return false - } +func BigIntArrayToElementArray(bi []*big.Int) []*ff.Element { + var o []*ff.Element + for i := range bi { + o = append(o, ff.NewElement().SetBigInt(bi[i])) } - return true -} - -func NewElement() *ff.Element { - return &ff.Element{0, 0, 0, 0} + return o }