Browse Source

Adapt babyjub/eddsa to new Poseidon methods

feature/poseidon-opt-goff
arnaucube 4 years ago
parent
commit
2a3f0d9ed5
5 changed files with 79 additions and 63 deletions
  1. +2
    -0
      babyjub/eddsa.go
  2. +5
    -0
      ff/util.go
  3. +26
    -21
      poseidon/poseidon.go
  4. +41
    -29
      poseidon/poseidon_test.go
  5. +5
    -13
      utils/utils.go

+ 2
- 0
babyjub/eddsa.go

@ -222,11 +222,13 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature {
r.Mod(r, SubOrder) r.Mod(r, SubOrder)
R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B
A := k.Public().Point() A := k.Public().Point()
hmInput := [poseidon.T]*big.Int{R8.X, R8.Y, A.X, A.Y, msg, big.NewInt(int64(0))} hmInput := [poseidon.T]*big.Int{R8.X, R8.Y, A.X, A.Y, msg, big.NewInt(int64(0))}
hm, err := poseidon.PoseidonHash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) hm, err := poseidon.PoseidonHash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg)
if err != nil { if err != nil {
panic(err) panic(err)
} }
S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) S := new(big.Int).Lsh(k.Scalar().BigInt(), 3)
S = S.Mul(hm, S) S = S.Mul(hm, S)
S.Add(r, S) S.Add(r, S)

+ 5
- 0
ff/util.go

@ -0,0 +1,5 @@
package ff
func NewElement() *Element {
return &Element{}
}

+ 26
- 21
poseidon/poseidon.go

