From 113995d6f4a8c6f130fd55fc92c8dd71a56e5254 Mon Sep 17 00:00:00 2001 From: Oleksandr Brezhniev Date: Fri, 25 Jun 2021 23:34:20 +0300 Subject: [PATCH] Fixed sql tx close, fixed unit tests. Refactoring. Added missing structs and methods. --- data.go | 52 +++ db/memory/memory.go | 13 +- db/memory/memory_test.go | 25 ++ db/sql/sql.go | 40 +- db/sql/sql_test.go | 91 ++++- db/test/test.go | 780 +++++++++++++++++++++++++++++++++++++-- elembytes.go | 49 +++ entry.go | 98 +++++ hash.go | 124 +++++++ merkletree.go | 353 +++++------------- merkletree_test.go | 717 ----------------------------------- proof.go | 165 +++++++++ utils.go | 68 ++++ 13 files changed, 1531 insertions(+), 1044 deletions(-) create mode 100644 data.go create mode 100644 elembytes.go create mode 100644 entry.go create mode 100644 hash.go create mode 100644 proof.go diff --git a/data.go b/data.go new file mode 100644 index 0000000..282543b --- /dev/null +++ b/data.go @@ -0,0 +1,52 @@ +package merkletree + +import ( + "bytes" + "encoding/hex" + "fmt" +) + +// Data is the type used to represent the data stored in an entry of the MT. +// It consists of 8 elements: e0, e1, e2, e3, ...; +// where v = [e0,e1], index = [e2,e3]. +type Data [DataLen]ElemBytes + +func (d *Data) String() string { + return fmt.Sprintf("%s%s%s%s", hex.EncodeToString(d[0][:]), hex.EncodeToString(d[1][:]), + hex.EncodeToString(d[2][:]), hex.EncodeToString(d[3][:])) +} + +func (d *Data) Bytes() (b [ElemBytesLen * DataLen]byte) { + for i := 0; i < DataLen; i++ { + copy(b[i*ElemBytesLen:(i+1)*ElemBytesLen], d[i][:]) + } + return b +} + +func (d1 *Data) Equal(d2 *Data) bool { + return bytes.Equal(d1[0][:], d2[0][:]) && bytes.Equal(d1[1][:], d2[1][:]) && + bytes.Equal(d1[2][:], d2[2][:]) && bytes.Equal(d1[3][:], d2[3][:]) +} + +func (d Data) MarshalText() ([]byte, error) { + dataBytes := d.Bytes() + return []byte(hex.EncodeToString(dataBytes[:])), nil +} + +func (d *Data) UnmarshalText(text []byte) error { + var dataBytes [ElemBytesLen * DataLen]byte + _, err := hex.Decode(dataBytes[:], text) + if err != nil { + return err + } + *d = *NewDataFromBytes(dataBytes) + return nil +} + +func NewDataFromBytes(b [ElemBytesLen * DataLen]byte) *Data { + d := &Data{} + for i := 0; i < DataLen; i++ { + copy(d[i][:], b[i*ElemBytesLen : (i+1)*ElemBytesLen][:]) + } + return d +} diff --git a/db/memory/memory.go b/db/memory/memory.go index 0d37fed..13ac8d0 100644 --- a/db/memory/memory.go +++ b/db/memory/memory.go @@ -46,7 +46,9 @@ func (m *Storage) Get(key []byte) (*merkletree.Node, error) { func (m *Storage) GetRoot() (*merkletree.Hash, error) { if m.currentRoot != nil { - return m.currentRoot, nil + hash := merkletree.Hash{} + copy(hash[:], m.currentRoot[:]) + return &hash, nil } return nil, merkletree.ErrNotFound } @@ -97,7 +99,7 @@ func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error { func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) { if tx.currentRoot != nil { hash := merkletree.Hash{} - copy(tx.currentRoot[:], hash[:]) + copy(hash[:], tx.currentRoot[:]) return &hash, nil } return nil, merkletree.ErrNotFound @@ -105,6 +107,9 @@ func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) { // SetRoot sets a hash of merkle tree root in the interface db.Tx func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error { + + // TODO: do tx.Put('currentroot', hash) here ? + root := &merkletree.Hash{} copy(root[:], hash[:]) tx.currentRoot = root @@ -116,6 +121,10 @@ func (tx *StorageTx) Commit() error { for _, v := range tx.kv { tx.s.kv.Put(v.K, v.V) } + //if tx.currentRoot == nil { + // tx.currentRoot = &merkletree.Hash{} + //} + tx.s.currentRoot = tx.currentRoot tx.kv = nil return nil } diff --git a/db/memory/memory_test.go b/db/memory/memory_test.go index 32cdcc9..f796157 100644 --- a/db/memory/memory_test.go +++ b/db/memory/memory_test.go @@ -22,4 +22,29 @@ func TestMemory(t *testing.T) { test.TestConcatTx(t, NewMemoryStorage()) test.TestList(t, NewMemoryStorage()) test.TestIterate(t, NewMemoryStorage()) + + test.TestNewTree(t, NewMemoryStorage()) + test.TestAddDifferentOrder(t, NewMemoryStorage(), NewMemoryStorage()) + test.TestAddRepeatedIndex(t, NewMemoryStorage()) + test.TestGet(t, NewMemoryStorage()) + test.TestUpdate(t, NewMemoryStorage()) + test.TestUpdate2(t, NewMemoryStorage()) + test.TestGenerateAndVerifyProof128(t, NewMemoryStorage()) + test.TestTreeLimit(t, NewMemoryStorage()) + test.TestSiblingsFromProof(t, NewMemoryStorage()) + test.TestVerifyProofCases(t, NewMemoryStorage()) + test.TestVerifyProofFalse(t, NewMemoryStorage()) + test.TestGraphViz(t, NewMemoryStorage()) + test.TestDelete(t, NewMemoryStorage()) + test.TestDelete2(t, NewMemoryStorage(), NewMemoryStorage()) + test.TestDelete3(t, NewMemoryStorage(), NewMemoryStorage()) + test.TestDelete4(t, NewMemoryStorage(), NewMemoryStorage()) + test.TestDelete5(t, NewMemoryStorage(), NewMemoryStorage()) + test.TestDeleteNonExistingKeys(t, NewMemoryStorage()) + test.TestDumpLeafsImportLeafs(t, NewMemoryStorage(), NewMemoryStorage()) + test.TestAddAndGetCircomProof(t, NewMemoryStorage()) + test.TestUpdateCircomProcessorProof(t, NewMemoryStorage()) + test.TestSmtVerifier(t, NewMemoryStorage()) + test.TestTypesMarshalers(t, NewMemoryStorage()) + } diff --git a/db/sql/sql.go b/db/sql/sql.go index 9938937..de0a941 100644 --- a/db/sql/sql.go +++ b/db/sql/sql.go @@ -1,6 +1,7 @@ package sql import ( + "crypto/sha256" "database/sql" "encoding/binary" "errors" @@ -29,7 +30,7 @@ type Storage struct { type StorageTx struct { *Storage tx *sqlx.Tx - cache merkletree.KvMap + cache KvMap currentRoot *merkletree.Hash } @@ -74,7 +75,7 @@ func (s *Storage) NewTx() (merkletree.Tx, error) { if err != nil { return nil, err } - return &StorageTx{s, tx, make(merkletree.KvMap), s.currentRoot}, nil + return &StorageTx{s, tx, make(KvMap), s.currentRoot}, nil } // Get retrieves a value from a key in the db.Storage @@ -167,7 +168,7 @@ func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) { func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error { //fullKey := append(tx.mtId, k...) fullKey := k - tx.cache.Put(fullKey, *v) + tx.cache.Put(tx.mtId, fullKey, *v) fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v) return nil } @@ -204,17 +205,13 @@ func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error { // Add implements the method Add of the interface db.Tx func (tx *StorageTx) Add(atx merkletree.Tx) error { dbtx := atx.(*StorageTx) - //if !bytes.Equal(tx.prefix, dbtx.prefix) { - // // TODO: change cache to store prefix too! - // return errors.New("adding StorageTx with different prefix is not implemented") - //} if tx.mtId != dbtx.mtId { - // TODO: change cache to store prefix too! return errors.New("adding StorageTx with different prefix is not implemented") } for _, v := range dbtx.cache { - tx.cache.Put(v.K, v.V) + tx.cache.Put(v.MTId, v.K, v.V) } + // TODO: change cache to store different currentRoots for different mtIds too! tx.currentRoot = dbtx.currentRoot return nil } @@ -246,7 +243,7 @@ func (tx *StorageTx) Commit() error { if err != nil { return err } - _, err = tx.tx.Exec(upsertStmt, tx.mtId, key[:], node.Type, childL, childR, entry) + _, err = tx.tx.Exec(upsertStmt, v.MTId, key[:], node.Type, childL, childR, entry) if err != nil { return err } @@ -266,7 +263,7 @@ func (tx *StorageTx) Commit() error { // Close implements the method Close of the interface db.Tx func (tx *StorageTx) Close() { - //tx.tx.Rollback() + tx.tx.Rollback() tx.cache = nil } @@ -313,3 +310,24 @@ func (item *NodeItem) Node() (*merkletree.Node, error) { } return &node, nil } + +// KV contains a key (K) and a value (V) +type KV struct { + MTId uint64 + K []byte + V merkletree.Node +} + +// KvMap is a key-value map between a sha256 byte array hash, and a KV struct +type KvMap map[[sha256.Size]byte]KV + +// Get retrieves the value respective to a key from the KvMap +func (m KvMap) Get(k []byte) (merkletree.Node, bool) { + v, ok := m[sha256.Sum256(k)] + return v.V, ok +} + +// Put stores a key and a value in the KvMap +func (m KvMap) Put(mtId uint64, k []byte, v merkletree.Node) { + m[sha256.Sum256(k)] = KV{mtId, k, v} +} diff --git a/db/sql/sql_test.go b/db/sql/sql_test.go index c0114f0..c06d840 100644 --- a/db/sql/sql_test.go +++ b/db/sql/sql_test.go @@ -4,11 +4,13 @@ import ( "bytes" "encoding/hex" "encoding/json" + "errors" "fmt" "github.com/iden3/go-iden3-crypto/constants" cryptoUtils "github.com/iden3/go-iden3-crypto/utils" "github.com/iden3/go-merkletree" "github.com/iden3/go-merkletree/db/memory" + "github.com/iden3/go-merkletree/db/test" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,7 +20,11 @@ import ( "testing" ) -func sqlStorage(t *testing.T) merkletree.Storage { +var maxMTId uint64 = 0 +var cleared = false + +func setupDB() (*sqlx.DB, error) { + var err error host := os.Getenv("PGHOST") if host == "" { host = "localhost" @@ -33,7 +39,7 @@ func sqlStorage(t *testing.T) merkletree.Storage { } password := os.Getenv("PGPASSWORD") if password == "" { - panic("No PGPASSWORD envvar specified") + return nil, errors.New("No PGPASSWORD envvar specified") } dbname := os.Getenv("PGDATABASE") if dbname == "" { @@ -50,19 +56,34 @@ func sqlStorage(t *testing.T) merkletree.Storage { ) dbx, err := sqlx.Connect("postgres", psqlconn) if err != nil { - t.Fatal(err) - return nil + return nil, err } // clear MerkleTree table + //if !cleared { dbx.Exec("TRUNCATE TABLE mt_roots") dbx.Exec("TRUNCATE TABLE mt_nodes") + cleared = true + //} + + return dbx, nil +} + +func sqlStorage(t *testing.T) merkletree.Storage { + + dbx, err := setupDB() + if err != nil { + t.Fatal(err) + return nil + } sto, err := NewSqlStorage(dbx, false) if err != nil { t.Fatal(err) return nil } + sto.mtId = maxMTId + maxMTId++ t.Cleanup(func() { }) @@ -70,26 +91,60 @@ func sqlStorage(t *testing.T) merkletree.Storage { return sto } +func TestReturnKnownErrIfNotExists(t *testing.T) { + test.TestReturnKnownErrIfNotExists(t, sqlStorage(t)) +} + +func TestStorageInsertGet(t *testing.T) { + test.TestStorageInsertGet(t, sqlStorage(t)) +} + +func TestStorageWithPrefix(t *testing.T) { + test.TestStorageWithPrefix(t, sqlStorage(t)) +} + func TestSql(t *testing.T) { //sto := sqlStorage(t) - //t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) { - // test.TestReturnKnownErrIfNotExists(t, sqlStorage(t)) - //}) - //t.Run("TestStorageInsertGet", func(t *testing.T) { - // test.TestStorageInsertGet(t, sqlStorage(t)) - //}) - //test.TestStorageWithPrefix(t, sqlStorage(t)) - //test.TestConcatTx(t, sqlStorage(t)) - //test.TestList(t, sqlStorage(t)) - //test.TestIterate(t, sqlStorage(t)) + t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) { + test.TestReturnKnownErrIfNotExists(t, sqlStorage(t)) + }) + t.Run("TestStorageInsertGet", func(t *testing.T) { + test.TestStorageInsertGet(t, sqlStorage(t)) + }) + t.Run("TestStorageWithPrefix", func(t *testing.T) { + test.TestStorageWithPrefix(t, sqlStorage(t)) + }) + test.TestConcatTx(t, sqlStorage(t)) + test.TestList(t, sqlStorage(t)) + test.TestIterate(t, sqlStorage(t)) + + test.TestNewTree(t, sqlStorage(t)) + test.TestAddDifferentOrder(t, sqlStorage(t), sqlStorage(t)) + test.TestAddRepeatedIndex(t, sqlStorage(t)) + test.TestGet(t, sqlStorage(t)) + test.TestUpdate(t, sqlStorage(t)) + test.TestUpdate2(t, sqlStorage(t)) + test.TestGenerateAndVerifyProof128(t, sqlStorage(t)) + test.TestTreeLimit(t, sqlStorage(t)) + test.TestSiblingsFromProof(t, sqlStorage(t)) + test.TestVerifyProofCases(t, sqlStorage(t)) + test.TestVerifyProofFalse(t, sqlStorage(t)) + test.TestGraphViz(t, sqlStorage(t)) + test.TestDelete(t, sqlStorage(t)) + test.TestDelete2(t, sqlStorage(t), sqlStorage(t)) + test.TestDelete3(t, sqlStorage(t), sqlStorage(t)) + test.TestDelete4(t, sqlStorage(t), sqlStorage(t)) + test.TestDelete5(t, sqlStorage(t), sqlStorage(t)) + test.TestDeleteNonExistingKeys(t, sqlStorage(t)) + test.TestDumpLeafsImportLeafs(t, sqlStorage(t), sqlStorage(t)) + test.TestAddAndGetCircomProof(t, sqlStorage(t)) + test.TestUpdateCircomProcessorProof(t, sqlStorage(t)) + test.TestSmtVerifier(t, sqlStorage(t)) + test.TestTypesMarshalers(t, sqlStorage(t)) } var debug = false -type Fatalable interface { - Fatal(args ...interface{}) -} - func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree { sto := sqlStorage(f) diff --git a/db/test/test.go b/db/test/test.go index 416ee92..8dd3eb5 100644 --- a/db/test/test.go +++ b/db/test/test.go @@ -2,23 +2,39 @@ package test import ( + "bytes" + "encoding/hex" + "encoding/json" + "fmt" + "github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-merkletree" + "github.com/stretchr/testify/require" + "math/big" "testing" "github.com/stretchr/testify/assert" ) +var debug = false + +func newTestingMerkle(t *testing.T, sto merkletree.Storage, numLevels int) *merkletree.MerkleTree { + mt, err := merkletree.NewMerkleTree(sto, numLevels) + if err != nil { + t.Fatal(err) + return nil + } + return mt +} + // TestReturnKnownErrIfNotExists checks that the implementation of the // db.Storage interface returns the expected error in the case that the value // is not found func TestReturnKnownErrIfNotExists(t *testing.T, sto merkletree.Storage) { + //defer sto.Close() k := []byte("key") tx, err := sto.NewTx() - //defer func() { - // tx.Close() - // sto.Close() - //}() + defer tx.Close() assert.Nil(t, err) _, err = tx.Get(k) @@ -28,28 +44,30 @@ func TestReturnKnownErrIfNotExists(t *testing.T, sto merkletree.Storage) { // TestStorageInsertGet checks that the implementation of the db.Storage // interface behaves as expected func TestStorageInsertGet(t *testing.T, sto merkletree.Storage) { - key := []byte("key") + defer sto.Close() value := merkletree.Hash{1, 1, 1, 1} tx, err := sto.NewTx() - //defer func() { - // tx.Close() - // sto.Close() - //}() + defer tx.Close() + assert.Nil(t, err) node := merkletree.NewNodeMiddle(&value, &value) - err = tx.Put(key, node) + key, err := node.Key() assert.Nil(t, err) - v, err := tx.Get(key) + err = tx.Put(key[:], node) + assert.Nil(t, err) + v, err := tx.Get(key[:]) assert.Nil(t, err) assert.Equal(t, value, *v.ChildL) assert.Equal(t, value, *v.ChildR) assert.Nil(t, tx.Commit()) - tx, err = sto.NewTx() + tx2, err := sto.NewTx() + defer tx2.Close() assert.Nil(t, err) - v, err = tx.Get(key) + v, err = tx2.Get(key[:]) assert.Nil(t, err) + require.NotNil(t, v) assert.Equal(t, value, *v.ChildL) assert.Equal(t, value, *v.ChildR) } @@ -57,7 +75,7 @@ func TestStorageInsertGet(t *testing.T, sto merkletree.Storage) { // TestStorageWithPrefix checks that the implementation of the db.Storage // interface behaves as expected for the WithPrefix method func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) { - k := []byte{9} + defer sto.Close() sto1 := sto.WithPrefix([]byte{1}) sto2 := sto.WithPrefix([]byte{2}) @@ -67,37 +85,44 @@ func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) { sto1tx, err := sto1.NewTx() assert.Nil(t, err) node := merkletree.NewNodeLeaf(&merkletree.Hash{1, 2, 3}, &merkletree.Hash{4, 5, 6}) - err = sto1tx.Put(k, node) + k, err := node.Key() + err = sto1tx.Put(k[:], node) assert.Nil(t, err) - v1, err := sto1tx.Get(k) + v1, err := sto1tx.Get(k[:]) assert.Nil(t, err) assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1]) assert.Nil(t, sto1tx.Commit()) sto2tx, err := sto2.NewTx() assert.Nil(t, err) - node.Entry[1] = &merkletree.Hash{9, 10} - err = sto2tx.Put(k, node) + + v2, err := sto2tx.Get(k[:]) + assert.Equal(t, merkletree.ErrNotFound, err) + + err = sto2tx.Put(k[:], node) assert.Nil(t, err) - v2, err := sto2tx.Get(k) + v2, err = sto2tx.Get(k[:]) assert.Nil(t, err) - assert.Equal(t, merkletree.Hash{9, 10}, *v2.Entry[1]) + assert.Equal(t, merkletree.Hash{4, 5, 6}, *v2.Entry[1]) assert.Nil(t, sto2tx.Commit()) // check outside tx - v1, err = sto1.Get(k) + v1, err = sto1.Get(k[:]) assert.Nil(t, err) + require.NotNil(t, v1) assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1]) - v2, err = sto2.Get(k) + v2, err = sto2.Get(k[:]) assert.Nil(t, err) - assert.Equal(t, merkletree.Hash{9, 10}, *v2.Entry[1]) + require.NotNil(t, v2) + assert.Equal(t, merkletree.Hash{4, 5, 6}, *v2.Entry[1]) } // TestIterate checks that the implementation of the db.Storage interface // behaves as expected for the Iterate method func TestIterate(t *testing.T, sto merkletree.Storage) { + defer sto.Close() r := []merkletree.KV{} lister := func(k []byte, v *merkletree.Node) (bool, error) { r = append(r, merkletree.KV{K: merkletree.Clone(k), V: *v}) @@ -175,12 +200,12 @@ func TestConcatTx(t *testing.T, sto merkletree.Storage) { // check outside tx v1, err := sto1.Get(k) - assert.Nil(t, err) - assert.Equal(t, v1, merkletree.NewNodeLeaf(&merkletree.Hash{4, 5, 6}, &merkletree.Hash{7, 8, 9})) + require.Nil(t, err) + assert.Equal(t, *merkletree.NewNodeLeaf(&merkletree.Hash{4, 5, 6}, &merkletree.Hash{7, 8, 9}), *v1) v2, err := sto2.Get(k) - assert.Nil(t, err) - assert.Equal(t, v2, merkletree.NewNodeLeaf(&merkletree.Hash{8, 9}, &merkletree.Hash{10, 11})) + require.Nil(t, err) + assert.Equal(t, *merkletree.NewNodeLeaf(&merkletree.Hash{8, 9}, &merkletree.Hash{10, 11}), *v2) } // TestList checks that the implementation of the db.Storage interface behaves @@ -223,3 +248,704 @@ func TestList(t *testing.T, sto merkletree.Storage) { assert.Equal(t, r[0], merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5})}) assert.Equal(t, r[1], merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6})}) } + +// +// TODO: Add tests for each storage +// + +func TestNewTree(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 10) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll + + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll + + err = mt.Add(big.NewInt(1234), big.NewInt(9876)) + assert.Nil(t, err) + assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll + + dbRoot, err := mt.DB().GetRoot() + require.Nil(t, err) + assert.Equal(t, mt.Root(), dbRoot) + + proof, v, err := mt.GenerateProof(big.NewInt(33), nil) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(44), v) + + assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44))) + assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45))) +} + +func TestAddDifferentOrder(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) { + mt1 := newTestingMerkle(t, sto, 140) + defer mt1.DB().Close() + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt1.Add(k, v); err != nil { + t.Fatal(err) + } + } + + mt2 := newTestingMerkle(t, sto2, 140) + defer mt2.DB().Close() + for i := 16 - 1; i >= 0; i-- { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt2.Add(k, v); err != nil { + t.Fatal(err) + } + } + + assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex()) + assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll +} + +func TestAddRepeatedIndex(t *testing.T, sto merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + k := big.NewInt(int64(3)) + v := big.NewInt(int64(12)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + err := mt.Add(k, v) + assert.NotNil(t, err) + assert.Equal(t, err, merkletree.ErrEntryIndexAlreadyExists) +} + +func TestGet(t *testing.T, sto merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(int64(i * 2)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + k, v, _, err := mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(10), k) + assert.Equal(t, big.NewInt(20), v) + + k, v, _, err = mt.Get(big.NewInt(15)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(15), k) + assert.Equal(t, big.NewInt(30), v) + + k, v, _, err = mt.Get(big.NewInt(16)) + assert.NotNil(t, err) + assert.Equal(t, merkletree.ErrKeyNotFound, err) + assert.Equal(t, "0", k.String()) + assert.Equal(t, "0", v.String()) +} + +func TestUpdate(t *testing.T, sto merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(int64(i * 2)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + _, v, _, err := mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(20), v) + + _, err = mt.Update(big.NewInt(10), big.NewInt(1024)) + assert.Nil(t, err) + _, v, _, err = mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(1024), v) + + _, err = mt.Update(big.NewInt(1000), big.NewInt(1024)) + assert.Equal(t, merkletree.ErrKeyNotFound, err) + + dbRoot, err := mt.DB().GetRoot() + require.Nil(t, err) + assert.Equal(t, mt.Root(), dbRoot) +} + +func TestUpdate2(t *testing.T, sto merkletree.Storage) { + mt1 := newTestingMerkle(t, sto, 140) + defer mt1.DB().Close() + mt2 := newTestingMerkle(t, sto, 140) + defer mt2.DB().Close() + + err := mt1.Add(big.NewInt(1), big.NewInt(119)) + assert.Nil(t, err) + err = mt1.Add(big.NewInt(2), big.NewInt(229)) + assert.Nil(t, err) + err = mt1.Add(big.NewInt(9876), big.NewInt(6789)) + assert.Nil(t, err) + + err = mt2.Add(big.NewInt(1), big.NewInt(11)) + assert.Nil(t, err) + err = mt2.Add(big.NewInt(2), big.NewInt(22)) + assert.Nil(t, err) + err = mt2.Add(big.NewInt(9876), big.NewInt(10)) + assert.Nil(t, err) + + _, err = mt1.Update(big.NewInt(1), big.NewInt(11)) + assert.Nil(t, err) + _, err = mt1.Update(big.NewInt(2), big.NewInt(22)) + assert.Nil(t, err) + _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789)) + assert.Nil(t, err) + + assert.Equal(t, mt1.Root(), mt2.Root()) +} + +func TestGenerateAndVerifyProof128(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 140) + require.Nil(t, err) + defer mt.DB().Close() + + for i := 0; i < 128; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + proof, v, err := mt.GenerateProof(big.NewInt(42), nil) + assert.Nil(t, err) + assert.Equal(t, "0", v.String()) + assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0))) +} + +func TestTreeLimit(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 5) + require.Nil(t, err) + defer mt.DB().Close() + + for i := 0; i < 16; i++ { + err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i))) + assert.Nil(t, err) + } + + // here the tree is full, should not allow to add more data as reaches the maximum number of levels + err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16))) + assert.NotNil(t, err) + assert.Equal(t, merkletree.ErrReachedMaxLevel, err) +} + +func TestSiblingsFromProof(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 140) + require.Nil(t, err) + defer mt.DB().Close() + + for i := 0; i < 64; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + + proof, _, err := mt.GenerateProof(big.NewInt(4), nil) + if err != nil { + t.Fatal(err) + } + + siblings := merkletree.SiblingsFromProof(proof) + assert.Equal(t, 6, len(siblings)) + assert.Equal(t, + "d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b", + siblings[0].Hex()) + assert.Equal(t, + "9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16", + siblings[1].Hex()) + assert.Equal(t, + "de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27", + siblings[2].Hex()) + assert.Equal(t, + "5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317", + siblings[3].Hex()) + assert.Equal(t, + "77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122", + siblings[4].Hex()) + assert.Equal(t, + "943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07", + siblings[5].Hex()) +} + +func TestVerifyProofCases(t *testing.T, sto merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + + for i := 0; i < 8; i++ { + if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil { + t.Fatal(err) + } + } + + // Existence proof + proof, _, err := mt.GenerateProof(big.NewInt(4), nil) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, proof.Existence, true) + assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0))) + assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll + + for i := 8; i < 32; i++ { + proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil) + assert.Nil(t, err) + if debug { + fmt.Println(i, proof) + } + } + // Non-existence proof, empty aux + proof, _, err = mt.GenerateProof(big.NewInt(12), nil) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, proof.Existence, false) + // assert.True(t, proof.nodeAux == nil) + assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0))) + assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll + + // Non-existence proof, diff. node aux + proof, _, err = mt.GenerateProof(big.NewInt(10), nil) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, proof.Existence, false) + assert.True(t, proof.NodeAux != nil) + assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0))) + assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll +} + +func TestVerifyProofFalse(t *testing.T, sto merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + + for i := 0; i < 8; i++ { + if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil { + t.Fatal(err) + } + } + + // Invalid existence proof (node used for verification doesn't + // correspond to node in the proof) + proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, proof.Existence, true) + assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5)))) + + // Invalid non-existence proof (Non-existence proof, diff. node aux) + proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, proof.Existence, true) + // Now we change the proof from existence to non-existence, and add e's + // data as auxiliary node. + proof.Existence = false + proof.NodeAux = &merkletree.NodeAux{Key: merkletree.NewHashFromBigInt(big.NewInt(int64(4))), + Value: merkletree.NewHashFromBigInt(big.NewInt(4))} + assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0))) +} + +func TestGraphViz(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 10) + assert.Nil(t, err) + + _ = mt.Add(big.NewInt(1), big.NewInt(0)) + _ = mt.Add(big.NewInt(2), big.NewInt(0)) + _ = mt.Add(big.NewInt(3), big.NewInt(0)) + _ = mt.Add(big.NewInt(4), big.NewInt(0)) + _ = mt.Add(big.NewInt(5), big.NewInt(0)) + _ = mt.Add(big.NewInt(100), big.NewInt(0)) + + // mt.PrintGraphViz(nil) + + expected := `digraph hierarchy { +node [fontname=Monospace,fontsize=10,shape=box] +"56332309..." -> {"18483622..." "20902180..."} +"18483622..." -> {"75768243..." "16893244..."} +"75768243..." -> {"empty0" "21857056..."} +"empty0" [style=dashed,label=0]; +"21857056..." -> {"51072523..." "empty1"} +"empty1" [style=dashed,label=0]; +"51072523..." -> {"17311038..." "empty2"} +"empty2" [style=dashed,label=0]; +"17311038..." -> {"69499803..." "21008290..."} +"69499803..." [style=filled]; +"21008290..." [style=filled]; +"16893244..." [style=filled]; +"20902180..." -> {"12496585..." "18055627..."} +"12496585..." -> {"19374975..." "15739329..."} +"19374975..." [style=filled]; +"15739329..." [style=filled]; +"18055627..." [style=filled]; +} +` + w := bytes.NewBufferString("") + err = mt.GraphViz(w, nil) + assert.Nil(t, err) + assert.Equal(t, []byte(expected), w.Bytes()) +} + +func TestDelete(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 10) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll + + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll + + err = mt.Add(big.NewInt(1234), big.NewInt(9876)) + assert.Nil(t, err) + assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll + + // mt.PrintGraphViz(nil) + + err = mt.Delete(big.NewInt(33)) + // mt.PrintGraphViz(nil) + assert.Nil(t, err) + assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll + + err = mt.Delete(big.NewInt(1234)) + assert.Nil(t, err) + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + dbRoot, err := mt.DB().GetRoot() + require.Nil(t, err) + assert.Equal(t, mt.Root(), dbRoot) +} + +func TestDelete2(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + for i := 0; i < 8; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + + expectedRoot := mt.Root() + + k := big.NewInt(8) + v := big.NewInt(0) + err := mt.Add(k, v) + require.Nil(t, err) + + err = mt.Delete(big.NewInt(8)) + assert.Nil(t, err) + assert.Equal(t, expectedRoot, mt.Root()) + + mt2 := newTestingMerkle(t, sto2, 140) + defer mt2.DB().Close() + for i := 0; i < 8; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(0) + if err := mt2.Add(k, v); err != nil { + t.Fatal(err) + } + } + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDelete3(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + + err := mt.Add(big.NewInt(1), big.NewInt(1)) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + + assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll + + mt2 := newTestingMerkle(t, sto2, 140) + defer mt2.DB().Close() + err = mt2.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDelete4(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) { + mt := newTestingMerkle(t, sto, 140) + defer mt.DB().Close() + + err := mt.Add(big.NewInt(1), big.NewInt(1)) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(3), big.NewInt(3)) + assert.Nil(t, err) + + assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll + + mt2 := newTestingMerkle(t, sto2, 140) + defer mt2.DB().Close() + err = mt2.Add(big.NewInt(2), big.NewInt(2)) + assert.Nil(t, err) + err = mt2.Add(big.NewInt(3), big.NewInt(3)) + assert.Nil(t, err) + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDelete5(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 10) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll + + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll + + mt2 := newTestingMerkle(t, sto2, 140) + defer mt2.DB().Close() + err = mt2.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, mt2.Root(), mt.Root()) +} + +func TestDeleteNonExistingKeys(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 10) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + + err = mt.Delete(big.NewInt(33)) + assert.Nil(t, err) + err = mt.Delete(big.NewInt(33)) + assert.Equal(t, merkletree.ErrKeyNotFound, err) + + err = mt.Delete(big.NewInt(1)) + assert.Nil(t, err) + + assert.Equal(t, "0", mt.Root().String()) + + err = mt.Delete(big.NewInt(33)) + assert.Equal(t, merkletree.ErrKeyNotFound, err) +} + +func TestDumpLeafsImportLeafs(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 140) + require.Nil(t, err) + defer mt.DB().Close() + + q1 := new(big.Int).Sub(constants.Q, big.NewInt(1)) + for i := 0; i < 10; i++ { + // use numbers near under Q + k := new(big.Int).Sub(q1, big.NewInt(int64(i))) + v := big.NewInt(0) + err = mt.Add(k, v) + require.Nil(t, err) + + // use numbers near above 0 + k = big.NewInt(int64(i)) + err = mt.Add(k, v) + require.Nil(t, err) + } + + d, err := mt.DumpLeafs(nil) + assert.Nil(t, err) + + mt2, err := merkletree.NewMerkleTree(sto2, 140) + require.Nil(t, err) + defer mt2.DB().Close() + err = mt2.ImportDumpedLeafs(d) + assert.Nil(t, err) + + assert.Equal(t, mt.Root(), mt2.Root()) +} + +func TestAddAndGetCircomProof(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 10) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "0", cpp.OldRoot.String()) + assert.Equal(t, "13578938...", cpp.NewRoot.String()) + assert.Equal(t, "0", cpp.OldKey.String()) + assert.Equal(t, "0", cpp.OldValue.String()) + assert.Equal(t, "1", cpp.NewKey.String()) + assert.Equal(t, "2", cpp.NewValue.String()) + assert.Equal(t, true, cpp.IsOld0) + assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) + assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings)) + + cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "13578938...", cpp.OldRoot.String()) + assert.Equal(t, "54123936...", cpp.NewRoot.String()) + assert.Equal(t, "1", cpp.OldKey.String()) + assert.Equal(t, "2", cpp.OldValue.String()) + assert.Equal(t, "33", cpp.NewKey.String()) + assert.Equal(t, "44", cpp.NewValue.String()) + assert.Equal(t, false, cpp.IsOld0) + assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) + assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings)) + + cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66)) + assert.Nil(t, err) + assert.Equal(t, "54123936...", cpp.OldRoot.String()) + assert.Equal(t, "50943640...", cpp.NewRoot.String()) + assert.Equal(t, "0", cpp.OldKey.String()) + assert.Equal(t, "0", cpp.OldValue.String()) + assert.Equal(t, "55", cpp.NewKey.String()) + assert.Equal(t, "66", cpp.NewValue.String()) + assert.Equal(t, true, cpp.IsOld0) + assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) + assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings)) +} + +func TestUpdateCircomProcessorProof(t *testing.T, sto merkletree.Storage) { + mt := newTestingMerkle(t, sto, 10) + defer mt.DB().Close() + + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(int64(i * 2)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + _, v, _, err := mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(20), v) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024)) + assert.Nil(t, err) + assert.Equal(t, "39010880...", cpp.OldRoot.String()) + assert.Equal(t, "18587862...", cpp.NewRoot.String()) + assert.Equal(t, "10", cpp.OldKey.String()) + assert.Equal(t, "20", cpp.OldValue.String()) + assert.Equal(t, "10", cpp.NewKey.String()) + assert.Equal(t, "1024", cpp.NewValue.String()) + assert.Equal(t, false, cpp.IsOld0) + assert.Equal(t, + "[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]", + fmt.Sprintf("%v", cpp.Siblings)) +} + +func TestSmtVerifier(t *testing.T, sto merkletree.Storage) { + mt, err := merkletree.NewMerkleTree(sto, 4) + assert.Nil(t, err) + + err = mt.Add(big.NewInt(1), big.NewInt(11)) + assert.Nil(t, err) + + cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil) + assert.Nil(t, err) + jCvp, err := json.Marshal(cvp) + assert.Nil(t, err) + // expect siblings to be '[]', instead of 'null' + expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll + + assert.Equal(t, expected, string(jCvp)) + err = mt.Add(big.NewInt(2), big.NewInt(22)) + assert.Nil(t, err) + err = mt.Add(big.NewInt(3), big.NewInt(33)) + assert.Nil(t, err) + err = mt.Add(big.NewInt(4), big.NewInt(44)) + assert.Nil(t, err) + + cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil) + assert.Nil(t, err) + + jCvp, err = json.Marshal(cvp) + assert.Nil(t, err) + // Test vectors generated using https://github.com/iden3/circomlib smt.js + // Expect siblings with the extra 0 that the circom circuits need + expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll + assert.Equal(t, expected, string(jCvp)) + + cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil) + assert.Nil(t, err) + + jCvp, err = json.Marshal(cvp) + assert.Nil(t, err) + // Test vectors generated using https://github.com/iden3/circomlib smt.js + // Without the extra 0 that the circom circuits need, but that are not + // needed at a smart contract verification + expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll + assert.Equal(t, expected, string(jCvp)) +} + +func TestTypesMarshalers(t *testing.T, sto merkletree.Storage) { + // test Hash marshalers + h, err := merkletree.NewHashFromString("42") + assert.Nil(t, err) + s, err := json.Marshal(h) + assert.Nil(t, err) + var h2 *merkletree.Hash + err = json.Unmarshal(s, &h2) + assert.Nil(t, err) + assert.Equal(t, h, h2) + + // create CircomProcessorProof + mt := newTestingMerkle(t, sto, 10) + defer mt.DB().Close() + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(int64(i * 2)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + _, v, _, err := mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(20), v) + cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024)) + assert.Nil(t, err) + + // test CircomProcessorProof marshalers + b, err := json.Marshal(&cpp) + assert.Nil(t, err) + + var cpp2 *merkletree.CircomProcessorProof + err = json.Unmarshal(b, &cpp2) + assert.Nil(t, err) + assert.Equal(t, cpp, cpp2) +} diff --git a/elembytes.go b/elembytes.go new file mode 100644 index 0000000..93a472e --- /dev/null +++ b/elembytes.go @@ -0,0 +1,49 @@ +package merkletree + +import ( + "encoding/hex" + "fmt" + "math/big" +) + +const ( + // ElemBytesLen is the length of the Hash byte array + ElemBytesLen = 32 +) + +// ElemBytes is the basic type used to store data in the MT. ElemBytes +// corresponds to the serialization of an element from mimc7. +type ElemBytes [ElemBytesLen]byte + +func NewElemBytesFromBigInt(v *big.Int) (e ElemBytes) { + bs := SwapEndianness(v.Bytes()) + copy(e[:], bs) + return e +} + +func (e *ElemBytes) BigInt() *big.Int { + return new(big.Int).SetBytes(SwapEndianness(e[:])) +} + +// String returns the first 4 bytes of ElemBytes in hex. +func (e *ElemBytes) String() string { + return fmt.Sprintf("%v...", hex.EncodeToString(e[:4])) +} + +// ElemBytesToBytes serializes an array of ElemBytes to []byte. +func ElemBytesToBytes(es []ElemBytes) []byte { + bs := make([]byte, len(es)*ElemBytesLen) + for i := 0; i < len(es); i++ { + copy(bs[i*ElemBytesLen:(i+1)*ElemBytesLen], es[i][:]) + } + return bs +} + +// ElemBytesToBigInts serializes an array of ElemBytes to []byte. +func ElemBytesToBigInts(es []ElemBytes) []*big.Int { + bs := make([]*big.Int, len(es)) + for i := 0; i < len(es); i++ { + bs[i] = es[i].BigInt() + } + return bs +} diff --git a/entry.go b/entry.go new file mode 100644 index 0000000..faf5674 --- /dev/null +++ b/entry.go @@ -0,0 +1,98 @@ +package merkletree + +import ( + "encoding/hex" + cryptoUtils "github.com/iden3/go-iden3-crypto/utils" +) + +// Entry is the generic type that is stored in the MT. The entry should not be +// modified after creating because the cached hIndex and hValue won't be +// updated. +type Entry struct { + Data Data + // hIndex is a cache used to avoid recalculating hIndex + hIndex *Hash + // hValue is a cache used to avoid recalculating hValue + hValue *Hash +} + +type Entrier interface { + Entry() *Entry +} + +func (e *Entry) Index() []ElemBytes { + return e.Data[:IndexLen] +} + +func (e *Entry) Value() []ElemBytes { + return e.Data[IndexLen:] +} + +// HIndex calculates the hash of the Index of the Entry, used to find the path +// from the root to the leaf in the MT. +func (e *Entry) HIndex() (*Hash, error) { + var err error + if e.hIndex == nil { // Cache the hIndex. + hIndex, err := HashElems(ElemBytesToBigInts(e.Index())...) + if err != nil { + return nil, err + } + e.hIndex = hIndex + } + return e.hIndex, err +} + +// HValue calculates the hash of the Value of the Entry +func (e *Entry) HValue() (*Hash, error) { + var err error + if e.hValue == nil { // Cache the hValue. + hValue, err := HashElems(ElemBytesToBigInts(e.Value())...) + if err != nil { + return nil, err + } + e.hValue = hValue + } + return e.hValue, err +} + +// HiHv returns the HIndex and HValue of the Entry +func (e *Entry) HiHv() (*Hash, *Hash, error) { + hi, err := e.HIndex() + if err != nil { + return nil, nil, err + } + hv, err := e.HValue() + if err != nil { + return nil, nil, err + } + + return hi, hv, nil +} + +func (e *Entry) Bytes() []byte { + b := e.Data.Bytes() + return b[:] +} + +func (e1 *Entry) Equal(e2 *Entry) bool { + return e1.Data.Equal(&e2.Data) +} + +func (e Entry) MarshalText() ([]byte, error) { + return []byte(hex.EncodeToString(e.Bytes())), nil +} + +func (e *Entry) UnmarshalText(text []byte) error { + return e.Data.UnmarshalText(text) +} + +func (e *Entry) Clone() *Entry { + data := NewDataFromBytes(e.Data.Bytes()) + return &Entry{Data: *data} +} + +func CheckEntryInField(e Entry) bool { + bigints := ElemBytesToBigInts(e.Data[:]) + ok := cryptoUtils.CheckBigIntArrayInField(bigints) + return ok +} diff --git a/hash.go b/hash.go new file mode 100644 index 0000000..166fa51 --- /dev/null +++ b/hash.go @@ -0,0 +1,124 @@ +package merkletree + +import ( + "bytes" + "encoding/hex" + "fmt" + cryptoUtils "github.com/iden3/go-iden3-crypto/utils" + "math/big" + "strings" +) + +var ( + // HashZero is used at Empty nodes + HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +) + +// Hash is the generic type stored in the MerkleTree +type Hash [32]byte + +// MarshalText implements the marshaler for the Hash type +func (h Hash) MarshalText() ([]byte, error) { + return []byte(h.BigInt().String()), nil +} + +// UnmarshalText implements the unmarshaler for the Hash type +func (h *Hash) UnmarshalText(b []byte) error { + ha, err := NewHashFromString(string(b)) + copy(h[:], ha[:]) + return err +} + +// String returns decimal representation in string format of the Hash +func (h Hash) String() string { + s := h.BigInt().String() + if len(s) < numCharPrint { + return s + } + return s[0:numCharPrint] + "..." +} + +// Hex returns the hexadecimal representation of the Hash +func (h Hash) Hex() string { + return hex.EncodeToString(h[:]) + // alternatively equivalent, but with too extra steps: + // bRaw := h.BigInt().Bytes() + // b := [32]byte{} + // copy(b[:], SwapEndianness(bRaw[:])) + // return hex.EncodeToString(b[:]) +} + +// BigInt returns the *big.Int representation of the *Hash +func (h *Hash) BigInt() *big.Int { + if new(big.Int).SetBytes(SwapEndianness(h[:])) == nil { + return big.NewInt(0) + } + return new(big.Int).SetBytes(SwapEndianness(h[:])) +} + +// Bytes returns the []byte representation of the *Hash, which always is 32 +// bytes length. +func (h *Hash) Bytes() []byte { + bi := new(big.Int).SetBytes(h[:]).Bytes() + b := [32]byte{} + copy(b[:], SwapEndianness(bi[:])) + return b[:] +} + +func (h *Hash) Equals(h2 *Hash) bool { + return bytes.Equal(h[:], h2[:]) +} + +// NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the +// endianness in the process. This is the intended method to get a *big.Int +// from a byte array that previously has ben generated by the Hash.Bytes() +// method. +func NewBigIntFromHashBytes(b []byte) (*big.Int, error) { + if len(b) != ElemBytesLen { + return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b)) + } + bi := new(big.Int).SetBytes(b[:ElemBytesLen]) + if !cryptoUtils.CheckBigIntInField(bi) { + return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field") + } + return bi, nil +} + +// NewHashFromBigInt returns a *Hash representation of the given *big.Int +func NewHashFromBigInt(b *big.Int) *Hash { + r := &Hash{} + copy(r[:], SwapEndianness(b.Bytes())) + return r +} + +// NewHashFromBytes returns a *Hash from a byte array, swapping the endianness +// in the process. This is the intended method to get a *Hash from a byte array +// that previously has ben generated by the Hash.Bytes() method. +func NewHashFromBytes(b []byte) (*Hash, error) { + if len(b) != ElemBytesLen { + return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b)) + } + var h Hash + copy(h[:], SwapEndianness(b)) + return &h, nil +} + +// NewHashFromHex returns a *Hash representation of the given hex string +func NewHashFromHex(h string) (*Hash, error) { + h = strings.TrimPrefix(h, "0x") + b, err := hex.DecodeString(h) + if err != nil { + return nil, err + } + return NewHashFromBytes(SwapEndianness(b[:])) +} + +// NewHashFromString returns a *Hash representation of the given decimal string +func NewHashFromString(s string) (*Hash, error) { + bi, ok := new(big.Int).SetString(s, 10) + if !ok { + return nil, fmt.Errorf("Can not parse string to Hash") + } + return NewHashFromBigInt(bi), nil +} diff --git a/merkletree.go b/merkletree.go index 904d589..1858c65 100644 --- a/merkletree.go +++ b/merkletree.go @@ -2,12 +2,10 @@ package merkletree import ( "bytes" - "encoding/hex" "errors" "fmt" "io" "math/big" - "strings" "sync" cryptoUtils "github.com/iden3/go-iden3-crypto/utils" @@ -17,10 +15,13 @@ const ( // proofFlagsLen is the byte length of the flags in the proof header // (first 32 bytes). proofFlagsLen = 2 - // ElemBytesLen is the length of the Hash byte array - ElemBytesLen = 32 numCharPrint = 8 + + // IndexLen indicates how many elements are used for the index. + IndexLen = 4 + // DataLen indicates how many elements are used for the data. + DataLen = 8 ) var ( @@ -51,115 +52,8 @@ var ( ErrNotWritable = errors.New("Merkle Tree not writable") dbKeyRootNode = []byte("currentroot") - // HashZero is used at Empty nodes - HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) -// Hash is the generic type stored in the MerkleTree -type Hash [32]byte - -// MarshalText implements the marshaler for the Hash type -func (h Hash) MarshalText() ([]byte, error) { - return []byte(h.BigInt().String()), nil -} - -// UnmarshalText implements the unmarshaler for the Hash type -func (h *Hash) UnmarshalText(b []byte) error { - ha, err := NewHashFromString(string(b)) - copy(h[:], ha[:]) - return err -} - -// String returns decimal representation in string format of the Hash -func (h Hash) String() string { - s := h.BigInt().String() - if len(s) < numCharPrint { - return s - } - return s[0:numCharPrint] + "..." -} - -// Hex returns the hexadecimal representation of the Hash -func (h Hash) Hex() string { - return hex.EncodeToString(h[:]) - // alternatively equivalent, but with too extra steps: - // bRaw := h.BigInt().Bytes() - // b := [32]byte{} - // copy(b[:], SwapEndianness(bRaw[:])) - // return hex.EncodeToString(b[:]) -} - -// BigInt returns the *big.Int representation of the *Hash -func (h *Hash) BigInt() *big.Int { - if new(big.Int).SetBytes(SwapEndianness(h[:])) == nil { - return big.NewInt(0) - } - return new(big.Int).SetBytes(SwapEndianness(h[:])) -} - -// Bytes returns the []byte representation of the *Hash, which always is 32 -// bytes length. -func (h *Hash) Bytes() []byte { - bi := new(big.Int).SetBytes(h[:]).Bytes() - b := [32]byte{} - copy(b[:], SwapEndianness(bi[:])) - return b[:] -} - -// NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the -// endianness in the process. This is the intended method to get a *big.Int -// from a byte array that previously has ben generated by the Hash.Bytes() -// method. -func NewBigIntFromHashBytes(b []byte) (*big.Int, error) { - if len(b) != ElemBytesLen { - return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b)) - } - bi := new(big.Int).SetBytes(b[:ElemBytesLen]) - if !cryptoUtils.CheckBigIntInField(bi) { - return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field") - } - return bi, nil -} - -// NewHashFromBigInt returns a *Hash representation of the given *big.Int -func NewHashFromBigInt(b *big.Int) *Hash { - r := &Hash{} - copy(r[:], SwapEndianness(b.Bytes())) - return r -} - -// NewHashFromBytes returns a *Hash from a byte array, swapping the endianness -// in the process. This is the intended method to get a *Hash from a byte array -// that previously has ben generated by the Hash.Bytes() method. -func NewHashFromBytes(b []byte) (*Hash, error) { - if len(b) != ElemBytesLen { - return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b)) - } - var h Hash - copy(h[:], SwapEndianness(b)) - return &h, nil -} - -// NewHashFromHex returns a *Hash representation of the given hex string -func NewHashFromHex(h string) (*Hash, error) { - h = strings.TrimPrefix(h, "0x") - b, err := hex.DecodeString(h) - if err != nil { - return nil, err - } - return NewHashFromBytes(SwapEndianness(b[:])) -} - -// NewHashFromString returns a *Hash representation of the given decimal string -func NewHashFromString(s string) (*Hash, error) { - bi, ok := new(big.Int).SetString(s, 10) - if !ok { - return nil, fmt.Errorf("Can not parse string to Hash") - } - return NewHashFromBigInt(bi), nil -} - // MerkleTree is the struct with the main elements of the MerkleTree type MerkleTree struct { sync.RWMutex @@ -276,6 +170,51 @@ func (mt *MerkleTree) Add(k, v *big.Int) error { return nil } +// AddEntry adds the Entry to the MerkleTree +func (mt *MerkleTree) AddEntry(e *Entry) error { + // verify that the MerkleTree is writable + if !mt.writable { + return ErrNotWritable + } + // verify that the ElemBytes are valid and fit inside the mimc7 field. + if !CheckEntryInField(*e) { + return errors.New("Elements not inside the Finite Field over R") + } + tx, err := mt.db.NewTx() + if err != nil { + return err + } + mt.Lock() + defer mt.Unlock() + + hIndex, err := e.HIndex() + if err != nil { + return err + } + hValue, err := e.HValue() + if err != nil { + return err + } + newNodeLeaf := NewNodeLeaf(hIndex, hValue) + path := getPath(mt.maxLevels, hIndex[:]) + + newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path) + if err != nil { + return err + } + mt.rootKey = newRootKey + err = mt.setCurrentRoot(tx, mt.rootKey) + if err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) { @@ -757,103 +696,6 @@ type NodeAux struct { Value *Hash } -// Proof defines the required elements for a MT proof of existence or -// non-existence. -type Proof struct { - // existence indicates wether this is a proof of existence or - // non-existence. - Existence bool - // depth indicates how deep in the tree the proof goes. - depth uint - // notempties is a bitmap of non-empty Siblings found in Siblings. - notempties [ElemBytesLen - proofFlagsLen]byte - // Siblings is a list of non-empty sibling keys. - Siblings []*Hash - NodeAux *NodeAux -} - -// NewProofFromBytes parses a byte array into a Proof. -func NewProofFromBytes(bs []byte) (*Proof, error) { - if len(bs) < ElemBytesLen { - return nil, ErrInvalidProofBytes - } - p := &Proof{} - if (bs[0] & 0x01) == 0 { - p.Existence = true - } - p.depth = uint(bs[1]) - copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen]) - siblingBytes := bs[ElemBytesLen:] - sibIdx := 0 - for i := uint(0); i < p.depth; i++ { - if TestBitBigEndian(p.notempties[:], i) { - if len(siblingBytes) < (sibIdx+1)*ElemBytesLen { - return nil, ErrInvalidProofBytes - } - var sib Hash - copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen]) - p.Siblings = append(p.Siblings, &sib) - sibIdx++ - } - } - - if !p.Existence && ((bs[0] & 0x02) != 0) { - p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}} - nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:] - if len(nodeAuxBytes) != 2*ElemBytesLen { - return nil, ErrInvalidProofBytes - } - copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen]) - copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen]) - } - return p, nil -} - -// Bytes serializes a Proof into a byte array. -func (p *Proof) Bytes() []byte { - bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings) - if p.NodeAux != nil { - bsLen += 2 * ElemBytesLen //nolint:gomnd - } - bs := make([]byte, bsLen) - - if !p.Existence { - bs[0] |= 0x01 - } - bs[1] = byte(p.depth) - copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:]) - siblingsBytes := bs[len(p.notempties)+proofFlagsLen:] - for i, k := range p.Siblings { - copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:]) - } - if p.NodeAux != nil { - bs[0] |= 0x02 - copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:]) - copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:]) - } - return bs -} - -// SiblingsFromProof returns all the siblings of the proof. -func SiblingsFromProof(proof *Proof) []*Hash { - sibIdx := 0 - siblings := []*Hash{} - for lvl := 0; lvl < int(proof.depth); lvl++ { - if TestBitBigEndian(proof.notempties[:], uint(lvl)) { - siblings = append(siblings, proof.Siblings[sibIdx]) - sibIdx++ - } else { - siblings = append(siblings, &HashZero) - } - } - return siblings -} - -// AllSiblings returns all the siblings of the proof. -func (p *Proof) AllSiblings() []*Hash { - return SiblingsFromProof(p) -} - // CircomSiblingsFromSiblings returns the full siblings compatible with circom func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash { // Add the rest of empty levels to the siblings @@ -1008,67 +850,6 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, return nil, nil, ErrKeyNotFound } -// VerifyProof verifies the Merkle Proof for the entry and root. -func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool { - rootFromProof, err := RootFromProof(proof, k, v) - if err != nil { - return false - } - return bytes.Equal(rootKey[:], rootFromProof[:]) -} - -// RootFromProof calculates the root that would correspond to a tree whose -// siblings are the ones in the proof with the leaf hashing to hIndex and -// hValue. -func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) { - kHash := NewHashFromBigInt(k) - vHash := NewHashFromBigInt(v) - sibIdx := len(proof.Siblings) - 1 - var err error - var midKey *Hash - if proof.Existence { - midKey, err = LeafKey(kHash, vHash) - if err != nil { - return nil, err - } - } else { - if proof.NodeAux == nil { - midKey = &HashZero - } else { - if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) { - return nil, - fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux") - } - midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value) - if err != nil { - return nil, err - } - } - } - path := getPath(int(proof.depth), kHash[:]) - var siblingKey *Hash - for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- { - if TestBitBigEndian(proof.notempties[:], uint(lvl)) { - siblingKey = proof.Siblings[sibIdx] - sibIdx-- - } else { - siblingKey = &HashZero - } - if path[lvl] { - midKey, err = NewNodeMiddle(siblingKey, midKey).Key() - if err != nil { - return nil, err - } - } else { - midKey, err = NewNodeMiddle(midKey, siblingKey).Key() - if err != nil { - return nil, err - } - } - } - return midKey, nil -} - // walk is a helper recursive function to iterate over all tree branches func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error { n, err := mt.GetNode(key) @@ -1199,3 +980,37 @@ func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error { } return nil } + +//// ImportTree imports the tree from the output from the DumpTree function +//func (mt *MerkleTree) ImportTree(i io.Reader) error { +// tx, err := mt.DB().NewTx() +// if err != nil { +// return err +// } +// mt.Lock() +// defer mt.Unlock() +// +// r := bufio.NewReader(i) +// for { +// k, v, err := deserializeKV(r) +// if err == io.EOF { +// break +// } else if err != nil { +// return err +// } +// tx.Put(k, v) +// } +// +// v, err := tx.GetRoot() +// if err != nil { +// return err +// } +// +// if err := tx.Commit(); err != nil { +// return err +// } +// mt.rootKey = &Hash{} +// copy(mt.rootKey[:], v[:]) +// +// return nil +//} diff --git a/merkletree_test.go b/merkletree_test.go index fb67899..4e25b66 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -1,35 +1,15 @@ package merkletree import ( - "bytes" - "encoding/hex" - "encoding/json" - "fmt" "math/big" "testing" "github.com/iden3/go-iden3-crypto/constants" cryptoUtils "github.com/iden3/go-iden3-crypto/utils" - "github.com/iden3/go-merkletree/db/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var debug = false - -type Fatalable interface { - Fatal(args ...interface{}) -} - -func newTestingMerkle(f Fatalable, numLevels int) *MerkleTree { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), numLevels) - if err != nil { - f.Fatal(err) - return nil - } - return mt -} - func TestHashParsers(t *testing.T) { h0 := NewHashFromBigInt(big.NewInt(0)) assert.Equal(t, "0", h0.String()) @@ -89,700 +69,3 @@ func testHashParsers(t *testing.T, a *big.Int) { assert.Equal(t, a, aBIFromHBytes) assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String()) } - -func TestNewTree(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) - assert.Nil(t, err) - assert.Equal(t, "0", mt.Root().String()) - - // test vectors generated using https://github.com/iden3/circomlib smt.js - err = mt.Add(big.NewInt(1), big.NewInt(2)) - assert.Nil(t, err) - assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll - - err = mt.Add(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll - - err = mt.Add(big.NewInt(1234), big.NewInt(9876)) - assert.Nil(t, err) - assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll - - dbRoot, err := mt.dbGetRoot() - require.Nil(t, err) - assert.Equal(t, mt.Root(), dbRoot) - - proof, v, err := mt.GenerateProof(big.NewInt(33), nil) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(44), v) - - assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44))) - assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45))) -} - -func TestAddDifferentOrder(t *testing.T) { - mt1 := newTestingMerkle(t, 140) - defer mt1.db.Close() - for i := 0; i < 16; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(0) - if err := mt1.Add(k, v); err != nil { - t.Fatal(err) - } - } - - mt2 := newTestingMerkle(t, 140) - defer mt2.db.Close() - for i := 16 - 1; i >= 0; i-- { - k := big.NewInt(int64(i)) - v := big.NewInt(0) - if err := mt2.Add(k, v); err != nil { - t.Fatal(err) - } - } - - assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex()) - assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll -} - -func TestAddRepeatedIndex(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.db.Close() - k := big.NewInt(int64(3)) - v := big.NewInt(int64(12)) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - err := mt.Add(k, v) - assert.NotNil(t, err) - assert.Equal(t, err, ErrEntryIndexAlreadyExists) -} - -func TestGet(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.db.Close() - - for i := 0; i < 16; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(int64(i * 2)) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - k, v, _, err := mt.Get(big.NewInt(10)) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(10), k) - assert.Equal(t, big.NewInt(20), v) - - k, v, _, err = mt.Get(big.NewInt(15)) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(15), k) - assert.Equal(t, big.NewInt(30), v) - - k, v, _, err = mt.Get(big.NewInt(16)) - assert.NotNil(t, err) - assert.Equal(t, ErrKeyNotFound, err) - assert.Equal(t, "0", k.String()) - assert.Equal(t, "0", v.String()) -} - -func TestUpdate(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.db.Close() - - for i := 0; i < 16; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(int64(i * 2)) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - _, v, _, err := mt.Get(big.NewInt(10)) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(20), v) - - _, err = mt.Update(big.NewInt(10), big.NewInt(1024)) - assert.Nil(t, err) - _, v, _, err = mt.Get(big.NewInt(10)) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(1024), v) - - _, err = mt.Update(big.NewInt(1000), big.NewInt(1024)) - assert.Equal(t, ErrKeyNotFound, err) - - dbRoot, err := mt.dbGetRoot() - require.Nil(t, err) - assert.Equal(t, mt.Root(), dbRoot) -} - -func TestUpdate2(t *testing.T) { - mt1 := newTestingMerkle(t, 140) - defer mt1.db.Close() - mt2 := newTestingMerkle(t, 140) - defer mt2.db.Close() - - err := mt1.Add(big.NewInt(1), big.NewInt(119)) - assert.Nil(t, err) - err = mt1.Add(big.NewInt(2), big.NewInt(229)) - assert.Nil(t, err) - err = mt1.Add(big.NewInt(9876), big.NewInt(6789)) - assert.Nil(t, err) - - err = mt2.Add(big.NewInt(1), big.NewInt(11)) - assert.Nil(t, err) - err = mt2.Add(big.NewInt(2), big.NewInt(22)) - assert.Nil(t, err) - err = mt2.Add(big.NewInt(9876), big.NewInt(10)) - assert.Nil(t, err) - - _, err = mt1.Update(big.NewInt(1), big.NewInt(11)) - assert.Nil(t, err) - _, err = mt1.Update(big.NewInt(2), big.NewInt(22)) - assert.Nil(t, err) - _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789)) - assert.Nil(t, err) - - assert.Equal(t, mt1.Root(), mt2.Root()) -} - -func TestGenerateAndVerifyProof128(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140) - require.Nil(t, err) - defer mt.db.Close() - - for i := 0; i < 128; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(0) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - proof, v, err := mt.GenerateProof(big.NewInt(42), nil) - assert.Nil(t, err) - assert.Equal(t, "0", v.String()) - assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0))) -} - -func TestTreeLimit(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 5) - require.Nil(t, err) - defer mt.db.Close() - - for i := 0; i < 16; i++ { - err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i))) - assert.Nil(t, err) - } - - // here the tree is full, should not allow to add more data as reaches the maximum number of levels - err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16))) - assert.NotNil(t, err) - assert.Equal(t, ErrReachedMaxLevel, err) -} - -func TestSiblingsFromProof(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140) - require.Nil(t, err) - defer mt.db.Close() - - for i := 0; i < 64; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(0) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - - proof, _, err := mt.GenerateProof(big.NewInt(4), nil) - if err != nil { - t.Fatal(err) - } - - siblings := SiblingsFromProof(proof) - assert.Equal(t, 6, len(siblings)) - assert.Equal(t, - "d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b", - siblings[0].Hex()) - assert.Equal(t, - "9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16", - siblings[1].Hex()) - assert.Equal(t, - "de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27", - siblings[2].Hex()) - assert.Equal(t, - "5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317", - siblings[3].Hex()) - assert.Equal(t, - "77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122", - siblings[4].Hex()) - assert.Equal(t, - "943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07", - siblings[5].Hex()) -} - -func TestVerifyProofCases(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.DB().Close() - - for i := 0; i < 8; i++ { - if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil { - t.Fatal(err) - } - } - - // Existence proof - proof, _, err := mt.GenerateProof(big.NewInt(4), nil) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, proof.Existence, true) - assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0))) - assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll - - for i := 8; i < 32; i++ { - proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil) - assert.Nil(t, err) - if debug { - fmt.Println(i, proof) - } - } - // Non-existence proof, empty aux - proof, _, err = mt.GenerateProof(big.NewInt(12), nil) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, proof.Existence, false) - // assert.True(t, proof.nodeAux == nil) - assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0))) - assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll - - // Non-existence proof, diff. node aux - proof, _, err = mt.GenerateProof(big.NewInt(10), nil) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, proof.Existence, false) - assert.True(t, proof.NodeAux != nil) - assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0))) - assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll -} - -func TestVerifyProofFalse(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.DB().Close() - - for i := 0; i < 8; i++ { - if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil { - t.Fatal(err) - } - } - - // Invalid existence proof (node used for verification doesn't - // correspond to node in the proof) - proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, proof.Existence, true) - assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5)))) - - // Invalid non-existence proof (Non-existence proof, diff. node aux) - proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, proof.Existence, true) - // Now we change the proof from existence to non-existence, and add e's - // data as auxiliary node. - proof.Existence = false - proof.NodeAux = &NodeAux{Key: NewHashFromBigInt(big.NewInt(int64(4))), - Value: NewHashFromBigInt(big.NewInt(4))} - assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0))) -} - -func TestGraphViz(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) - assert.Nil(t, err) - - _ = mt.Add(big.NewInt(1), big.NewInt(0)) - _ = mt.Add(big.NewInt(2), big.NewInt(0)) - _ = mt.Add(big.NewInt(3), big.NewInt(0)) - _ = mt.Add(big.NewInt(4), big.NewInt(0)) - _ = mt.Add(big.NewInt(5), big.NewInt(0)) - _ = mt.Add(big.NewInt(100), big.NewInt(0)) - - // mt.PrintGraphViz(nil) - - expected := `digraph hierarchy { -node [fontname=Monospace,fontsize=10,shape=box] -"56332309..." -> {"18483622..." "20902180..."} -"18483622..." -> {"75768243..." "16893244..."} -"75768243..." -> {"empty0" "21857056..."} -"empty0" [style=dashed,label=0]; -"21857056..." -> {"51072523..." "empty1"} -"empty1" [style=dashed,label=0]; -"51072523..." -> {"17311038..." "empty2"} -"empty2" [style=dashed,label=0]; -"17311038..." -> {"69499803..." "21008290..."} -"69499803..." [style=filled]; -"21008290..." [style=filled]; -"16893244..." [style=filled]; -"20902180..." -> {"12496585..." "18055627..."} -"12496585..." -> {"19374975..." "15739329..."} -"19374975..." [style=filled]; -"15739329..." [style=filled]; -"18055627..." [style=filled]; -} -` - w := bytes.NewBufferString("") - err = mt.GraphViz(w, nil) - assert.Nil(t, err) - assert.Equal(t, []byte(expected), w.Bytes()) -} - -func TestDelete(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) - assert.Nil(t, err) - assert.Equal(t, "0", mt.Root().String()) - - // test vectors generated using https://github.com/iden3/circomlib smt.js - err = mt.Add(big.NewInt(1), big.NewInt(2)) - assert.Nil(t, err) - assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll - - err = mt.Add(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll - - err = mt.Add(big.NewInt(1234), big.NewInt(9876)) - assert.Nil(t, err) - assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll - - // mt.PrintGraphViz(nil) - - err = mt.Delete(big.NewInt(33)) - // mt.PrintGraphViz(nil) - assert.Nil(t, err) - assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll - - err = mt.Delete(big.NewInt(1234)) - assert.Nil(t, err) - err = mt.Delete(big.NewInt(1)) - assert.Nil(t, err) - assert.Equal(t, "0", mt.Root().String()) - - dbRoot, err := mt.dbGetRoot() - require.Nil(t, err) - assert.Equal(t, mt.Root(), dbRoot) -} - -func TestDelete2(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.db.Close() - for i := 0; i < 8; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(0) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - - expectedRoot := mt.Root() - - k := big.NewInt(8) - v := big.NewInt(0) - err := mt.Add(k, v) - require.Nil(t, err) - - err = mt.Delete(big.NewInt(8)) - assert.Nil(t, err) - assert.Equal(t, expectedRoot, mt.Root()) - - mt2 := newTestingMerkle(t, 140) - defer mt2.db.Close() - for i := 0; i < 8; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(0) - if err := mt2.Add(k, v); err != nil { - t.Fatal(err) - } - } - assert.Equal(t, mt2.Root(), mt.Root()) -} - -func TestDelete3(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.db.Close() - - err := mt.Add(big.NewInt(1), big.NewInt(1)) - assert.Nil(t, err) - - err = mt.Add(big.NewInt(2), big.NewInt(2)) - assert.Nil(t, err) - - assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll - err = mt.Delete(big.NewInt(1)) - assert.Nil(t, err) - assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll - - mt2 := newTestingMerkle(t, 140) - defer mt2.db.Close() - err = mt2.Add(big.NewInt(2), big.NewInt(2)) - assert.Nil(t, err) - assert.Equal(t, mt2.Root(), mt.Root()) -} - -func TestDelete4(t *testing.T) { - mt := newTestingMerkle(t, 140) - defer mt.db.Close() - - err := mt.Add(big.NewInt(1), big.NewInt(1)) - assert.Nil(t, err) - - err = mt.Add(big.NewInt(2), big.NewInt(2)) - assert.Nil(t, err) - - err = mt.Add(big.NewInt(3), big.NewInt(3)) - assert.Nil(t, err) - - assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll - err = mt.Delete(big.NewInt(1)) - assert.Nil(t, err) - assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll - - mt2 := newTestingMerkle(t, 140) - defer mt2.db.Close() - err = mt2.Add(big.NewInt(2), big.NewInt(2)) - assert.Nil(t, err) - err = mt2.Add(big.NewInt(3), big.NewInt(3)) - assert.Nil(t, err) - assert.Equal(t, mt2.Root(), mt.Root()) -} - -func TestDelete5(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) - assert.Nil(t, err) - - err = mt.Add(big.NewInt(1), big.NewInt(2)) - assert.Nil(t, err) - err = mt.Add(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll - - err = mt.Delete(big.NewInt(1)) - assert.Nil(t, err) - assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll - - mt2 := newTestingMerkle(t, 140) - defer mt2.db.Close() - err = mt2.Add(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - assert.Equal(t, mt2.Root(), mt.Root()) -} - -func TestDeleteNonExistingKeys(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) - assert.Nil(t, err) - - err = mt.Add(big.NewInt(1), big.NewInt(2)) - assert.Nil(t, err) - err = mt.Add(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - - err = mt.Delete(big.NewInt(33)) - assert.Nil(t, err) - err = mt.Delete(big.NewInt(33)) - assert.Equal(t, ErrKeyNotFound, err) - - err = mt.Delete(big.NewInt(1)) - assert.Nil(t, err) - - assert.Equal(t, "0", mt.Root().String()) - - err = mt.Delete(big.NewInt(33)) - assert.Equal(t, ErrKeyNotFound, err) -} - -func TestDumpLeafsImportLeafs(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140) - require.Nil(t, err) - defer mt.db.Close() - - q1 := new(big.Int).Sub(constants.Q, big.NewInt(1)) - for i := 0; i < 10; i++ { - // use numbers near under Q - k := new(big.Int).Sub(q1, big.NewInt(int64(i))) - v := big.NewInt(0) - err = mt.Add(k, v) - require.Nil(t, err) - - // use numbers near above 0 - k = big.NewInt(int64(i)) - err = mt.Add(k, v) - require.Nil(t, err) - } - - d, err := mt.DumpLeafs(nil) - assert.Nil(t, err) - - mt2, err := NewMerkleTree(memory.NewMemoryStorage(), 140) - require.Nil(t, err) - defer mt2.db.Close() - err = mt2.ImportDumpedLeafs(d) - assert.Nil(t, err) - - assert.Equal(t, mt.Root(), mt2.Root()) -} - -func TestAddAndGetCircomProof(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) - assert.Nil(t, err) - assert.Equal(t, "0", mt.Root().String()) - - // test vectors generated using https://github.com/iden3/circomlib smt.js - cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2)) - assert.Nil(t, err) - assert.Equal(t, "0", cpp.OldRoot.String()) - assert.Equal(t, "13578938...", cpp.NewRoot.String()) - assert.Equal(t, "0", cpp.OldKey.String()) - assert.Equal(t, "0", cpp.OldValue.String()) - assert.Equal(t, "1", cpp.NewKey.String()) - assert.Equal(t, "2", cpp.NewValue.String()) - assert.Equal(t, true, cpp.IsOld0) - assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) - assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings)) - - cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - assert.Equal(t, "13578938...", cpp.OldRoot.String()) - assert.Equal(t, "54123936...", cpp.NewRoot.String()) - assert.Equal(t, "1", cpp.OldKey.String()) - assert.Equal(t, "2", cpp.OldValue.String()) - assert.Equal(t, "33", cpp.NewKey.String()) - assert.Equal(t, "44", cpp.NewValue.String()) - assert.Equal(t, false, cpp.IsOld0) - assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) - assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings)) - - cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66)) - assert.Nil(t, err) - assert.Equal(t, "54123936...", cpp.OldRoot.String()) - assert.Equal(t, "50943640...", cpp.NewRoot.String()) - assert.Equal(t, "0", cpp.OldKey.String()) - assert.Equal(t, "0", cpp.OldValue.String()) - assert.Equal(t, "55", cpp.NewKey.String()) - assert.Equal(t, "66", cpp.NewValue.String()) - assert.Equal(t, true, cpp.IsOld0) - assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) - assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings)) -} - -func TestUpdateCircomProcessorProof(t *testing.T) { - mt := newTestingMerkle(t, 10) - defer mt.db.Close() - - for i := 0; i < 16; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(int64(i * 2)) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - _, v, _, err := mt.Get(big.NewInt(10)) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(20), v) - - // test vectors generated using https://github.com/iden3/circomlib smt.js - cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024)) - assert.Nil(t, err) - assert.Equal(t, "39010880...", cpp.OldRoot.String()) - assert.Equal(t, "18587862...", cpp.NewRoot.String()) - assert.Equal(t, "10", cpp.OldKey.String()) - assert.Equal(t, "20", cpp.OldValue.String()) - assert.Equal(t, "10", cpp.NewKey.String()) - assert.Equal(t, "1024", cpp.NewValue.String()) - assert.Equal(t, false, cpp.IsOld0) - assert.Equal(t, - "[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]", - fmt.Sprintf("%v", cpp.Siblings)) -} - -func TestSmtVerifier(t *testing.T) { - mt, err := NewMerkleTree(memory.NewMemoryStorage(), 4) - assert.Nil(t, err) - - err = mt.Add(big.NewInt(1), big.NewInt(11)) - assert.Nil(t, err) - - cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil) - assert.Nil(t, err) - jCvp, err := json.Marshal(cvp) - assert.Nil(t, err) - // expect siblings to be '[]', instead of 'null' - expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll - - assert.Equal(t, expected, string(jCvp)) - err = mt.Add(big.NewInt(2), big.NewInt(22)) - assert.Nil(t, err) - err = mt.Add(big.NewInt(3), big.NewInt(33)) - assert.Nil(t, err) - err = mt.Add(big.NewInt(4), big.NewInt(44)) - assert.Nil(t, err) - - cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil) - assert.Nil(t, err) - - jCvp, err = json.Marshal(cvp) - assert.Nil(t, err) - // Test vectors generated using https://github.com/iden3/circomlib smt.js - // Expect siblings with the extra 0 that the circom circuits need - expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll - assert.Equal(t, expected, string(jCvp)) - - cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil) - assert.Nil(t, err) - - jCvp, err = json.Marshal(cvp) - assert.Nil(t, err) - // Test vectors generated using https://github.com/iden3/circomlib smt.js - // Without the extra 0 that the circom circuits need, but that are not - // needed at a smart contract verification - expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll - assert.Equal(t, expected, string(jCvp)) -} - -func TestTypesMarshalers(t *testing.T) { - // test Hash marshalers - h, err := NewHashFromString("42") - assert.Nil(t, err) - s, err := json.Marshal(h) - assert.Nil(t, err) - var h2 *Hash - err = json.Unmarshal(s, &h2) - assert.Nil(t, err) - assert.Equal(t, h, h2) - - // create CircomProcessorProof - mt := newTestingMerkle(t, 10) - defer mt.db.Close() - for i := 0; i < 16; i++ { - k := big.NewInt(int64(i)) - v := big.NewInt(int64(i * 2)) - if err := mt.Add(k, v); err != nil { - t.Fatal(err) - } - } - _, v, _, err := mt.Get(big.NewInt(10)) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(20), v) - cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024)) - assert.Nil(t, err) - - // test CircomProcessorProof marshalers - b, err := json.Marshal(&cpp) - assert.Nil(t, err) - - var cpp2 *CircomProcessorProof - err = json.Unmarshal(b, &cpp2) - assert.Nil(t, err) - assert.Equal(t, cpp, cpp2) -} diff --git a/proof.go b/proof.go new file mode 100644 index 0000000..6d243a2 --- /dev/null +++ b/proof.go @@ -0,0 +1,165 @@ +package merkletree + +import ( + "bytes" + "fmt" + "math/big" +) + +// Proof defines the required elements for a MT proof of existence or +// non-existence. +type Proof struct { + // existence indicates wether this is a proof of existence or + // non-existence. + Existence bool + // depth indicates how deep in the tree the proof goes. + depth uint + // notempties is a bitmap of non-empty Siblings found in Siblings. + notempties [ElemBytesLen - proofFlagsLen]byte + // Siblings is a list of non-empty sibling keys. + Siblings []*Hash + NodeAux *NodeAux +} + +// NewProofFromBytes parses a byte array into a Proof. +func NewProofFromBytes(bs []byte) (*Proof, error) { + if len(bs) < ElemBytesLen { + return nil, ErrInvalidProofBytes + } + p := &Proof{} + if (bs[0] & 0x01) == 0 { + p.Existence = true + } + p.depth = uint(bs[1]) + copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen]) + siblingBytes := bs[ElemBytesLen:] + sibIdx := 0 + for i := uint(0); i < p.depth; i++ { + if TestBitBigEndian(p.notempties[:], i) { + if len(siblingBytes) < (sibIdx+1)*ElemBytesLen { + return nil, ErrInvalidProofBytes + } + var sib Hash + copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen]) + p.Siblings = append(p.Siblings, &sib) + sibIdx++ + } + } + + if !p.Existence && ((bs[0] & 0x02) != 0) { + p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}} + nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:] + if len(nodeAuxBytes) != 2*ElemBytesLen { + return nil, ErrInvalidProofBytes + } + copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen]) + copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen]) + } + return p, nil +} + +// Bytes serializes a Proof into a byte array. +func (p *Proof) Bytes() []byte { + bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings) + if p.NodeAux != nil { + bsLen += 2 * ElemBytesLen //nolint:gomnd + } + bs := make([]byte, bsLen) + + if !p.Existence { + bs[0] |= 0x01 + } + bs[1] = byte(p.depth) + copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:]) + siblingsBytes := bs[len(p.notempties)+proofFlagsLen:] + for i, k := range p.Siblings { + copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:]) + } + if p.NodeAux != nil { + bs[0] |= 0x02 + copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:]) + copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:]) + } + return bs +} + +// SiblingsFromProof returns all the siblings of the proof. +func SiblingsFromProof(proof *Proof) []*Hash { + sibIdx := 0 + siblings := []*Hash{} + for lvl := 0; lvl < int(proof.depth); lvl++ { + if TestBitBigEndian(proof.notempties[:], uint(lvl)) { + siblings = append(siblings, proof.Siblings[sibIdx]) + sibIdx++ + } else { + siblings = append(siblings, &HashZero) + } + } + return siblings +} + +// AllSiblings returns all the siblings of the proof. +func (p *Proof) AllSiblings() []*Hash { + return SiblingsFromProof(p) +} + +// VerifyProof verifies the Merkle Proof for the entry and root. +func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool { + rootFromProof, err := RootFromProof(proof, k, v) + if err != nil { + return false + } + return bytes.Equal(rootKey[:], rootFromProof[:]) +} + +// RootFromProof calculates the root that would correspond to a tree whose +// siblings are the ones in the proof with the leaf hashing to hIndex and +// hValue. +func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) { + kHash := NewHashFromBigInt(k) + vHash := NewHashFromBigInt(v) + sibIdx := len(proof.Siblings) - 1 + var err error + var midKey *Hash + if proof.Existence { + midKey, err = LeafKey(kHash, vHash) + if err != nil { + return nil, err + } + } else { + if proof.NodeAux == nil { + midKey = &HashZero + } else { + if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) { + return nil, + fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux") + } + midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value) + if err != nil { + return nil, err + } + } + } + path := getPath(int(proof.depth), kHash[:]) + var siblingKey *Hash + for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- { + if TestBitBigEndian(proof.notempties[:], uint(lvl)) { + siblingKey = proof.Siblings[sibIdx] + sibIdx-- + } else { + siblingKey = &HashZero + } + if path[lvl] { + midKey, err = NewNodeMiddle(siblingKey, midKey).Key() + if err != nil { + return nil, err + } + } else { + midKey, err = NewNodeMiddle(midKey, siblingKey).Key() + if err != nil { + return nil, err + } + } + } + return midKey, nil +} diff --git a/utils.go b/utils.go index af8b606..a1af586 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,9 @@ package merkletree import ( + "encoding/binary" + "fmt" + "io" "math/big" "github.com/iden3/go-iden3-crypto/poseidon" @@ -56,3 +59,68 @@ func SwapEndianness(b []byte) []byte { } return o } + +func checkKVLen(kLen, vLen int) error { + if kLen > 0xff { + return fmt.Errorf("len(k) %d > 0xff", kLen) + } + if vLen > 0xffff { + return fmt.Errorf("len(v) %d > 0xffff", vLen) + } + return nil +} + +func serializeKV(w io.Writer, k, v []byte) error { + if err := checkKVLen(len(k), len(v)); err != nil { + return err + } + kH := byte(len(k)) + vH := Uint16ToBytes(uint16(len(v))) + _, err := w.Write([]byte{kH}) + if err != nil { + return err + } + _, err = w.Write(vH) + if err != nil { + return err + } + _, err = w.Write(k) + if err != nil { + return err + } + _, err = w.Write(v) + if err != nil { + return err + } + return nil +} + +func deserializeKV(r io.Reader) ([]byte, []byte, error) { + header := make([]byte, 3) + _, err := io.ReadFull(r, header) + if err != nil { + return nil, nil, err + } + kLen := int(header[0]) + vLen := int(BytesToUint16(header[1:])) + kv := make([]byte, kLen+vLen) + _, err = io.ReadFull(r, kv) + if err == io.EOF { + return nil, nil, io.ErrUnexpectedEOF + } else if err != nil { + return nil, nil, err + } + return kv[:kLen], kv[kLen:], nil +} + +// Uint16ToBytes returns a byte array from a uint16 +func Uint16ToBytes(u uint16) []byte { + var b [2]byte + binary.LittleEndian.PutUint16(b[:], u) + return b[:] +} + +// BytesToUint16 returns a uint16 from a byte array +func BytesToUint16(b []byte) uint16 { + return binary.LittleEndian.Uint16(b[:2]) +}