diff --git a/common/account.go b/common/account.go index 72342e5..d7e8ea1 100644 --- a/common/account.go +++ b/common/account.go @@ -21,9 +21,9 @@ const ( // maxBalanceBytes is the maximum bytes that can use the Account.Balance *big.Int maxBalanceBytes = 24 - idxBytesLen = 4 - // maxIdxValue is the maximum value that Idx can have (32 bits: maxIdxValue=2**32-1) - maxIdxValue = 0xffffffff + idxBytesLen = 6 + // maxIdxValue is the maximum value that Idx can have (48 bits: maxIdxValue=2**48-1) + maxIdxValue = 0xffffffffffff // userThreshold determines the threshold from the User Idxs can be userThreshold = 256 @@ -40,7 +40,7 @@ var ( ) // Idx represents the account Index in the MerkleTree -type Idx uint32 +type Idx uint64 // String returns a string representation of the Idx func (idx Idx) String() string { @@ -48,10 +48,15 @@ func (idx Idx) String() string { } // Bytes returns a byte array representing the Idx -func (idx Idx) Bytes() []byte { - var b [4]byte - binary.BigEndian.PutUint32(b[:], uint32(idx)) - return b[:] +func (idx Idx) Bytes() ([6]byte, error) { + if idx > maxIdxValue { + return [6]byte{}, ErrIdxOverflow + } + var idxBytes [8]byte + binary.BigEndian.PutUint64(idxBytes[:], uint64(idx)) + var b [6]byte + copy(b[:], idxBytes[2:]) + return b, nil } // BigInt returns a *big.Int representing the Idx @@ -62,9 +67,11 @@ func (idx Idx) BigInt() *big.Int { // IdxFromBytes returns Idx from a byte array func IdxFromBytes(b []byte) (Idx, error) { if len(b) != idxBytesLen { - return 0, fmt.Errorf("can not parse Idx, bytes len %d, expected 4", len(b)) + return 0, fmt.Errorf("can not parse Idx, bytes len %d, expected %d", len(b), idxBytesLen) } - idx := binary.BigEndian.Uint32(b[:4]) + var idxBytes [8]byte + copy(idxBytes[2:], b[:]) + idx := binary.BigEndian.Uint64(idxBytes[:]) return Idx(idx), nil } @@ -73,7 +80,35 @@ func IdxFromBigInt(b *big.Int) (Idx, error) { if b.Int64() > maxIdxValue { return 0, ErrNumOverflow } - return Idx(uint32(b.Int64())), nil + return Idx(uint64(b.Int64())), nil +} + +// Nonce represents the nonce value in a uint64, which has the method Bytes that returns a byte array of length 5 (40 bits). +type Nonce uint64 + +// Bytes returns a byte array of length 5 representing the Nonce +func (n Nonce) Bytes() ([5]byte, error) { + if n > maxNonceValue { + return [5]byte{}, ErrNonceOverflow + } + var nonceBytes [8]byte + binary.BigEndian.PutUint64(nonceBytes[:], uint64(n)) + var b [5]byte + copy(b[:], nonceBytes[3:]) + return b, nil +} + +// BigInt returns the *big.Int representation of the Nonce value +func (n Nonce) BigInt() *big.Int { + return big.NewInt(int64(n)) +} + +// NonceFromBytes returns Nonce from a [5]byte +func NonceFromBytes(b [5]byte) Nonce { + var nonceBytes [8]byte + copy(nonceBytes[3:], b[:]) + nonce := binary.BigEndian.Uint64(nonceBytes[:]) + return Nonce(nonce) } // Account is a struct that gives information of the holdings of an address and a specific token. Is the data structure that generates the Value stored in the leaf of the MerkleTree diff --git a/common/account_test.go b/common/account_test.go index eeebdf5..1c92cc6 100644 --- a/common/account_test.go +++ b/common/account_test.go @@ -15,6 +15,61 @@ import ( "github.com/stretchr/testify/assert" ) +func TestIdxParser(t *testing.T) { + i := Idx(1) + iBytes, err := i.Bytes() + assert.Nil(t, err) + assert.Equal(t, 6, len(iBytes)) + assert.Equal(t, "000000000001", hex.EncodeToString(iBytes[:])) + i2, err := IdxFromBytes(iBytes[:]) + assert.Nil(t, err) + assert.Equal(t, i, i2) + + i = Idx(100) + assert.Equal(t, big.NewInt(100), i.BigInt()) + + // value before overflow + i = Idx(281474976710655) + iBytes, err = i.Bytes() + assert.Nil(t, err) + assert.Equal(t, 6, len(iBytes)) + assert.Equal(t, "ffffffffffff", hex.EncodeToString(iBytes[:])) + i2, err = IdxFromBytes(iBytes[:]) + assert.Nil(t, err) + assert.Equal(t, i, i2) + + // expect value overflow + i = Idx(281474976710656) + iBytes, err = i.Bytes() + assert.NotNil(t, err) + assert.Equal(t, ErrIdxOverflow, err) +} + +func TestNonceParser(t *testing.T) { + n := Nonce(1) + nBytes, err := n.Bytes() + assert.Nil(t, err) + assert.Equal(t, 5, len(nBytes)) + assert.Equal(t, "0000000001", hex.EncodeToString(nBytes[:])) + n2 := NonceFromBytes(nBytes) + assert.Equal(t, n, n2) + + // value before overflow + n = Nonce(1099511627775) + nBytes, err = n.Bytes() + assert.Nil(t, err) + assert.Equal(t, 5, len(nBytes)) + assert.Equal(t, "ffffffffff", hex.EncodeToString(nBytes[:])) + n2 = NonceFromBytes(nBytes) + assert.Equal(t, n, n2) + + // expect value overflow + n = Nonce(1099511627776) + nBytes, err = n.Bytes() + assert.NotNil(t, err) + assert.Equal(t, ErrNonceOverflow, err) +} + func TestAccount(t *testing.T) { var sk babyjub.PrivateKey _, err := hex.Decode(sk[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) diff --git a/common/errors.go b/common/errors.go index 4fc0a6a..bb04f59 100644 --- a/common/errors.go +++ b/common/errors.go @@ -11,5 +11,8 @@ var ErrNumOverflow = errors.New("Value overflows the type") // ErrNonceOverflow is used when a given nonce overflows the maximum capacity of the Nonce (2**40-1) var ErrNonceOverflow = errors.New("Nonce overflow, max value: 2**40 -1") +// ErrIdxOverflow is used when a given nonce overflows the maximum capacity of the Idx (2**48-1) +var ErrIdxOverflow = errors.New("Idx overflow, max value: 2**48 -1") + // ErrBatchQueueEmpty is used when the coordinator.BatchQueue.Pop() is called and has no elements var ErrBatchQueueEmpty = errors.New("BatchQueue empty") diff --git a/common/l1tx.go b/common/l1tx.go index cffcfc9..4e211e5 100644 --- a/common/l1tx.go +++ b/common/l1tx.go @@ -10,7 +10,7 @@ import ( const ( // L1TxBytesLen is the length of the byte array that represents the L1Tx - L1TxBytesLen = 68 + L1TxBytesLen = 72 ) // L1Tx is a struct that represents a L1 tx @@ -63,23 +63,31 @@ func (tx *L1Tx) Tx() *Tx { // Bytes encodes a L1Tx into []byte func (tx *L1Tx) Bytes(nLevels int) ([]byte, error) { - var b [68]byte + var b [L1TxBytesLen]byte copy(b[0:20], tx.FromEthAddr.Bytes()) pkComp := tx.FromBJJ.Compress() copy(b[20:52], pkComp[:]) - copy(b[52:56], tx.FromIdx.Bytes()) + fromIdxBytes, err := tx.FromIdx.Bytes() + if err != nil { + return nil, err + } + copy(b[52:58], fromIdxBytes[:]) loadAmountFloat16, err := NewFloat16(tx.LoadAmount) if err != nil { return nil, err } - copy(b[56:58], loadAmountFloat16.Bytes()) + copy(b[58:60], loadAmountFloat16.Bytes()) amountFloat16, err := NewFloat16(tx.Amount) if err != nil { return nil, err } - copy(b[58:60], amountFloat16.Bytes()) - copy(b[60:64], tx.TokenID.Bytes()) - copy(b[64:68], tx.ToIdx.Bytes()) + copy(b[60:62], amountFloat16.Bytes()) + copy(b[62:66], tx.TokenID.Bytes()) + toIdxBytes, err := tx.ToIdx.Bytes() + if err != nil { + return nil, err + } + copy(b[66:72], toIdxBytes[:]) return b[:], nil } @@ -99,17 +107,17 @@ func L1TxFromBytes(b []byte) (*L1Tx, error) { if err != nil { return nil, err } - tx.FromIdx, err = IdxFromBytes(b[52:56]) + tx.FromIdx, err = IdxFromBytes(b[52:58]) if err != nil { return nil, err } - tx.LoadAmount = Float16FromBytes(b[56:58]).BigInt() - tx.Amount = Float16FromBytes(b[58:60]).BigInt() - tx.TokenID, err = TokenIDFromBytes(b[60:64]) + tx.LoadAmount = Float16FromBytes(b[58:60]).BigInt() + tx.Amount = Float16FromBytes(b[60:62]).BigInt() + tx.TokenID, err = TokenIDFromBytes(b[62:66]) if err != nil { return nil, err } - tx.ToIdx, err = IdxFromBytes(b[64:68]) + tx.ToIdx, err = IdxFromBytes(b[66:72]) if err != nil { return nil, err } diff --git a/common/l1tx_test.go b/common/l1tx_test.go index 0f4636a..b436aeb 100644 --- a/common/l1tx_test.go +++ b/common/l1tx_test.go @@ -29,7 +29,7 @@ func TestL1TxByteParsers(t *testing.T) { FromEthAddr: ethCommon.HexToAddress("0xc58d29fA6e86E4FAe04DDcEd660d45BCf3Cb2370"), } - expected, err := utils.HexDecode("c58d29fa6e86e4fae04ddced660d45bcf3cb237056ca90f80d7c374ae7485e9bcc47d4ac399460948da6aeeb899311097925a72c00000002000200010000000500000003") + expected, err := utils.HexDecode("c58d29fa6e86e4fae04ddced660d45bcf3cb237056ca90f80d7c374ae7485e9bcc47d4ac399460948da6aeeb899311097925a72c0000000000020002000100000005000000000003") require.Nil(t, err) encodedData, err := l1Tx.Bytes(32) diff --git a/common/pooll2tx.go b/common/pooll2tx.go index d64caaa..3264e62 100644 --- a/common/pooll2tx.go +++ b/common/pooll2tx.go @@ -1,7 +1,6 @@ package common import ( - "encoding/binary" "fmt" "math/big" "time" @@ -11,34 +10,6 @@ import ( "github.com/iden3/go-iden3-crypto/poseidon" ) -// Nonce represents the nonce value in a uint64, which has the method Bytes that returns a byte array of length 5 (40 bits). -type Nonce uint64 - -// Bytes returns a byte array of length 5 representing the Nonce -func (n Nonce) Bytes() ([5]byte, error) { - if n > maxNonceValue { - return [5]byte{}, ErrNonceOverflow - } - var nonceBytes [8]byte - binary.BigEndian.PutUint64(nonceBytes[:], uint64(n)) - var b [5]byte - copy(b[:], nonceBytes[3:]) - return b, nil -} - -// BigInt returns the *big.Int representation of the Nonce value -func (n Nonce) BigInt() *big.Int { - return big.NewInt(int64(n)) -} - -// NonceFromBytes returns Nonce from a [5]byte -func NonceFromBytes(b [5]byte) Nonce { - var nonceBytes [8]byte - copy(nonceBytes[3:], b[:]) - nonce := binary.BigEndian.Uint64(nonceBytes[:]) - return Nonce(nonce) -} - // PoolL2Tx is a struct that represents a L2Tx sent by an account to the coordinator hat is waiting to be forged type PoolL2Tx struct { // Stored in DB: mandatory fileds @@ -109,8 +80,16 @@ func (tx *PoolL2Tx) TxCompressedData() (*big.Int, error) { copy(b[2:7], nonceBytes[:]) copy(b[7:11], tx.TokenID.Bytes()) copy(b[11:13], amountFloat16.Bytes()) - copy(b[13+2:19], tx.ToIdx.Bytes()) - copy(b[19+2:25], tx.FromIdx.Bytes()) + toIdxBytes, err := tx.ToIdx.Bytes() + if err != nil { + return nil, err + } + copy(b[13:19], toIdxBytes[:]) + fromIdxBytes, err := tx.FromIdx.Bytes() + if err != nil { + return nil, err + } + copy(b[19:25], fromIdxBytes[:]) copy(b[25:27], []byte{0, 1, 0, 0}) // TODO check js implementation (unexpected behaviour from test vector generated from js) copy(b[27:31], sc.Bytes()) @@ -146,8 +125,16 @@ func (tx *PoolL2Tx) TxCompressedDataV2() (*big.Int, error) { copy(b[2:7], nonceBytes[:]) copy(b[7:11], tx.TokenID.Bytes()) copy(b[11:13], amountFloat16.Bytes()) - copy(b[13+2:19], tx.ToIdx.Bytes()) - copy(b[19+2:25], tx.FromIdx.Bytes()) + toIdxBytes, err := tx.ToIdx.Bytes() + if err != nil { + return nil, err + } + copy(b[13:19], toIdxBytes[:]) + fromIdxBytes, err := tx.FromIdx.Bytes() + if err != nil { + return nil, err + } + copy(b[19:25], fromIdxBytes[:]) bi := new(big.Int).SetBytes(b[:]) return bi, nil diff --git a/common/pooll2tx_test.go b/common/pooll2tx_test.go index 2f67739..b52ee89 100644 --- a/common/pooll2tx_test.go +++ b/common/pooll2tx_test.go @@ -10,31 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNonceParser(t *testing.T) { - n := Nonce(1) - nBytes, err := n.Bytes() - assert.Nil(t, err) - assert.Equal(t, 5, len(nBytes)) - assert.Equal(t, "0000000001", hex.EncodeToString(nBytes[:])) - n2 := NonceFromBytes(nBytes) - assert.Equal(t, n, n2) - - // value before overflow - n = Nonce(1099511627775) - nBytes, err = n.Bytes() - assert.Nil(t, err) - assert.Equal(t, 5, len(nBytes)) - assert.Equal(t, "ffffffffff", hex.EncodeToString(nBytes[:])) - n2 = NonceFromBytes(nBytes) - assert.Equal(t, n, n2) - - // expect value overflow - n = Nonce(1099511627776) - nBytes, err = n.Bytes() - assert.NotNil(t, err) - assert.Equal(t, ErrNonceOverflow, err) -} - func TestTxCompressedData(t *testing.T) { var sk babyjub.PrivateKey _, err := hex.Decode(sk[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) diff --git a/common/tx_test.go b/common/tx_test.go deleted file mode 100644 index bce7ded..0000000 --- a/common/tx_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package common - -import ( - "math/big" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestIdx(t *testing.T) { - i := Idx(100) - assert.Equal(t, big.NewInt(100), i.BigInt()) - - i = Idx(uint32(4294967295)) - assert.Equal(t, "4294967295", i.BigInt().String()) - - b := big.NewInt(4294967296) - i, err := IdxFromBigInt(b) - assert.NotNil(t, err) - assert.Equal(t, ErrNumOverflow, err) - assert.Equal(t, Idx(0), i) -} diff --git a/db/statedb/statedb.go b/db/statedb/statedb.go index ba04218..1d5d4d4 100644 --- a/db/statedb/statedb.go +++ b/db/statedb/statedb.go @@ -235,7 +235,11 @@ func (s *StateDB) GetAccount(idx common.Idx) (*common.Account, error) { // getAccountInTreeDB is abstracted from StateDB to be used from StateDB and // from ExitTree. GetAccount returns the account for the given Idx func getAccountInTreeDB(sto db.Storage, idx common.Idx) (*common.Account, error) { - vBytes, err := sto.Get(idx.Bytes()) + idxBytes, err := idx.Bytes() + if err != nil { + return nil, err + } + vBytes, err := sto.Get(idxBytes[:]) if err != nil { return nil, err } @@ -282,7 +286,11 @@ func createAccountInTreeDB(sto db.Storage, mt *merkletree.MerkleTree, idx common return nil, err } - _, err = tx.Get(idx.Bytes()) + idxBytes, err := idx.Bytes() + if err != nil { + return nil, err + } + _, err = tx.Get(idxBytes[:]) if err != db.ErrNotFound { return nil, ErrAccountAlreadyExists } @@ -291,7 +299,7 @@ func createAccountInTreeDB(sto db.Storage, mt *merkletree.MerkleTree, idx common if err != nil { return nil, err } - err = tx.Put(idx.Bytes(), v.Bytes()) + err = tx.Put(idxBytes[:], v.Bytes()) if err != nil { return nil, err } @@ -337,7 +345,11 @@ func updateAccountInTreeDB(sto db.Storage, mt *merkletree.MerkleTree, idx common if err != nil { return nil, err } - err = tx.Put(idx.Bytes(), v.Bytes()) + idxBytes, err := idx.Bytes() + if err != nil { + return nil, err + } + err = tx.Put(idxBytes[:], v.Bytes()) if err != nil { return nil, err } diff --git a/db/statedb/txprocessors.go b/db/statedb/txprocessors.go index f7d48ef..4a9bfa4 100644 --- a/db/statedb/txprocessors.go +++ b/db/statedb/txprocessors.go @@ -678,7 +678,11 @@ func (s *StateDB) setIdx(idx common.Idx) error { if err != nil { return err } - err = tx.Put(keyidx, idx.Bytes()) + idxBytes, err := idx.Bytes() + if err != nil { + return err + } + err = tx.Put(keyidx, idxBytes[:]) if err != nil { return err } diff --git a/db/statedb/utils.go b/db/statedb/utils.go index 52e48fc..20492b8 100644 --- a/db/statedb/utils.go +++ b/db/statedb/utils.go @@ -46,12 +46,16 @@ func (s *StateDB) setIdxByEthAddrBJJ(idx common.Idx, addr ethCommon.Address, pk } k := concatEthAddrBJJ(addr, pk) // store Addr&BJJ-idx - err = tx.Put(k, idx.Bytes()) + idxBytes, err := idx.Bytes() + if err != nil { + return err + } + err = tx.Put(k, idxBytes[:]) if err != nil { return err } // store Addr-idx - err = tx.Put(addr.Bytes(), idx.Bytes()) + err = tx.Put(addr.Bytes(), idxBytes[:]) if err != nil { return err }