Browse Source

Fix Hash bytes & *big.int parsers

fix/hash-parsers
arnaucube 4 years ago
parent
commit
acf3838e06
2 changed files with 63 additions and 21 deletions
  1. +18
    -9
      merkletree.go
  2. +45
    -12
      merkletree_test.go

+ 18
- 9
merkletree.go

@ -71,7 +71,12 @@ func (h Hash) String() string {
// Hex returns the hexadecimal representation of the Hash // Hex returns the hexadecimal representation of the Hash
func (h Hash) Hex() string { func (h Hash) Hex() string {
return hex.EncodeToString(h.BigInt().Bytes())
return hex.EncodeToString(h[:])
// alternatively equivalent, but with too extra steps:
// bRaw := h.BigInt().Bytes()
// b := [32]byte{}
// copy(b[:], common.SwapEndianness(bRaw[:]))
// return hex.EncodeToString(b[:])
} }
// BigInt returns the *big.Int representation of the *Hash // BigInt returns the *big.Int representation of the *Hash
@ -85,21 +90,25 @@ func (h *Hash) BigInt() *big.Int {
// Bytes returns the []byte representation of the *Hash, which always is 32 // Bytes returns the []byte representation of the *Hash, which always is 32
// bytes length. // bytes length.
func (h *Hash) Bytes() []byte { func (h *Hash) Bytes() []byte {
bi := new(big.Int).SetBytes(common.SwapEndianness(h[:])).Bytes()
bi := new(big.Int).SetBytes(h[:]).Bytes()
b := [32]byte{} b := [32]byte{}
copy(b[:], bi[:])
copy(b[:], common.SwapEndianness(bi[:]))
return b[:] return b[:]
} }
// NewBigIntFromBytes returns a *big.Int from a byte array, swapping the
// NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
// endianness in the process. This is the intended method to get a *big.Int // endianness in the process. This is the intended method to get a *big.Int
// from a byte array that previously has ben generated by the Hash.Bytes() // from a byte array that previously has ben generated by the Hash.Bytes()
// method. // method.
func NewBigIntFromBytes(b []byte) (*big.Int, error) {
func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
if len(b) != ElemBytesLen { if len(b) != ElemBytesLen {
return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b)) return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
} }
return new(big.Int).SetBytes(common.SwapEndianness(b[:ElemBytesLen])), nil
bi := new(big.Int).SetBytes(b[:ElemBytesLen])
if !cryptoUtils.CheckBigIntInField(bi) {
return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
}
return bi, nil
} }
// NewHashFromBigInt returns a *Hash representation of the given *big.Int // NewHashFromBigInt returns a *Hash representation of the given *big.Int
@ -128,7 +137,7 @@ func NewHashFromHex(h string) (*Hash, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewHashFromBytes(b)
return NewHashFromBytes(common.SwapEndianness(b[:]))
} }
// MerkleTree is the struct with the main elements of the MerkleTree // MerkleTree is the struct with the main elements of the MerkleTree
@ -1130,11 +1139,11 @@ func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error { func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
for i := 0; i < len(b); i += 64 { for i := 0; i < len(b); i += 64 {
lr := b[i : i+64] lr := b[i : i+64]
lB, err := NewBigIntFromBytes(lr[:32])
lB, err := NewBigIntFromHashBytes(lr[:32])
if err != nil { if err != nil {
return err return err
} }
rB, err := NewBigIntFromBytes(lr[32:])
rB, err := NewBigIntFromHashBytes(lr[32:])
if err != nil { if err != nil {
return err return err
} }

+ 45
- 12
merkletree_test.go

@ -7,7 +7,8 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/iden3/go-iden3-core/common"
"github.com/iden3/go-iden3-crypto/constants"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
"github.com/iden3/go-merkletree/db/memory" "github.com/iden3/go-merkletree/db/memory"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -46,11 +47,11 @@ func TestHashParsers(t *testing.T) {
h := NewHashFromBigInt(b) h := NewHashFromBigInt(b)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String())
assert.Equal(t, "49322979...", h.String()) assert.Equal(t, "49322979...", h.String())
assert.Equal(t, "0ae794eb9c3d8bbb9002e993fc2ed301dcbd2af5508ed072c375e861f1aa5b26", h.Hex())
assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
b1, err := NewBigIntFromBytes(b.Bytes())
b1, err := NewBigIntFromHashBytes(b.Bytes())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, new(big.Int).SetBytes(common.SwapEndianness(b.Bytes())).String(), b1.String())
assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
b2, err := NewHashFromBytes(b.Bytes()) b2, err := NewHashFromBytes(b.Bytes())
assert.Nil(t, err) assert.Nil(t, err)
@ -61,6 +62,31 @@ func TestHashParsers(t *testing.T) {
assert.Equal(t, h, h2) assert.Equal(t, h, h2)
_, err = NewHashFromHex("0x12") _, err = NewHashFromHex("0x12")
assert.NotNil(t, err) assert.NotNil(t, err)
// check limits
a := new(big.Int).Sub(constants.Q, big.NewInt(1))
testHashParsers(t, a)
a = big.NewInt(int64(1))
testHashParsers(t, a)
}
func testHashParsers(t *testing.T, a *big.Int) {
require.True(t, cryptoUtils.CheckBigIntInField(a))
h := NewHashFromBigInt(a)
assert.Equal(t, a, h.BigInt())
hFromBytes, err := NewHashFromBytes(h.Bytes())
assert.Nil(t, err)
assert.Equal(t, h, hFromBytes)
assert.Equal(t, a, hFromBytes.BigInt())
assert.Equal(t, a.String(), hFromBytes.BigInt().String())
hFromHex, err := NewHashFromHex(h.Hex())
assert.Nil(t, err)
assert.Equal(t, h, hFromHex)
aBIFromHBytes, err := NewBigIntFromHashBytes(h.Bytes())
assert.Nil(t, err)
assert.Equal(t, a, aBIFromHBytes)
assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
} }
func TestNewTree(t *testing.T) { func TestNewTree(t *testing.T) {
@ -111,7 +137,7 @@ func TestAddDifferentOrder(t *testing.T) {
} }
assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex()) assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
assert.Equal(t, "0630b27c6f8c7d36d144369ab1ac408552b544ebe96ad642bad6a94a96258e26", mt1.Root().Hex())
assert.Equal(t, "268e25964aa9d6ba42d66ae9eb44b5528540acb19a3644d1367d8c6f7cb23006", mt1.Root().Hex())
} }
func TestAddRepeatedIndex(t *testing.T) { func TestAddRepeatedIndex(t *testing.T) {
@ -264,12 +290,12 @@ func TestSiblingsFromProof(t *testing.T) {
siblings := SiblingsFromProof(proof) siblings := SiblingsFromProof(proof)
assert.Equal(t, 6, len(siblings)) assert.Equal(t, 6, len(siblings))
assert.Equal(t, "2f59aeef9e5b881609aa56940dba76b5cb1440a794f4eb03ad5e5958dd8b475b", siblings[0].Hex())
assert.Equal(t, "2eb29ffbded0987f36a62aecddf748d2b9bf28326300bfa15e474e0a12abe8c1", siblings[1].Hex())
assert.Equal(t, "0c6ee1298933d073a390cc3d267a8a4d5a7df65a126d3fdc5a16b9c28afddaf4", siblings[2].Hex())
assert.Equal(t, "1575898b0b4e7802a6be130e7b76ede64fe42079b6852eba6af985bd46a34aa9", siblings[3].Hex())
assert.Equal(t, "1d15b701c1fd521841120980c5cbfa86f15b1f22bf1d3079ed0d0314751d7954", siblings[4].Hex())
assert.Equal(t, "1ee00f37756159cfefaa0bce02779460b449a049165f3bb9fef81105bc285d43", siblings[5].Hex())
assert.Equal(t, "5b478bdd58595ead03ebf494a74014cbb576ba0d9456aa0916885b9eefae592f", siblings[0].Hex())
assert.Equal(t, "c1e8ab120a4e475ea1bf00633228bfb9d248f7ddec2aa6367f98d0defb9fb22e", siblings[1].Hex())
assert.Equal(t, "f4dafd8ac2b9165adc3f6d125af67d5a4d8a7a263dcc90a373d0338929e16e0c", siblings[2].Hex())
assert.Equal(t, "a94aa346bd85f96aba2e85b67920e44fe6ed767b0e13bea602784e0b8b897515", siblings[3].Hex())
assert.Equal(t, "54791d7514030ded79301dbf221f5bf186facbc5800912411852fdc101b7151d", siblings[4].Hex())
assert.Equal(t, "435d28bc0511f8feb93b5f1649a049b460947702ce0baaefcf596175370fe01e", siblings[5].Hex())
} }
func TestVerifyProofCases(t *testing.T) { func TestVerifyProofCases(t *testing.T) {
@ -556,11 +582,18 @@ func TestDumpLeafsImportLeafs(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
defer mt.db.Close() defer mt.db.Close()
q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
k := big.NewInt(int64(i))
// use numbers near under Q
k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
v := big.NewInt(0) v := big.NewInt(0)
err = mt.Add(k, v) err = mt.Add(k, v)
require.Nil(t, err) require.Nil(t, err)
// use numbers near above 0
k = big.NewInt(int64(i))
err = mt.Add(k, v)
require.Nil(t, err)
} }
d, err := mt.DumpLeafs(nil) d, err := mt.DumpLeafs(nil)

Loading…
Cancel
Save