Optimize Poseidon migrating from *big.Int to goff

Optimize Poseidon migrating from *big.Int to goff generated finite field
operations.

Benchmarks:
Tested on a Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz, with 16GB of RAM.

- Before the optimizations:
```
BenchmarkPoseidon-4                  470           2489678 ns/op
BenchmarkPoseidonLarge-4             476           2530568 ns/op
```

- With the optimizations of #12:
```
BenchmarkPoseidon-4                  766           1550013 ns/op
BenchmarkPoseidonLarge-4             782           1547572 ns/op
```

- With the changes of this PR, where uses goff generated code instead of *big.Int:
```
BenchmarkPoseidon-4                 9638            121651 ns/op
BenchmarkPoseidonLarge-4            9781            119921 ns/op
```
This commit is contained in:
arnaucube
2020-03-03 16:31:40 +01:00
parent 83f87bfa46
commit b45d8a582b
2 changed files with 80 additions and 83 deletions

View File

@@ -1,12 +1,11 @@
package poseidon package poseidon
import ( import (
"bytes"
"errors" "errors"
"math/big" "math/big"
"strconv" "strconv"
"github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-iden3-crypto/utils"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
) )
@@ -16,15 +15,11 @@ const NROUNDSF = 8
const NROUNDSP = 57 const NROUNDSP = 57
const T = 6 const T = 6
var constC []*big.Int var constC []*ff.Element
var constM [T][T]*big.Int var constM [T][T]*ff.Element
func Zero() *big.Int { func Zero() *ff.Element {
return new(big.Int) return utils.NewElement().SetZero()
}
func modQ(v *big.Int) {
v.Mod(v, constants.Q)
} }
func init() { func init() {
@@ -32,22 +27,12 @@ func init() {
constM = getMDS() constM = getMDS()
} }
func leByteArrayToBigInt(b []byte) *big.Int { func getPseudoRandom(seed string, n int) []*ff.Element {
res := big.NewInt(0) res := make([]*ff.Element, n)
for i := 0; i < len(b); i++ {
n := big.NewInt(int64(b[i]))
res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8)))
}
return res
}
func getPseudoRandom(seed string, n int) []*big.Int {
res := make([]*big.Int, n)
hash := blake2b.Sum256([]byte(seed)) hash := blake2b.Sum256([]byte(seed))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
hashBigInt := Zero() hashBigInt := big.NewInt(int64(0))
res[i] = utils.SetBigIntFromLEBytes(hashBigInt, hash[:]) res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:]))
modQ(res[i])
hash = blake2b.Sum256(hash[:]) hash = blake2b.Sum256(hash[:])
} }
return res return res
@@ -62,31 +47,30 @@ func nonceToString(n int) string {
} }
// https://eprint.iacr.org/2019/458.pdf pag.8 // https://eprint.iacr.org/2019/458.pdf pag.8
func getMDS() [T][T]*big.Int { func getMDS() [T][T]*ff.Element {
nonce := 0 nonce := 0
cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2)
for !checkAllDifferent(cauchyMatrix) { for !checkAllDifferent(cauchyMatrix) {
nonce += 1 nonce += 1
cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2)
} }
var m [T][T]*big.Int var m [T][T]*ff.Element
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
// var mi []*big.Int
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
m[i][j] = new(big.Int).Sub(cauchyMatrix[i], cauchyMatrix[T+j]) m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j].ModInverse(m[i][j], constants.Q) m[i][j].Inverse(m[i][j])
} }
} }
return m return m
} }
func checkAllDifferent(v []*big.Int) bool { func checkAllDifferent(v []*ff.Element) bool {
for i := 0; i < len(v); i++ { for i := 0; i < len(v); i++ {
if bytes.Equal(v[i].Bytes(), big.NewInt(int64(0)).Bytes()) { if v[i].Equal(utils.NewElement().SetZero()) {
return false return false
} }
for j := i + 1; j < len(v); j++ { for j := i + 1; j < len(v); j++ {
if bytes.Equal(v[i].Bytes(), v[j].Bytes()) { if v[i].Equal(v[j]) {
return false return false
} }
} }
@@ -95,22 +79,22 @@ func checkAllDifferent(v []*big.Int) bool {
} }
// ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf // ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf
func ark(state [T]*big.Int, c *big.Int) { func ark(state [T]*ff.Element, c *ff.Element) {
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
modQ(state[i].Add(state[i], c)) state[i].Add(state[i], c)
} }
} }
// cubic performs x^5 mod p // cubic performs x^5 mod p
// https://eprint.iacr.org/2019/458.pdf page 8 // https://eprint.iacr.org/2019/458.pdf page 8
var five = big.NewInt(5) // var five = big.NewInt(5)
func cubic(a *big.Int) { func cubic(a *ff.Element) {
a.Exp(a, five, constants.Q) a.Exp(*a, 5)
} }
// sbox https://eprint.iacr.org/2019/458.pdf page 6 // sbox https://eprint.iacr.org/2019/458.pdf page 6
func sbox(state [T]*big.Int, i int) { func sbox(state [T]*ff.Element, i int) {
if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) { if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) {
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
cubic(state[j]) cubic(state[j])
@@ -121,30 +105,29 @@ func sbox(state [T]*big.Int, i int) {
} }
// mix returns [[matrix]] * [vector] // mix returns [[matrix]] * [vector]
func mix(state [T]*big.Int, newState [T]*big.Int, m [T][T]*big.Int) { func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) {
mul := Zero() mul := Zero()
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
newState[i].SetInt64(0) newState[i].SetUint64(0)
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
modQ(mul.Mul(m[i][j], state[j])) mul.Mul(m[i][j], state[j])
newState[i].Add(newState[i], mul) newState[i].Add(newState[i], mul)
} }
modQ(newState[i])
} }
} }
// PoseidonHash computes the Poseidon hash for the given inputs // PoseidonHash computes the Poseidon hash for the given inputs
func PoseidonHash(inp [T]*big.Int) (*big.Int, error) { func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) {
if !utils.CheckBigIntArrayInField(inp[:], constants.Q) { if !utils.CheckElementArrayInField(inp[:]) {
return nil, errors.New("inputs values not inside Finite Field") return nil, errors.New("inputs values not inside Finite Field")
} }
state := [T]*big.Int{} state := [T]*ff.Element{}
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
state[i] = new(big.Int).Set(inp[i]) state[i] = utils.NewElement().Set(inp[i])
} }
// ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5 // ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5
var newState [T]*big.Int var newState [T]*ff.Element
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
newState[i] = Zero() newState[i] = Zero()
} }
@@ -157,16 +140,16 @@ func PoseidonHash(inp [T]*big.Int) (*big.Int, error) {
return state[0], nil return state[0], nil
} }
// Hash performs the Poseidon hash over a *big.Int array // Hash performs the Poseidon hash over a ff.Element array
// in chunks of 5 elements // in chunks of 5 elements
func Hash(arr []*big.Int) (*big.Int, error) { func Hash(arr []*ff.Element) (*ff.Element, error) {
if !utils.CheckBigIntArrayInField(arr, constants.Q) { if !utils.CheckElementArrayInField(arr) {
return nil, errors.New("inputs values not inside Finite Field") return nil, errors.New("inputs values not inside Finite Field")
} }
r := big.NewInt(1) r := utils.NewElement().SetOne()
for i := 0; i < len(arr); i = i + T - 1 { for i := 0; i < len(arr); i = i + T - 1 {
var toHash [T]*big.Int var toHash [T]*ff.Element
j := 0 j := 0
for ; j < T-1; j++ { for ; j < T-1; j++ {
if i+j >= len(arr) { if i+j >= len(arr) {
@@ -177,14 +160,14 @@ func Hash(arr []*big.Int) (*big.Int, error) {
toHash[j] = r toHash[j] = r
j++ j++
for ; j < T; j++ { for ; j < T; j++ {
toHash[j] = constants.Zero toHash[j] = Zero()
} }
ph, err := PoseidonHash(toHash) ph, err := PoseidonHash(toHash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modQ(r.Add(r, ph)) r.Add(r, ph)
} }
return r, nil return r, nil
@@ -192,18 +175,19 @@ func Hash(arr []*big.Int) (*big.Int, error) {
// HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as
// little-endian // little-endian
func HashBytes(b []byte) (*big.Int, error) { func HashBytes(b []byte) (*ff.Element, error) {
n := 31 n := 31
bElems := make([]*big.Int, 0, len(b)/n+1) bElems := make([]*ff.Element, 0, len(b)/n+1)
for i := 0; i < len(b)/n; i++ { for i := 0; i < len(b)/n; i++ {
v := Zero() v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)])
bElems = append(bElems, v) bElems = append(bElems, utils.NewElement().SetBigInt(v))
} }
if len(b)%n != 0 { if len(b)%n != 0 {
v := Zero() v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:]) utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:])
bElems = append(bElems, v) bElems = append(bElems, utils.NewElement().SetBigInt(v))
} }
return Hash(bElems) return Hash(bElems)
} }

View File

@@ -5,6 +5,7 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-iden3-crypto/utils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
@@ -16,46 +17,46 @@ func TestBlake2bVersion(t *testing.T) {
} }
func TestPoseidon(t *testing.T) { func TestPoseidon(t *testing.T) {
b1 := big.NewInt(int64(1)) b1 := utils.NewElement().SetUint64(1)
b2 := big.NewInt(int64(2)) b2 := utils.NewElement().SetUint64(2)
h, err := Hash([]*big.Int{b1, b2}) h, err := Hash([]*ff.Element{b1, b2})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String()) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String())
b3 := big.NewInt(int64(3)) b3 := utils.NewElement().SetUint64(3)
b4 := big.NewInt(int64(4)) b4 := utils.NewElement().SetUint64(4)
h, err = Hash([]*big.Int{b3, b4}) h, err = Hash([]*ff.Element{b3, b4})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String()) assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String())
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.")
n := 31 n := 31
msgElems := make([]*big.Int, 0, len(msg)/n+1) msgElems := make([]*ff.Element, 0, len(msg)/n+1)
for i := 0; i < len(msg)/n; i++ { for i := 0; i < len(msg)/n; i++ {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)])
msgElems = append(msgElems, v) msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
} }
if len(msg)%n != 0 { if len(msg)%n != 0 {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:]) utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:])
msgElems = append(msgElems, v) msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
} }
hmsg, err := Hash(msgElems) hmsg, err := Hash(msgElems)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String()) assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String())
msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.") msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.")
msg2Elems := make([]*big.Int, 0, len(msg2)/n+1) msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1)
for i := 0; i < len(msg2)/n; i++ { for i := 0; i < len(msg2)/n; i++ {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)])
msg2Elems = append(msg2Elems, v) msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
} }
if len(msg2)%n != 0 { if len(msg2)%n != 0 {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:]) utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:])
msg2Elems = append(msg2Elems, v) msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
} }
hmsg2, err := Hash(msg2Elems) hmsg2, err := Hash(msg2Elems)
assert.Nil(t, err) assert.Nil(t, err)
@@ -67,29 +68,41 @@ func TestPoseidon(t *testing.T) {
} }
func TestPoseidonBrokenChunks(t *testing.T) { func TestPoseidonBrokenChunks(t *testing.T) {
h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4), h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4),
big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)}) utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)})
assert.Nil(t, err) assert.Nil(t, err)
h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9), h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9),
big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}) utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func TestPoseidonBrokenPadding(t *testing.T) { func TestPoseidonBrokenPadding(t *testing.T) {
h1, err := Hash([]*big.Int{big.NewInt(1)}) h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)})
assert.Nil(t, err) assert.Nil(t, err)
h2, err := Hash([]*big.Int{big.NewInt(1), big.NewInt(0)}) h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func BenchmarkPoseidon(b *testing.B) { func BenchmarkPoseidon(b *testing.B) {
b12 := big.NewInt(int64(12)) b12 := utils.NewElement().SetUint64(12)
b45 := big.NewInt(int64(45)) b45 := utils.NewElement().SetUint64(45)
b78 := big.NewInt(int64(78)) b78 := utils.NewElement().SetUint64(78)
b41 := big.NewInt(int64(41)) b41 := utils.NewElement().SetUint64(41)
bigArray4 := []*big.Int{b12, b45, b78, b41} bigArray4 := []*ff.Element{b12, b45, b78, b41}
for i := 0; i < b.N; i++ {
Hash(bigArray4)
}
}
func BenchmarkPoseidonLarge(b *testing.B) {
b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
bigArray4 := []*ff.Element{b12, b45, b78, b41}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Hash(bigArray4) Hash(bigArray4)