@ -5,6 +5,7 @@ import (
"math/big" "math/big"
"strconv" "strconv"
"github.com/iden3/go-iden3-crypto/constants"
"github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-iden3-crypto/utils"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
@ -19,7 +20,11 @@ var constC []*ff.Element
var constM [T][T]*ff.Element var constM [T][T]*ff.Element
func Zero() *ff.Element { func Zero() *ff.Element {
return utils.NewElement().SetZero()
return ff.NewElement().SetZero()
}
func modQ(v *big.Int) {
v.Mod(v, constants.Q)
} }
func init() { func init() {
@ -32,7 +37,7 @@ func getPseudoRandom(seed string, n int) []*ff.Element {
hash := blake2b.Sum256([]byte(seed)) hash := blake2b.Sum256([]byte(seed))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
hashBigInt := big.NewInt(int64(0)) hashBigInt := big.NewInt(int64(0))
res[i] = utils.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:]))
res[i] = ff.NewElement().SetBigInt(utils.SetBigIntFromLEBytes(hashBigInt, hash[:]))
hash = blake2b.Sum256(hash[:]) hash = blake2b.Sum256(hash[:])
} }
return res return res
@ -57,7 +62,7 @@ func getMDS() [T][T]*ff.Element {
var m [T][T]*ff.Element var m [T][T]*ff.Element
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
for j := 0; j < T; j++ { for j := 0; j < T; j++ {
m[i][j] = utils.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j] = ff.NewElement().Sub(cauchyMatrix[i], cauchyMatrix[T+j])
m[i][j].Inverse(m[i][j]) m[i][j].Inverse(m[i][j])
} }
} }
@ -66,7 +71,7 @@ func getMDS() [T][T]*ff.Element {
func checkAllDifferent(v []*ff.Element) bool { func checkAllDifferent(v []*ff.Element) bool {
for i := 0; i < len(v); i++ { for i := 0; i < len(v); i++ {
if v[i].Equal(utils.NewElement().SetZero()) {
if v[i].Equal(ff.NewElement().SetZero()) {
return false return false
} }
for j := i + 1; j < len(v); j++ { for j := i + 1; j < len(v); j++ {
@ -117,13 +122,14 @@ func mix(state [T]*ff.Element, newState [T]*ff.Element, m [T][T]*ff.Element) {
} }
// PoseidonHash computes the Poseidon hash for the given inputs // PoseidonHash computes the Poseidon hash for the given inputs
func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) {
if !utils.CheckElementArrayInField(inp[:]) {
func PoseidonHash(inpBI [T]*big.Int) (*big.Int, error) {
if !utils.CheckBigIntArrayInField(inpBI[:]) {
return nil, errors.New("inputs values not inside Finite Field") return nil, errors.New("inputs values not inside Finite Field")
} }
inp := utils.BigIntArrayToElementArray(inpBI[:])
state := [T]*ff.Element{} state := [T]*ff.Element{}
for i := 0; i < T; i++ { for i := 0; i < T; i++ {
state[i] = utils.NewElement().Set(inp[i])
state[i] = ff.NewElement().Set(inp[i])
} }
// ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5 // ARK --> SBox --> M, https://eprint.iacr.org/2019/458.pdf pag.5
@ -137,19 +143,18 @@ func PoseidonHash(inp [T]*ff.Element) (*ff.Element, error) {
mix(state, newState, constM) mix(state, newState, constM)
state, newState = newState, state state, newState = newState, state
} }
return state[0], nil
rE := state[0]
r := big.NewInt(0)
rE.ToBigIntRegular(r)
return r, nil
} }
// Hash performs the Poseidon hash over a ff.Element array // Hash performs the Poseidon hash over a ff.Element array
// in chunks of 5 elements // in chunks of 5 elements
func Hash(arr []*ff.Element) (*ff.Element, error) {
if !utils.CheckElementArrayInField(arr) {
return nil, errors.New("inputs values not inside Finite Field")
}
r := utils.NewElement().SetOne()
func Hash(arr []*big.Int) (*big.Int, error) {
r := big.NewInt(int64(1))
for i := 0; i < len(arr); i = i + T - 1 { for i := 0; i < len(arr); i = i + T - 1 {
var toHash [T]*ff.Element
var toHash [T]*big.Int
j := 0 j := 0
for ; j < T-1; j++ { for ; j < T-1; j++ {
if i+j >= len(arr) { if i+j >= len(arr) {
@ -160,14 +165,14 @@ func Hash(arr []*ff.Element) (*ff.Element, error) {
toHash[j] = r toHash[j] = r
j++ j++
for ; j < T; j++ { for ; j < T; j++ {
toHash[j] = Zero()
toHash[j] = big.NewInt(0)
} }
ph, err := PoseidonHash(toHash) ph, err := PoseidonHash(toHash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r.Add(r, ph)
modQ(r.Add(r, ph))
} }
return r, nil return r, nil
@ -175,19 +180,19 @@ func Hash(arr []*ff.Element) (*ff.Element, error) {
// HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as
// little-endian // little-endian
func HashBytes(b []byte) (*ff.Element, error) {
func HashBytes(b []byte) (*big.Int, error) {
n := 31 n := 31
bElems := make([]*ff.Element, 0, len(b)/n+1)
bElems := make([]*big.Int, 0, len(b)/n+1)
for i := 0; i < len(b)/n; i++ { for i := 0; i < len(b)/n; i++ {
v := big.NewInt(int64(0)) v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, b[n*i:n*(i+1)])
bElems = append(bElems, utils.NewElement().SetBigInt(v))
bElems = append(bElems, v)
} }
if len(b)%n != 0 { if len(b)%n != 0 {
v := big.NewInt(int64(0)) v := big.NewInt(int64(0))
utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:]) utils.SetBigIntFromLEBytes(v, b[(len(b)/n)*n:])
bElems = append(bElems, utils.NewElement().SetBigInt(v))
bElems = append(bElems, v)
} }
return Hash(bElems) return Hash(bElems)
} }

+ 41
- 29
poseidon/poseidon_test.go

@ -5,7 +5,6 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-iden3-crypto/utils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
@ -17,46 +16,58 @@ func TestBlake2bVersion(t *testing.T) {
} }
func TestPoseidon(t *testing.T) { func TestPoseidon(t *testing.T) {
b1 := utils.NewElement().SetUint64(1)
b2 := utils.NewElement().SetUint64(2)
h, err := Hash([]*ff.Element{b1, b2})
b1 := big.NewInt(1)
b2 := big.NewInt(2)
h, err := Hash([]*big.Int{b1, b2})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String()) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154855", h.String())
b3 := utils.NewElement().SetUint64(3)
b4 := utils.NewElement().SetUint64(4)
h, err = Hash([]*ff.Element{b3, b4})
b3 := big.NewInt(3)
b4 := big.NewInt(4)
h, err = Hash([]*big.Int{b3, b4})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String()) assert.Equal(t, "4635491972858758537477743930622086396911540895966845494943021655521913507504", h.String())
b5 := big.NewInt(5)
b6 := big.NewInt(6)
b7 := big.NewInt(7)
b8 := big.NewInt(8)
b9 := big.NewInt(9)
b10 := big.NewInt(10)
b11 := big.NewInt(11)
b12 := big.NewInt(12)
h, err = Hash([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12})
assert.Nil(t, err)
assert.Equal(t, "15278801138972282646981503374384603641625274360649669926363020545395022098027", h.String())
msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.")
n := 31 n := 31
msgElems := make([]*ff.Element, 0, len(msg)/n+1)
msgElems := make([]*big.Int, 0, len(msg)/n+1)
for i := 0; i < len(msg)/n; i++ { for i := 0; i < len(msg)/n; i++ {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, msg[n*i:n*(i+1)])
msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
msgElems = append(msgElems, v)
} }
if len(msg)%n != 0 { if len(msg)%n != 0 {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:]) utils.SetBigIntFromLEBytes(v, msg[(len(msg)/n)*n:])
msgElems = append(msgElems, utils.NewElement().SetBigInt(v))
msgElems = append(msgElems, v)
} }
hmsg, err := Hash(msgElems) hmsg, err := Hash(msgElems)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String()) assert.Equal(t, "16019700159595764790637132363672701294192939959594423814006267756172551741065", hmsg.String())
msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.") msg2 := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Lorem ipsum dolor sit amet.")
msg2Elems := make([]*ff.Element, 0, len(msg2)/n+1)
msg2Elems := make([]*big.Int, 0, len(msg2)/n+1)
for i := 0; i < len(msg2)/n; i++ { for i := 0; i < len(msg2)/n; i++ {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)]) utils.SetBigIntFromLEBytes(v, msg2[n*i:n*(i+1)])
msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
msg2Elems = append(msg2Elems, v)
} }
if len(msg2)%n != 0 { if len(msg2)%n != 0 {
v := new(big.Int) v := new(big.Int)
utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:]) utils.SetBigIntFromLEBytes(v, msg2[(len(msg2)/n)*n:])
msg2Elems = append(msg2Elems, utils.NewElement().SetBigInt(v))
msg2Elems = append(msg2Elems, v)
} }
hmsg2, err := Hash(msg2Elems) hmsg2, err := Hash(msg2Elems)
assert.Nil(t, err) assert.Nil(t, err)
@ -68,29 +79,29 @@ func TestPoseidon(t *testing.T) {
} }
func TestPoseidonBrokenChunks(t *testing.T) { func TestPoseidonBrokenChunks(t *testing.T) {
h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4),
utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9)})
h1, err := Hash([]*big.Int{big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4),
big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9)})
assert.Nil(t, err) assert.Nil(t, err)
h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(5), utils.NewElement().SetUint64(6), utils.NewElement().SetUint64(7), utils.NewElement().SetUint64(8), utils.NewElement().SetUint64(9),
utils.NewElement().SetUint64(0), utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(2), utils.NewElement().SetUint64(3), utils.NewElement().SetUint64(4)})
h2, err := Hash([]*big.Int{big.NewInt(5), big.NewInt(6), big.NewInt(7), big.NewInt(8), big.NewInt(9),
big.NewInt(0), big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func TestPoseidonBrokenPadding(t *testing.T) { func TestPoseidonBrokenPadding(t *testing.T) {
h1, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1)})
h1, err := Hash([]*big.Int{big.NewInt(int64(1))})
assert.Nil(t, err) assert.Nil(t, err)
h2, err := Hash([]*ff.Element{utils.NewElement().SetUint64(1), utils.NewElement().SetUint64(0)})
h2, err := Hash([]*big.Int{big.NewInt(int64(1)), big.NewInt(int64(0))})
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func BenchmarkPoseidon(b *testing.B) { func BenchmarkPoseidon(b *testing.B) {
b12 := utils.NewElement().SetUint64(12)
b45 := utils.NewElement().SetUint64(45)
b78 := utils.NewElement().SetUint64(78)
b41 := utils.NewElement().SetUint64(41)
bigArray4 := []*ff.Element{b12, b45, b78, b41}
b12 := big.NewInt(int64(12))
b45 := big.NewInt(int64(45))
b78 := big.NewInt(int64(78))
b41 := big.NewInt(int64(41))
bigArray4 := []*big.Int{b12, b45, b78, b41}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Hash(bigArray4) Hash(bigArray4)
@ -98,11 +109,12 @@ func BenchmarkPoseidon(b *testing.B) {
} }
func BenchmarkPoseidonLarge(b *testing.B) { func BenchmarkPoseidonLarge(b *testing.B) {
b12 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b45 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b78 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b41 := utils.NewElement().SetString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
bigArray4 := []*ff.Element{b12, b45, b78, b41}
b12 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b45 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b78 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
b41 := utils.NewIntFromString("11384336176656855268977457483345535180380036354188103142384839473266348197733")
bigArray4 := []*big.Int{b12, b45, b78, b41}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Hash(bigArray4) Hash(bigArray4)

+ 5
- 13
utils/utils.go

@ -108,18 +108,10 @@ func CheckBigIntArrayInField(arr []*big.Int) bool {
return true return true
} }
// CheckElementArrayInField checks if given *ff.Element fits in a Field Q element
func CheckElementArrayInField(arr []*ff.Element) bool {
for _, aE := range arr {
a := big.NewInt(0)
aE.ToBigIntRegular(a)
if !CheckBigIntInField(a) {
return false
}
func BigIntArrayToElementArray(bi []*big.Int) []*ff.Element {
var o []*ff.Element
for i := range bi {
o = append(o, ff.NewElement().SetBigInt(bi[i]))
} }
return true
}
func NewElement() *ff.Element {
return &ff.Element{0, 0, 0, 0}
return o
} }

Loading…
Cancel
Save