diff --git a/common/account_test.go b/common/account_test.go index 12867ae..5724d81 100644 --- a/common/account_test.go +++ b/common/account_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "math/big" + "strings" "testing" ethCommon "github.com/ethereum/go-ethereum/common" @@ -185,6 +186,9 @@ func TestAccountLoopRandom(t *testing.T) { } func bigFromStr(h string, u int) *big.Int { + if u == 16 { + h = strings.TrimPrefix(h, "0x") + } b, ok := new(big.Int).SetString(h, u) if !ok { panic("bigFromStr err") diff --git a/db/statedb/statedb_test.go b/db/statedb/statedb_test.go index 56219cc..be29a02 100644 --- a/db/statedb/statedb_test.go +++ b/db/statedb/statedb_test.go @@ -6,8 +6,10 @@ import ( "io/ioutil" "math/big" "os" + "strings" "testing" + ethCommon "github.com/ethereum/go-ethereum/common" ethCrypto "github.com/ethereum/go-ethereum/crypto" "github.com/hermeznetwork/hermez-node/common" "github.com/iden3/go-iden3-crypto/babyjub" @@ -376,3 +378,82 @@ func printCheckpoints(t *testing.T, path string) { fmt.Println(" " + f.Name()) } } + +func bigFromStr(h string, u int) *big.Int { + if u == 16 { + h = strings.TrimPrefix(h, "0x") + } + b, ok := new(big.Int).SetString(h, u) + if !ok { + panic("bigFromStr err") + } + return b +} + +func TestCheckAccountsTreeTestVectors(t *testing.T) { + dir, err := ioutil.TempDir("", "tmpdb") + require.Nil(t, err) + defer assert.Nil(t, os.RemoveAll(dir)) + + sdb, err := NewStateDB(dir, TypeSynchronizer, 32) + require.Nil(t, err) + + ay0 := new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(253), nil), big.NewInt(1)) + // test value from js version (compatibility-canary) + assert.Equal(t, "1fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", (hex.EncodeToString(ay0.Bytes()))) + bjj0, err := babyjub.PointFromSignAndY(true, ay0) + require.Nil(t, err) + + ay1 := bigFromStr("00", 16) + bjj1, err := babyjub.PointFromSignAndY(false, ay1) + require.Nil(t, err) + ay2 := bigFromStr("21b0a1688b37f77b1d1d5539ec3b826db5ac78b2513f574a04c50a7d4f8246d7", 16) + bjj2, err := babyjub.PointFromSignAndY(false, ay2) + require.Nil(t, err) + + ay3 := bigFromStr("0x10", 16) // 0x10=16 + bjj3, err := babyjub.PointFromSignAndY(false, ay3) + require.Nil(t, err) + accounts := []*common.Account{ + { + Idx: 1, + TokenID: 0xFFFFFFFF, + PublicKey: (*babyjub.PublicKey)(bjj0), + EthAddr: ethCommon.HexToAddress("0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"), + Nonce: common.Nonce(0xFFFFFFFFFF), + Balance: bigFromStr("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16), + }, + { + Idx: 100, + TokenID: 0, + PublicKey: (*babyjub.PublicKey)(bjj1), + EthAddr: ethCommon.HexToAddress("0x00"), + Nonce: common.Nonce(0), + Balance: bigFromStr("0", 10), + }, + { + Idx: 0xFFFFFFFFFFFF, + TokenID: 3, + PublicKey: (*babyjub.PublicKey)(bjj2), + EthAddr: ethCommon.HexToAddress("0xA3C88ac39A76789437AED31B9608da72e1bbfBF9"), + Nonce: common.Nonce(129), + Balance: bigFromStr("42000000000000000000", 10), + }, + { + Idx: 10000, + TokenID: 1000, + PublicKey: (*babyjub.PublicKey)(bjj3), + EthAddr: ethCommon.HexToAddress("0x64"), + Nonce: common.Nonce(1900), + Balance: bigFromStr("14000000000000000000", 10), + }, + } + for i := 0; i < len(accounts); i++ { + _, err = accounts[i].HashValue() + require.Nil(t, err) + _, err = sdb.CreateAccount(accounts[i].Idx, accounts[i]) + require.Nil(t, err) + } + // root value generated by js version: + assert.Equal(t, "17298264051379321456969039521810887093935433569451713402227686942080129181291", sdb.mt.Root().BigInt().String()) +} diff --git a/db/statedb/utils.go b/db/statedb/utils.go index f38609f..b39112e 100644 --- a/db/statedb/utils.go +++ b/db/statedb/utils.go @@ -45,6 +45,10 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk } } + if pk == nil { + return fmt.Errorf("BabyJubJub pk not defined") + } + // store idx for EthAddr & BJJ assuming that EthAddr & BJJ still don't // have an Idx stored in the DB, and if so, the already stored Idx is // bigger than the given one, so should be updated to the new one @@ -53,12 +57,12 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk if err != nil { return err } - k := concatEthAddrBJJTokenID(addr, pk, tokenID) - // store Addr&BJJ-idx idxBytes, err := idx.Bytes() if err != nil { return err } + // store Addr&BJJ-idx + k := concatEthAddrBJJTokenID(addr, pk, tokenID) err = tx.Put(append(PrefixKeyAddrBJJ, k...), idxBytes[:]) if err != nil { return err