From b45d8a582b2ec4bb798d7fd61698b3a1722c08ea Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 3 Mar 2020 16:31:40 +0100 Subject: [PATCH] Optimize Poseidon migrating from *big.Int to goff Optimize Poseidon migrating from *big.Int to goff generated finite field operations. Benchmarks: Tested on a Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz, with 16GB of RAM. - Before the optimizations: ``` BenchmarkPoseidon-4 470 2489678 ns/op BenchmarkPoseidonLarge-4 476 2530568 ns/op ``` - With the optimizations of #12: ``` BenchmarkPoseidon-4 766 1550013 ns/op BenchmarkPoseidonLarge-4 782 1547572 ns/op ``` - With the changes of this PR, where uses goff generated code instead of *big.Int: ``` BenchmarkPoseidon-4 9638 121651 ns/op BenchmarkPoseidonLarge-4 9781 119921 ns/op ``` --- poseidon/poseidon.go | 104 ++++++++++++++++---------------------- poseidon/poseidon_test.go | 59 ++++++++++++--------- 2 files changed, 80 insertions(+), 83 deletions(-) diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index d4e2b8b..79cd651 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -1,12 +1,11 @@ package poseidon import ( - "bytes" "errors" "math/big" "strconv" - "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "golang.org/x/crypto/blake2b" ) @@ -16,15 +15,11 @@ const NROUNDSF = 8 const NROUNDSP = 57 const T = 6 -var constC []*big.Int -var constM [T][T]*big.Int +var constC []*ff.Element +var constM [T][T]*ff.Element -func Zero() *big.Int { - return new(big.Int) -} - -func modQ(v *big.Int) { - v.Mod(v, constants.Q) +func Zero() *ff.Element { + return utils.NewElement().SetZero() } func init() { @@ -32,22 +27,12 @@ func init() { constM = getMDS() } -func leByteArrayToBigInt(b []byte) *big.Int { - res := big.NewInt(0) - for i := 0; i < len(b); i++ { - n := big.NewInt(int64(b[i])) - res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8))) - } - return res -} - -func getPseudoRandom(seed string, n int) []*big.Int { - res := make([]*big.Int, n) +func getPseudoRandom(seed string, n int) []*ff.Element { + res := make([]*ff.Element, n) hash := blake2b.Sum256([]byte(seed)) for i := 0; i < n; i++ { - hashBigInt := Zero() - res[i] = utils.SetBigIntFromLEBytes(hashBigInt, hash[:]) - modQ(res[i]) + hashBigInt := big.NewInt(int64(0)) + res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:])) hash = blake2b.Sum256(hash[:]) } return res @@ -62,31 +47,30 @@ func nonceToString(n int) string { } // https://eprint.iacr.org/2019/458.pdf pag.8 -func getMDS() [T][T]*big.Int { +func getMDS() [T][T]*ff.Element { nonce := 0 cauchyMatrix := getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) for !checkAllDifferent(cauchyMatrix) { nonce += 1 cauchyMatrix = getPseudoRandom(SEED+"_matrix_"+nonceToString(nonce), T*2) } - var m [T][T]*big.Int + var m [T][T]*ff.Element for i := 0; i < T; i++ { - // var mi []*big.Int for j := 0; j < T; j++ { - m[i][j] = new(big.Int).Sub(cauchyMatrix[i], cauchyMatrix[T+j]) - m[i][j].ModInverse(m[i][j], constants.Q) + m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j]) + m[i][j].Inverse(m[i][j]) } } return m } -func checkAllDifferent(v []*big.Int) bool { +func checkAllDifferent(v []*ff.Element) bool { for i := 0; i < len(v); i++ { - if bytes.Equal(v[i].Bytes(), big.NewInt(int64(0)).Bytes()) { + if v[i].Equal(utils.NewElement().SetZero()) { return false } for j := i + 1; j < len(v); j++ { - if bytes.Equal(v[i].Bytes(), v[j].Bytes()) { + if v[i].Equal(v[j]) { return false } } @@ -95,22 +79,22 @@ func checkAllDifferent(v []*big.Int) bool { } // ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf -func ark(state [T]*big.Int, c *big.Int) { +func ark(state [T]*ff.Element, c *ff.Element) { for i := 0; i < T; i++ { - modQ(state[i].Add(state[i], c)) + state[i].Add(state[i], c) } } // cubic performs x^5 mod p // https://eprint.iacr.org/2019/458.pdf page 8 -var five = big.NewInt(5) +// var five = big.NewInt(5) -func cubic(a *big.Int) { - a.Exp(a, five, constants.Q) +func cubic(a *ff.Element) { + a.Exp(*a, 5) } // sbox https://eprint.iacr.org/2019/458.pdf page 6 -func sbox(state [T]*big.Int, i int) { +func sbox(state [T]*ff.Element, i int) { if (i < NROUNDSF/2) || (i >= NROUNDSF/2+NROUNDSP) { for j := 0; j < T; j++ { cubic(state[j]) @@ -121,30 +105,29 @@ func sbox(state [T]*big.Int, i int) { } // mix returns [[matrix]] * [vector] -func mix(state [T]*big.Int, newState [T]*big.Int, m [T][T]*big.Int) { +func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) { mul := Zero() for i := 0; i < T; i++ { - newState[i].SetInt64(0) + newState[i].SetUint64(0) for j := 0; j < T; j++ { - modQ(mul.Mul(m[i][j], state[j])) + mul.Mul(m[i][j], state[j]) newState[i].Add(newState[i], mul) } - modQ(newState[i]) } } // PoseidonHash computes the Poseidon hash for the given inputs -func PoseidonHash(inp [T]*big.Int) (*big.Int, error) { - if !utils.CheckBigIntArrayInField(inp[:], constants.Q) { +func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) { + if !utils.CheckElementArrayInField(inp[:]) { return nil, errors.New("inputs values not inside Finite Field") } - state := [T]*big.Int{} + state := [T]*ff.Element{} for i := 0; i < T; i++ { - state[i] = new(big.Int).Set(inp[i]) + state[i] = utils.NewElement().Set(inp[i]) } // ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5 - var newState [T]*big.Int + var newState [T]*ff.Element for i := 0; i < T; i++ { newState[i] = Zero() } @@ -157,16 +140,16 @@ func PoseidonHash(inp [T]*big.Int) (*big.Int, error) { return state[0], nil } -// Hash performs the Poseidon hash over a *big.Int array +// Hash performs the Poseidon hash over a ff.Element array // in chunks of 5 elements -func Hash(arr []*big.Int) (*big.Int, error) { - if !utils.CheckBigIntArrayInField(arr, constants.Q) { +func Hash(arr []*ff.Element) (*ff.Element, error) { + if !utils.CheckElementArrayInField(arr) { return nil, errors.New("inputs values not inside Finite Field") } - r := big.NewInt(1) + r := utils.NewElement().SetOne() for i := 0; i < len(arr); i = i + T - 1 { - var toHash [T]*big.Int + var toHash [T]*ff.Element j := 0 for ; j < T-1; j++ { if i+j >= len(arr) { @@ -177,14 +160,14 @@ func Hash(arr []*big.Int) (*big.Int, error) { toHash[j] = r j++ for ; j < T; j++ { - toHash[j] = constants.Zero + toHash[j] = Zero() } ph, err := PoseidonHash(toHash) if err != nil { return nil, err } - modQ(r.Add(r, ph)) + r.Add(r, ph) } return r, nil @@ -192,18 +175,19 @@ func Hash(arr []*big.Int) (*big.Int, error) { // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // little-endian -func HashBytes(b []byte) (*big.Int, error) { +func HashBytes(b []byte) (*ff.Element, error) { n := 31 - bElems := make([]*big.Int, 0, len(b)/n+1) + bElems := make([]*ff.Element, 0, len(b)/n+1) for i := 0; i < len(b)/n; i++ { - v := Zero() + v := big.NewInt(int64(0)) utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)]) - bElems = append(bElems, v) + bElems = append(bElems, utils.NewElement().SetBigInt(v)) + } if len(b)%n != 0 { - v := Zero() + v := big.NewInt(int64(0)) utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:]) - bElems = append(bElems, v) + bElems = append(bElems, utils.NewElement().SetBigInt(v)) } return Hash(bElems) } diff --git a/poseidon/poseidon_test.go b/poseidon/poseidon_test.go index de13104..b6791ef 100644 --- a/poseidon/poseidon_test.go +++ b/poseidon/poseidon_test.go @@ -5,6 +5,7 @@ import ( "math/big" "testing" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "github.com/stretchr/testify/assert" "golang.org/x/crypto/blake2b" @@ -16,46 +17,46 @@ func TestBlake2bVersion(t *testing.T) { } func TestPoseidon(t *testing.T) { - b1 := big.NewInt(int64(1)) - b2 := big.NewInt(int64(2)) - h, err := Hash([]*big.Int{b1, b2}) + b1 := utils.NewElement().SetUint64(1) + b2 := utils.NewElement().SetUint64(2) + h, err := Hash([]*ff.Element{b1, b2}) assert.Nil(t, err) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String()) - b3 := big.NewInt(int64(3)) - b4 := big.NewInt(int64(4)) - h, err = Hash([]*big.Int{b3, b4}) + b3 := utils.NewElement().SetUint64(3) + b4 := utils.NewElement().SetUint64(4) + h, err = Hash([]*ff.Element{b3, b4}) assert.Nil(t, err) assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String()) msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") n := 31 - msgElems := make([]*big.Int, 0, len(msg)/n+1) + msgElems := make([]*ff.Element, 0, len(msg)/n+1) for i := 0; i < len(msg)/n; i++ { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)]) - msgElems = append(msgElems, v) + msgElems = append(msgElems, utils.NewElement().SetBigInt(v)) } if len(msg)%n != 0 { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:]) - msgElems = append(msgElems, v) + msgElems = append(msgElems, utils.NewElement().SetBigInt(v)) } hmsg, err := Hash(msgElems) assert.Nil(t, err) assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String()) msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.") - msg2Elems := make([]*big.Int, 0, len(msg2)/n+1) + msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1) for i := 0; i < len(msg2)/n; i++ { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)]) - msg2Elems = append(msg2Elems, v) + msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v)) } if len(msg2)%n != 0 { v := new(big.Int) utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:]) - msg2Elems = append(msg2Elems, v) + msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v)) } hmsg2, err := Hash(msg2Elems) assert.Nil(t, err) @@ -67,29 +68,41 @@ func TestPoseidon(t *testing.T) { } func TestPoseidonBrokenChunks(t *testing.T) { - h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4), - big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)}) + h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4), + utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)}) assert.Nil(t, err) - h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9), - big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}) + h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9), + utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)}) assert.Nil(t, err) assert.NotEqual(t, h1, h2) } func TestPoseidonBrokenPadding(t *testing.T) { - h1, err := Hash([]*big.Int{big.NewInt(1)}) + h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)}) assert.Nil(t, err) - h2, err := Hash([]*big.Int{big.NewInt(1), big.NewInt(0)}) + h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)}) assert.Nil(t, err) assert.NotEqual(t, h1, h2) } func BenchmarkPoseidon(b *testing.B) { - b12 := big.NewInt(int64(12)) - b45 := big.NewInt(int64(45)) - b78 := big.NewInt(int64(78)) - b41 := big.NewInt(int64(41)) - bigArray4 := []*big.Int{b12, b45, b78, b41} + b12 := utils.NewElement().SetUint64(12) + b45 := utils.NewElement().SetUint64(45) + b78 := utils.NewElement().SetUint64(78) + b41 := utils.NewElement().SetUint64(41) + bigArray4 := []*ff.Element{b12, b45, b78, b41} + + for i := 0; i < b.N; i++ { + Hash(bigArray4) + } +} + +func BenchmarkPoseidonLarge(b *testing.B) { + b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733") + bigArray4 := []*ff.Element{b12, b45, b78, b41} for i := 0; i < b.N; i++ { Hash(bigArray4)