Browse Source

Optimize Poseidon migrating from *big.Int to goff

Optimize Poseidon migrating from *big.Int to goff generated finite field
operations.

Benchmarks:
Tested on a Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz, with 16GB of RAM.

- Before the optimizations:
```
BenchmarkPoseidon-4                  470           2489678 ns/op
BenchmarkPoseidonLarge-4             476           2530568 ns/op
```

- With the optimizations of #12:
```
BenchmarkPoseidon-4                  766           1550013 ns/op
BenchmarkPoseidonLarge-4             782           1547572 ns/op
```

- With the changes of this PR, where uses goff generated code instead of *big.Int:
```
BenchmarkPoseidon-4                 9638            121651 ns/op
BenchmarkPoseidonLarge-4            9781            119921 ns/op
```
feature/poseidon-opt-goff
arnaucube 4 years ago
parent
commit
b45d8a582b
2 changed files with 80 additions and 83 deletions
  1. +44
    -60
      poseidon/poseidon.go
  2. +36
    -23
      poseidon/poseidon_test.go

+ 44
- 60
poseidon/poseidon.go

@ -1,12 +1,11 @@
package poseidon
import (
"bytes"
"errors"
"math/big"
"strconv"
"github.com/iden3/go-iden3-crypto/constants"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils"
"golang.org/x/crypto/blake2b"
)
@ -16,15 +15,11 @@ const NROUNDSF = 8
const NROUNDSP = 57
const T = 6
var constC []*big.Int
var constM [T][T]*big.Int
var constC []*ff.Element
var constM [T][T]*ff.Element
func Zero() *big.Int {
return new(big.Int)
}
func modQ(v *big.Int) {
v.Mod(v, constants.Q)
func Zero() *ff.Element {
return utils.NewElement().SetZero()
}
func init() {
@ -32,22 +27,12 @@ func init() {
constM = getMDS()
}
func leByteArrayToBigInt(b []byte) *big.Int {
res := big.NewInt(0)
for i := 0; i < len(b); i++ {
n := big.NewInt(int64(b[i]))
res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8)))
}
return res
}
func getPseudoRandom(seed string, n int) []*big.Int {
res := make([]*big.Int, n)
func getPseudoRandom(seed string, n int) []*ff.Element {
res := make([]*ff.Element, n)
hash := blake2b.Sum256([]byte(seed))
for i := 0; i < n; i++ {
hashBigInt := Zero()
res[i] = utils.SetBigIntFromLEBytes(hashBigInt, hash[:])
modQ(res[i])
hashBigInt := big.NewInt(int64(0))
res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:]))
hash = blake2b.Sum256(hash[:])
}
return res
@ -62,31 +47,30 @@ func nonceToString(n int) string {
}
// https://eprint.iacr.org/2019/458.pdf pag.8
func getMDS() [T][T]*big.Int {
func getMDS() [T][T]*ff.Element {
nonce := 0
cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2)
for !checkAllDifferent(cauchyMatrix) {
nonce += 1
cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2)
}
var m [T][T]*big.Int
var m [T][T]*ff.Element
for i := 0; i < T; i++ {
// var mi []*big.Int
for j := 0; j < T; j++ {
m[i][j] = new(big.Int).Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j].ModInverse(m[i][j], constants.Q)
m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j].Inverse(m[i][j])
}
}
return m
}
func checkAllDifferent(v []*big.Int) bool {
func checkAllDifferent(v []*ff.Element) bool {
for i := 0; i < len(v); i++ {
if bytes.Equal(v[i].Bytes(), big.NewInt(int64(0)).Bytes()) {
if v[i].Equal(utils.NewElement().SetZero()) {
return false
}
for j := i + 1; j < len(v); j++ {
if bytes.Equal(v[i].Bytes(), v[j].Bytes()) {
if v[i].Equal(v[j]) {
return false
}
}
@ -95,22 +79,22 @@ func checkAllDifferent(v []*big.Int) bool {
}
// ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf
func ark(state [T]*big.Int, c *big.Int) {
func ark(state [T]*ff.Element, c *ff.Element) {
for i := 0; i < T; i++ {
modQ(state[i].Add(state[i], c))
state[i].Add(state[i], c)
}
}
// cubic performs x^5 mod p
// https://eprint.iacr.org/2019/458.pdf page 8
var five = big.NewInt(5)
// var five = big.NewInt(5)
func cubic(a *big.Int) {
a.Exp(a, five, constants.Q)
func cubic(a *ff.Element) {
a.Exp(*a, 5)
}
// sbox https://eprint.iacr.org/2019/458.pdf page 6
func sbox(state [T]*big.Int, i int) {
func sbox(state [T]*ff.Element, i int) {
if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) {
for j := 0; j < T; j++ {
cubic(state[j])
@ -121,30 +105,29 @@ func sbox(state [T]*big.Int, i int) {
}
// mix returns [[matrix]] * [vector]
func mix(state [T]*big.Int, newState [T]*big.Int, m [T][T]*big.Int) {
func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) {
mul := Zero()
for i := 0; i < T; i++ {
newState[i].SetInt64(0)
newState[i].SetUint64(0)
for j := 0; j < T; j++ {
modQ(mul.Mul(m[i][j], state[j]))
mul.Mul(m[i][j], state[j])
newState[i].Add(newState[i], mul)
}
modQ(newState[i])
}
}
// PoseidonHash computes the Poseidon hash for the given inputs
func PoseidonHash(inp [T]*big.Int) (*big.Int, error) {
if !utils.CheckBigIntArrayInField(inp[:], constants.Q) {
func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) {
if !utils.CheckElementArrayInField(inp[:]) {
return nil, errors.New("inputs values not inside Finite Field")
}
state := [T]*big.Int{}
state := [T]*ff.Element{}
for i := 0; i < T; i++ {
state[i] = new(big.Int).Set(inp[i])
state[i] = utils.NewElement().Set(inp[i])
}
// ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5
var newState [T]*big.Int
var newState [T]*ff.Element
for i := 0; i < T; i++ {
newState[i] = Zero()
}
@ -157,16 +140,16 @@ func PoseidonHash(inp [T]*big.Int) (*big.Int, error) {
return state[0], nil
}
// Hash performs the Poseidon hash over a *big.Int array
// Hash performs the Poseidon hash over a ff.Element array
// in chunks of 5 elements
func Hash(arr []*big.Int) (*big.Int, error) {
if !utils.CheckBigIntArrayInField(arr, constants.Q) {
func Hash(arr []*ff.Element) (*ff.Element, error) {
if !utils.CheckElementArrayInField(arr) {
return nil, errors.New("inputs values not inside Finite Field")
}
r := big.NewInt(1)
r := utils.NewElement().SetOne()
for i := 0; i < len(arr); i = i + T - 1 {
var toHash [T]*big.Int
var toHash [T]*ff.Element
j := 0
for ; j < T-1; j++ {
if i+j >= len(arr) {
@ -177,14 +160,14 @@ func Hash(arr []*big.Int) (*big.Int, error) {
toHash[j] = r
j++
for ; j < T; j++ {
toHash[j] = constants.Zero
toHash[j] = Zero()
}
ph, err := PoseidonHash(toHash)
if err != nil {
return nil, err
}
modQ(r.Add(r, ph))
r.Add(r, ph)
}
return r, nil
@ -192,18 +175,19 @@ func Hash(arr []*big.Int) (*big.Int, error) {
// HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as
// little-endian
func HashBytes(b []byte) (*big.Int, error) {
func HashBytes(b []byte) (*ff.Element, error) {
n := 31
bElems := make([]*big.Int, 0, len(b)/n+1)
bElems := make([]*ff.Element, 0, len(b)/n+1)
for i := 0; i < len(b)/n; i++ {
v := Zero()
v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)])
bElems = append(bElems, v)
bElems = append(bElems, utils.NewElement().SetBigInt(v))
}
if len(b)%n != 0 {
v := Zero()
v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:])
bElems = append(bElems, v)
bElems = append(bElems, utils.NewElement().SetBigInt(v))
}
return Hash(bElems)
}

+ 36
- 23
poseidon/poseidon_test.go

@ -5,6 +5,7 @@ import (
"math/big"
"testing"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/blake2b"
@ -16,46 +17,46 @@ func TestBlake2bVersion(t *testing.T) {
}
func TestPoseidon(t *testing.T) {
b1 := big.NewInt(int64(1))
b2 := big.NewInt(int64(2))
h, err := Hash([]*big.Int{b1, b2})
b1 := utils.NewElement().SetUint64(1)
b2 := utils.NewElement().SetUint64(2)
h, err := Hash([]*ff.Element{b1, b2})
assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String())
b3 := big.NewInt(int64(3))
b4 := big.NewInt(int64(4))
h, err = Hash([]*big.Int{b3, b4})
b3 := utils.NewElement().SetUint64(3)
b4 := utils.NewElement().SetUint64(4)
h, err = Hash([]*ff.Element{b3, b4})
assert.Nil(t, err)
assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String())
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.")
n := 31
msgElems := make([]*big.Int, 0, len(msg)/n+1)
msgElems := make([]*ff.Element, 0, len(msg)/n+1)
for i := 0; i < len(msg)/n; i++ {
v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)])
msgElems = append(msgElems, v)
msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
}
if len(msg)%n != 0 {
v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:])
msgElems = append(msgElems, v)
msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
}
hmsg, err := Hash(msgElems)
assert.Nil(t, err)
assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String())
msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.")
msg2Elems := make([]*big.Int, 0, len(msg2)/n+1)
msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1)
for i := 0; i < len(msg2)/n; i++ {
v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)])
msg2Elems = append(msg2Elems, v)
msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
}
if len(msg2)%n != 0 {
v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:])
msg2Elems = append(msg2Elems, v)
msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
}
hmsg2, err := Hash(msg2Elems)
assert.Nil(t, err)
@ -67,29 +68,41 @@ func TestPoseidon(t *testing.T) {
}
func TestPoseidonBrokenChunks(t *testing.T) {
h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4),
big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)})
h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4),
utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)})
assert.Nil(t, err)
h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9),
big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)})
h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9),
utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)})
assert.Nil(t, err)
assert.NotEqual(t, h1, h2)
}
func TestPoseidonBrokenPadding(t *testing.T) {
h1, err := Hash([]*big.Int{big.NewInt(1)})
h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)})
assert.Nil(t, err)
h2, err := Hash([]*big.Int{big.NewInt(1), big.NewInt(0)})
h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)})
assert.Nil(t, err)
assert.NotEqual(t, h1, h2)
}
func BenchmarkPoseidon(b *testing.B) {
b12 := big.NewInt(int64(12))
b45 := big.NewInt(int64(45))
b78 := big.NewInt(int64(78))
b41 := big.NewInt(int64(41))
bigArray4 := []*big.Int{b12, b45, b78, b41}
b12 := utils.NewElement().SetUint64(12)
b45 := utils.NewElement().SetUint64(45)
b78 := utils.NewElement().SetUint64(78)
b41 := utils.NewElement().SetUint64(41)
bigArray4 := []*ff.Element{b12, b45, b78, b41}
for i := 0; i < b.N; i++ {
Hash(bigArray4)
}
}
func BenchmarkPoseidonLarge(b *testing.B) {
b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
bigArray4 := []*ff.Element{b12, b45, b78, b41}
for i := 0; i < b.N; i++ {
Hash(bigArray4)

Loading…
Cancel
Save