Browse Source

Optimize Poseidon migrating from *big.Int to goff

Optimize Poseidon migrating from *big.Int to goff generated finite field
operations.

Benchmarks:
Tested on a Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz, with 16GB of RAM.

- Before the optimizations:
```
BenchmarkPoseidon-4                  470           2489678 ns/op
BenchmarkPoseidonLarge-4             476           2530568 ns/op
```

- With the optimizations of #12:
```
BenchmarkPoseidon-4                  766           1550013 ns/op
BenchmarkPoseidonLarge-4             782           1547572 ns/op
```

- With the changes of this PR, where uses goff generated code instead of *big.Int:
```
BenchmarkPoseidon-4                 9638            121651 ns/op
BenchmarkPoseidonLarge-4            9781            119921 ns/op
```
feature/poseidon-opt-goff
arnaucube 4 years ago
parent
commit
b45d8a582b
2 changed files with 80 additions and 83 deletions
  1. +44
    -60
      poseidon/poseidon.go
  2. +36
    -23
      poseidon/poseidon_test.go

+ 44
- 60
poseidon/poseidon.go

@ -1,12 +1,11 @@
package poseidon package poseidon
import ( import (
"bytes"
"errors" "errors"
"math/big" "math/big"
"strconv" "strconv"
"github.com/iden3/go-iden3-crypto/constants"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-iden3-crypto/utils"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
) )
@ -16,15 +15,11 @@ const NROUNDSF = 8
const NROUNDSP = 57 const NROUNDSP = 57
const T = 6 const T = 6
var constC []*big.Int
var constM [T][T]*big.Int
var constC []*ff.Element
var constM [T][T]*ff.Element
func Zero() *big.Int {
return new(big.Int)
}
func modQ(v *big.Int) {
v.Mod(v, constants.Q)
func Zero() *ff.Element {
return utils.NewElement().SetZero()
} }
func init() { func init() {
@ -32,22 +27,12 @@ func init() {
constM = getMDS() constM = getMDS()
} }
func leByteArrayToBigInt(b []byte) *big.Int {
res := big.NewInt(0)
for i := 0; i < len(b); i++ {
n := big.NewInt(int64(b[i]))
res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8)))
}
return res
}
func getPseudoRandom(seed string, n int) []*big.Int {
res := make([]*big.Int, n)
func getPseudoRandom(seed string, n int) []*ff.Element {
res := make([]*ff.Element, n)
hash := blake2b.Sum256([]byte(seed)) hash := blake2b.Sum256([]byte(seed))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
hashBigInt := Zero()
res[i] = utils.SetBigIntFromLEBytes(hashBigInt, hash[:])
modQ(res[i])
hashBigInt := big.NewInt(int64(0))
res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:]))
hash = blake2b.Sum256(hash[:]) hash = blake2b.Sum256(hash[:])
} }
return res return res
@ -62,31 +47,30 @@ func nonceToString(n int) string {
} }
// https://eprint.iacr.org/2019/458.pdf pag.8 // https://eprint.iacr.org/2019/458.pdf pag.8
func getMDS() [T][T]*big.Int {
func getMDS() [T][T]*ff.Element {
nonce := 0 nonce := 0
cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2)
for !checkAllDifferent(cauchyMatrix) { for !checkAllDifferent(cauchyMatrix) {
nonce += 1 nonce += 1
cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2)
} }
var m [T][T]*big.Int
var m [T][T]*ff.Element
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
// var mi []*big.Int
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
m[i][j] = new(big.Int).Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j].ModInverse(m[i][j], constants.Q)
m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j].Inverse(m[i][j])
} }
} }
return m return m
} }
func checkAllDifferent(v []*big.Int) bool {
func checkAllDifferent(v []*ff.Element) bool {
for i := 0; i < len(v); i++ { for i := 0; i < len(v); i++ {
if bytes.Equal(v[i].Bytes(), big.NewInt(int64(0)).Bytes()) {
if v[i].Equal(utils.NewElement().SetZero()) {
return false return false
} }
for j := i + 1; j < len(v); j++ { for j := i + 1; j < len(v); j++ {
if bytes.Equal(v[i].Bytes(), v[j].Bytes()) {
if v[i].Equal(v[j]) {
return false return false
} }
} }
@ -95,22 +79,22 @@ func checkAllDifferent(v []*big.Int) bool {
} }
// ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf // ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf
func ark(state [T]*big.Int, c *big.Int) {
func ark(state [T]*ff.Element, c *ff.Element) {
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
modQ(state[i].Add(state[i], c))
state[i].Add(state[i], c)
} }
} }
// cubic performs x^5 mod p // cubic performs x^5 mod p
// https://eprint.iacr.org/2019/458.pdf page 8 // https://eprint.iacr.org/2019/458.pdf page 8
var five = big.NewInt(5)
// var five = big.NewInt(5)
func cubic(a *big.Int) {
a.Exp(a, five, constants.Q)
func cubic(a *ff.Element) {
a.Exp(*a, 5)
} }
// sbox https://eprint.iacr.org/2019/458.pdf page 6 // sbox https://eprint.iacr.org/2019/458.pdf page 6
func sbox(state [T]*big.Int, i int) {
func sbox(state [T]*ff.Element, i int) {
if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) { if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) {
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
cubic(state[j]) cubic(state[j])
@ -121,30 +105,29 @@ func sbox(state [T]*big.Int, i int) {
} }
// mix returns [[matrix]] * [vector] // mix returns [[matrix]] * [vector]
func mix(state [T]*big.Int, newState [T]*big.Int, m [T][T]*big.Int) {
func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) {
mul := Zero() mul := Zero()
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
newState[i].SetInt64(0)
newState[i].SetUint64(0)
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
modQ(mul.Mul(m[i][j], state[j]))
mul.Mul(m[i][j], state[j])
newState[i].Add(newState[i], mul) newState[i].Add(newState[i], mul)
} }
modQ(newState[i])
} }
} }
// PoseidonHash computes the Poseidon hash for the given inputs // PoseidonHash computes the Poseidon hash for the given inputs
func PoseidonHash(inp [T]*big.Int) (*big.Int, error) {
if !utils.CheckBigIntArrayInField(inp[:], constants.Q) {
func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) {
if !utils.CheckElementArrayInField(inp[:]) {
return nil, errors.New("inputs values not inside Finite Field") return nil, errors.New("inputs values not inside Finite Field")
} }
state := [T]*big.Int{}
state := [T]*ff.Element{}
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
state[i] = new(big.Int).Set(inp[i])
state[i] = utils.NewElement().Set(inp[i])
} }
// ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5 // ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5
var newState [T]*big.Int
var newState [T]*ff.Element
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
newState[i] = Zero() newState[i] = Zero()
} }
@ -157,16 +140,16 @@ func PoseidonHash(inp [T]*big.Int) (*big.Int, error) {
return state[0], nil return state[0], nil
} }
// Hash performs the Poseidon hash over a *big.Int array
// Hash performs the Poseidon hash over a ff.Element array
// in chunks of 5 elements // in chunks of 5 elements
func Hash(arr []*big.Int) (*big.Int, error) {
if !utils.CheckBigIntArrayInField(arr, constants.Q) {
func Hash(arr []*ff.Element) (*ff.Element, error) {
if !utils.CheckElementArrayInField(arr) {
return nil, errors.New("inputs values not inside Finite Field") return nil, errors.New("inputs values not inside Finite Field")
} }
r := big.NewInt(1)
r := utils.NewElement().SetOne()
for i := 0; i < len(arr); i = i + T - 1 { for i := 0; i < len(arr); i = i + T - 1 {
var toHash [T]*big.Int
var toHash [T]*ff.Element
j := 0 j := 0
for ; j < T-1; j++ { for ; j < T-1; j++ {
if i+j >= len(arr) { if i+j >= len(arr) {
@ -177,14 +160,14 @@ func Hash(arr []*big.Int) (*big.Int, error) {
toHash[j] = r toHash[j] = r
j++ j++
for ; j < T; j++ { for ; j < T; j++ {
toHash[j] = constants.Zero
toHash[j] = Zero()
} }
ph, err := PoseidonHash(toHash) ph, err := PoseidonHash(toHash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modQ(r.Add(r, ph))
r.Add(r, ph)
} }
return r, nil return r, nil
@ -192,18 +175,19 @@ func Hash(arr []*big.Int) (*big.Int, error) {
// HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as
// little-endian // little-endian
func HashBytes(b []byte) (*big.Int, error) {
func HashBytes(b []byte) (*ff.Element, error) {
n := 31 n := 31
bElems := make([]*big.Int, 0, len(b)/n+1)
bElems := make([]*ff.Element, 0, len(b)/n+1)
for i := 0; i < len(b)/n; i++ { for i := 0; i < len(b)/n; i++ {
v := Zero()
v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)])
bElems = append(bElems, v)
bElems = append(bElems, utils.NewElement().SetBigInt(v))
} }
if len(b)%n != 0 { if len(b)%n != 0 {
v := Zero()
v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:]) utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:])
bElems = append(bElems, v)
bElems = append(bElems, utils.NewElement().SetBigInt(v))
} }
return Hash(bElems) return Hash(bElems)
} }

+ 36
- 23
poseidon/poseidon_test.go

@ -5,6 +5,7 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-iden3-crypto/utils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
@ -16,46 +17,46 @@ func TestBlake2bVersion(t *testing.T) {
} }
func TestPoseidon(t *testing.T) { func TestPoseidon(t *testing.T) {
b1 := big.NewInt(int64(1))
b2 := big.NewInt(int64(2))
h, err := Hash([]*big.Int{b1, b2})
b1 := utils.NewElement().SetUint64(1)
b2 := utils.NewElement().SetUint64(2)
h, err := Hash([]*ff.Element{b1, b2})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String()) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String())
b3 := big.NewInt(int64(3))
b4 := big.NewInt(int64(4))
h, err = Hash([]*big.Int{b3, b4})
b3 := utils.NewElement().SetUint64(3)
b4 := utils.NewElement().SetUint64(4)
h, err = Hash([]*ff.Element{b3, b4})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String()) assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String())
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.")
n := 31 n := 31
msgElems := make([]*big.Int, 0, len(msg)/n+1)
msgElems := make([]*ff.Element, 0, len(msg)/n+1)
for i := 0; i < len(msg)/n; i++ { for i := 0; i < len(msg)/n; i++ {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)])
msgElems = append(msgElems, v)
msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
} }
if len(msg)%n != 0 { if len(msg)%n != 0 {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:]) utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:])
msgElems = append(msgElems, v)
msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
} }
hmsg, err := Hash(msgElems) hmsg, err := Hash(msgElems)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String()) assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String())
msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.") msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.")
msg2Elems := make([]*big.Int, 0, len(msg2)/n+1)
msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1)
for i := 0; i < len(msg2)/n; i++ { for i := 0; i < len(msg2)/n; i++ {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)])
msg2Elems = append(msg2Elems, v)
msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
} }
if len(msg2)%n != 0 { if len(msg2)%n != 0 {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:]) utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:])
msg2Elems = append(msg2Elems, v)
msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
} }
hmsg2, err := Hash(msg2Elems) hmsg2, err := Hash(msg2Elems)
assert.Nil(t, err) assert.Nil(t, err)
@ -67,29 +68,41 @@ func TestPoseidon(t *testing.T) {
} }
func TestPoseidonBrokenChunks(t *testing.T) { func TestPoseidonBrokenChunks(t *testing.T) {
h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4),
big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)})
h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4),
utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)})
assert.Nil(t, err) assert.Nil(t, err)
h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9),
big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)})
h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9),
utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func TestPoseidonBrokenPadding(t *testing.T) { func TestPoseidonBrokenPadding(t *testing.T) {
h1, err := Hash([]*big.Int{big.NewInt(1)})
h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)})
assert.Nil(t, err) assert.Nil(t, err)
h2, err := Hash([]*big.Int{big.NewInt(1), big.NewInt(0)})
h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func BenchmarkPoseidon(b *testing.B) { func BenchmarkPoseidon(b *testing.B) {
b12 := big.NewInt(int64(12))
b45 := big.NewInt(int64(45))
b78 := big.NewInt(int64(78))
b41 := big.NewInt(int64(41))
bigArray4 := []*big.Int{b12, b45, b78, b41}
b12 := utils.NewElement().SetUint64(12)
b45 := utils.NewElement().SetUint64(45)
b78 := utils.NewElement().SetUint64(78)
b41 := utils.NewElement().SetUint64(41)
bigArray4 := []*ff.Element{b12, b45, b78, b41}
for i := 0; i < b.N; i++ {
Hash(bigArray4)
}
}
func BenchmarkPoseidonLarge(b *testing.B) {
b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
bigArray4 := []*ff.Element{b12, b45, b78, b41}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Hash(bigArray4) Hash(bigArray4)

Loading…
Cancel
Save