Browse Source

Fixed sql tx close, fixed unit tests. Refactoring. Added missing structs and methods.

feature/postgres
Oleksandr Brezhniev 3 years ago
parent
commit
113995d6f4
13 changed files with 1531 additions and 1044 deletions
  1. +52
    -0
      data.go
  2. +11
    -2
      db/memory/memory.go
  3. +25
    -0
      db/memory/memory_test.go
  4. +29
    -11
      db/sql/sql.go
  5. +73
    -18
      db/sql/sql_test.go
  6. +753
    -27
      db/test/test.go
  7. +49
    -0
      elembytes.go
  8. +98
    -0
      entry.go
  9. +124
    -0
      hash.go
  10. +84
    -269
      merkletree.go
  11. +0
    -717
      merkletree_test.go
  12. +165
    -0
      proof.go
  13. +68
    -0
      utils.go

+ 52
- 0
data.go

@ -0,0 +1,52 @@
package merkletree
import (
"bytes"
"encoding/hex"
"fmt"
)
// Data is the type used to represent the data stored in an entry of the MT.
// It consists of 8 elements: e0, e1, e2, e3, ...;
// where v = [e0,e1], index = [e2,e3].
type Data [DataLen]ElemBytes
func (d *Data) String() string {
return fmt.Sprintf("%s%s%s%s", hex.EncodeToString(d[0][:]), hex.EncodeToString(d[1][:]),
hex.EncodeToString(d[2][:]), hex.EncodeToString(d[3][:]))
}
func (d *Data) Bytes() (b [ElemBytesLen * DataLen]byte) {
for i := 0; i < DataLen; i++ {
copy(b[i*ElemBytesLen:(i+1)*ElemBytesLen], d[i][:])
}
return b
}
func (d1 *Data) Equal(d2 *Data) bool {
return bytes.Equal(d1[0][:], d2[0][:]) && bytes.Equal(d1[1][:], d2[1][:]) &&
bytes.Equal(d1[2][:], d2[2][:]) && bytes.Equal(d1[3][:], d2[3][:])
}
func (d Data) MarshalText() ([]byte, error) {
dataBytes := d.Bytes()
return []byte(hex.EncodeToString(dataBytes[:])), nil
}
func (d *Data) UnmarshalText(text []byte) error {
var dataBytes [ElemBytesLen * DataLen]byte
_, err := hex.Decode(dataBytes[:], text)
if err != nil {
return err
}
*d = *NewDataFromBytes(dataBytes)
return nil
}
func NewDataFromBytes(b [ElemBytesLen * DataLen]byte) *Data {
d := &Data{}
for i := 0; i < DataLen; i++ {
copy(d[i][:], b[i*ElemBytesLen : (i+1)*ElemBytesLen][:])
}
return d
}

+ 11
- 2
db/memory/memory.go

@ -46,7 +46,9 @@ func (m *Storage) Get(key []byte) (*merkletree.Node, error) {
func (m *Storage) GetRoot() (*merkletree.Hash, error) { func (m *Storage) GetRoot() (*merkletree.Hash, error) {
if m.currentRoot != nil { if m.currentRoot != nil {
return m.currentRoot, nil
hash := merkletree.Hash{}
copy(hash[:], m.currentRoot[:])
return &hash, nil
} }
return nil, merkletree.ErrNotFound return nil, merkletree.ErrNotFound
} }
@ -97,7 +99,7 @@ func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) { func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
if tx.currentRoot != nil { if tx.currentRoot != nil {
hash := merkletree.Hash{} hash := merkletree.Hash{}
copy(tx.currentRoot[:], hash[:])
copy(hash[:], tx.currentRoot[:])
return &hash, nil return &hash, nil
} }
return nil, merkletree.ErrNotFound return nil, merkletree.ErrNotFound
@ -105,6 +107,9 @@ func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
// SetRoot sets a hash of merkle tree root in the interface db.Tx // SetRoot sets a hash of merkle tree root in the interface db.Tx
func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error { func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
// TODO: do tx.Put('currentroot', hash) here ?
root := &merkletree.Hash{} root := &merkletree.Hash{}
copy(root[:], hash[:]) copy(root[:], hash[:])
tx.currentRoot = root tx.currentRoot = root
@ -116,6 +121,10 @@ func (tx *StorageTx) Commit() error {
for _, v := range tx.kv { for _, v := range tx.kv {
tx.s.kv.Put(v.K, v.V) tx.s.kv.Put(v.K, v.V)
} }
//if tx.currentRoot == nil {
// tx.currentRoot = &merkletree.Hash{}
//}
tx.s.currentRoot = tx.currentRoot
tx.kv = nil tx.kv = nil
return nil return nil
} }

+ 25
- 0
db/memory/memory_test.go

@ -22,4 +22,29 @@ func TestMemory(t *testing.T) {
test.TestConcatTx(t, NewMemoryStorage()) test.TestConcatTx(t, NewMemoryStorage())
test.TestList(t, NewMemoryStorage()) test.TestList(t, NewMemoryStorage())
test.TestIterate(t, NewMemoryStorage()) test.TestIterate(t, NewMemoryStorage())
test.TestNewTree(t, NewMemoryStorage())
test.TestAddDifferentOrder(t, NewMemoryStorage(), NewMemoryStorage())
test.TestAddRepeatedIndex(t, NewMemoryStorage())
test.TestGet(t, NewMemoryStorage())
test.TestUpdate(t, NewMemoryStorage())
test.TestUpdate2(t, NewMemoryStorage())
test.TestGenerateAndVerifyProof128(t, NewMemoryStorage())
test.TestTreeLimit(t, NewMemoryStorage())
test.TestSiblingsFromProof(t, NewMemoryStorage())
test.TestVerifyProofCases(t, NewMemoryStorage())
test.TestVerifyProofFalse(t, NewMemoryStorage())
test.TestGraphViz(t, NewMemoryStorage())
test.TestDelete(t, NewMemoryStorage())
test.TestDelete2(t, NewMemoryStorage(), NewMemoryStorage())
test.TestDelete3(t, NewMemoryStorage(), NewMemoryStorage())
test.TestDelete4(t, NewMemoryStorage(), NewMemoryStorage())
test.TestDelete5(t, NewMemoryStorage(), NewMemoryStorage())
test.TestDeleteNonExistingKeys(t, NewMemoryStorage())
test.TestDumpLeafsImportLeafs(t, NewMemoryStorage(), NewMemoryStorage())
test.TestAddAndGetCircomProof(t, NewMemoryStorage())
test.TestUpdateCircomProcessorProof(t, NewMemoryStorage())
test.TestSmtVerifier(t, NewMemoryStorage())
test.TestTypesMarshalers(t, NewMemoryStorage())
} }

+ 29
- 11
db/sql/sql.go

@ -1,6 +1,7 @@
package sql package sql
import ( import (
"crypto/sha256"
"database/sql" "database/sql"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -29,7 +30,7 @@ type Storage struct {
type StorageTx struct { type StorageTx struct {
*Storage *Storage
tx *sqlx.Tx tx *sqlx.Tx
cache merkletree.KvMap
cache KvMap
currentRoot *merkletree.Hash currentRoot *merkletree.Hash
} }
@ -74,7 +75,7 @@ func (s *Storage) NewTx() (merkletree.Tx, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &StorageTx{s, tx, make(merkletree.KvMap), s.currentRoot}, nil
return &StorageTx{s, tx, make(KvMap), s.currentRoot}, nil
} }
// Get retrieves a value from a key in the db.Storage // Get retrieves a value from a key in the db.Storage
@ -167,7 +168,7 @@ func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error { func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
//fullKey := append(tx.mtId, k...) //fullKey := append(tx.mtId, k...)
fullKey := k fullKey := k
tx.cache.Put(fullKey, *v)
tx.cache.Put(tx.mtId, fullKey, *v)
fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v) fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
return nil return nil
} }
@ -204,17 +205,13 @@ func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
// Add implements the method Add of the interface db.Tx // Add implements the method Add of the interface db.Tx
func (tx *StorageTx) Add(atx merkletree.Tx) error { func (tx *StorageTx) Add(atx merkletree.Tx) error {
dbtx := atx.(*StorageTx) dbtx := atx.(*StorageTx)
//if !bytes.Equal(tx.prefix, dbtx.prefix) {
// // TODO: change cache to store prefix too!
// return errors.New("adding StorageTx with different prefix is not implemented")
//}
if tx.mtId != dbtx.mtId { if tx.mtId != dbtx.mtId {
// TODO: change cache to store prefix too!
return errors.New("adding StorageTx with different prefix is not implemented") return errors.New("adding StorageTx with different prefix is not implemented")
} }
for _, v := range dbtx.cache { for _, v := range dbtx.cache {
tx.cache.Put(v.K, v.V)
tx.cache.Put(v.MTId, v.K, v.V)
} }
// TODO: change cache to store different currentRoots for different mtIds too!
tx.currentRoot = dbtx.currentRoot tx.currentRoot = dbtx.currentRoot
return nil return nil
} }
@ -246,7 +243,7 @@ func (tx *StorageTx) Commit() error {
if err != nil { if err != nil {
return err return err
} }
_, err = tx.tx.Exec(upsertStmt, tx.mtId, key[:], node.Type, childL, childR, entry)
_, err = tx.tx.Exec(upsertStmt, v.MTId, key[:], node.Type, childL, childR, entry)
if err != nil { if err != nil {
return err return err
} }
@ -266,7 +263,7 @@ func (tx *StorageTx) Commit() error {
// Close implements the method Close of the interface db.Tx // Close implements the method Close of the interface db.Tx
func (tx *StorageTx) Close() { func (tx *StorageTx) Close() {
//tx.tx.Rollback()
tx.tx.Rollback()
tx.cache = nil tx.cache = nil
} }
@ -313,3 +310,24 @@ func (item *NodeItem) Node() (*merkletree.Node, error) {
} }
return &node, nil return &node, nil
} }
// KV contains a key (K) and a value (V)
type KV struct {
MTId uint64
K []byte
V merkletree.Node
}
// KvMap is a key-value map between a sha256 byte array hash, and a KV struct
type KvMap map[[sha256.Size]byte]KV
// Get retrieves the value respective to a key from the KvMap
func (m KvMap) Get(k []byte) (merkletree.Node, bool) {
v, ok := m[sha256.Sum256(k)]
return v.V, ok
}
// Put stores a key and a value in the KvMap
func (m KvMap) Put(mtId uint64, k []byte, v merkletree.Node) {
m[sha256.Sum256(k)] = KV{mtId, k, v}
}

+ 73
- 18
db/sql/sql_test.go

@ -4,11 +4,13 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-iden3-crypto/constants"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils" cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
"github.com/iden3/go-merkletree" "github.com/iden3/go-merkletree"
"github.com/iden3/go-merkletree/db/memory" "github.com/iden3/go-merkletree/db/memory"
"github.com/iden3/go-merkletree/db/test"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -18,7 +20,11 @@ import (
"testing" "testing"
) )
func sqlStorage(t *testing.T) merkletree.Storage {
var maxMTId uint64 = 0
var cleared = false
func setupDB() (*sqlx.DB, error) {
var err error
host := os.Getenv("PGHOST") host := os.Getenv("PGHOST")
if host == "" { if host == "" {
host = "localhost" host = "localhost"
@ -33,7 +39,7 @@ func sqlStorage(t *testing.T) merkletree.Storage {
} }
password := os.Getenv("PGPASSWORD") password := os.Getenv("PGPASSWORD")
if password == "" { if password == "" {
panic("No PGPASSWORD envvar specified")
return nil, errors.New("No PGPASSWORD envvar specified")
} }
dbname := os.Getenv("PGDATABASE") dbname := os.Getenv("PGDATABASE")
if dbname == "" { if dbname == "" {
@ -50,19 +56,34 @@ func sqlStorage(t *testing.T) merkletree.Storage {
) )
dbx, err := sqlx.Connect("postgres", psqlconn) dbx, err := sqlx.Connect("postgres", psqlconn)
if err != nil { if err != nil {
t.Fatal(err)
return nil
return nil, err
} }
// clear MerkleTree table // clear MerkleTree table
//if !cleared {
dbx.Exec("TRUNCATE TABLE mt_roots") dbx.Exec("TRUNCATE TABLE mt_roots")
dbx.Exec("TRUNCATE TABLE mt_nodes") dbx.Exec("TRUNCATE TABLE mt_nodes")
cleared = true
//}
return dbx, nil
}
func sqlStorage(t *testing.T) merkletree.Storage {
dbx, err := setupDB()
if err != nil {
t.Fatal(err)
return nil
}
sto, err := NewSqlStorage(dbx, false) sto, err := NewSqlStorage(dbx, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return nil return nil
} }
sto.mtId = maxMTId
maxMTId++
t.Cleanup(func() { t.Cleanup(func() {
}) })
@ -70,26 +91,60 @@ func sqlStorage(t *testing.T) merkletree.Storage {
return sto return sto
} }
func TestReturnKnownErrIfNotExists(t *testing.T) {
test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
}
func TestStorageInsertGet(t *testing.T) {
test.TestStorageInsertGet(t, sqlStorage(t))
}
func TestStorageWithPrefix(t *testing.T) {
test.TestStorageWithPrefix(t, sqlStorage(t))
}
func TestSql(t *testing.T) { func TestSql(t *testing.T) {
//sto := sqlStorage(t) //sto := sqlStorage(t)
//t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
// test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
//})
//t.Run("TestStorageInsertGet", func(t *testing.T) {
// test.TestStorageInsertGet(t, sqlStorage(t))
//})
//test.TestStorageWithPrefix(t, sqlStorage(t))
//test.TestConcatTx(t, sqlStorage(t))
//test.TestList(t, sqlStorage(t))
//test.TestIterate(t, sqlStorage(t))
t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
})
t.Run("TestStorageInsertGet", func(t *testing.T) {
test.TestStorageInsertGet(t, sqlStorage(t))
})
t.Run("TestStorageWithPrefix", func(t *testing.T) {
test.TestStorageWithPrefix(t, sqlStorage(t))
})
test.TestConcatTx(t, sqlStorage(t))
test.TestList(t, sqlStorage(t))
test.TestIterate(t, sqlStorage(t))
test.TestNewTree(t, sqlStorage(t))
test.TestAddDifferentOrder(t, sqlStorage(t), sqlStorage(t))
test.TestAddRepeatedIndex(t, sqlStorage(t))
test.TestGet(t, sqlStorage(t))
test.TestUpdate(t, sqlStorage(t))
test.TestUpdate2(t, sqlStorage(t))
test.TestGenerateAndVerifyProof128(t, sqlStorage(t))
test.TestTreeLimit(t, sqlStorage(t))
test.TestSiblingsFromProof(t, sqlStorage(t))
test.TestVerifyProofCases(t, sqlStorage(t))
test.TestVerifyProofFalse(t, sqlStorage(t))
test.TestGraphViz(t, sqlStorage(t))
test.TestDelete(t, sqlStorage(t))
test.TestDelete2(t, sqlStorage(t), sqlStorage(t))
test.TestDelete3(t, sqlStorage(t), sqlStorage(t))
test.TestDelete4(t, sqlStorage(t), sqlStorage(t))
test.TestDelete5(t, sqlStorage(t), sqlStorage(t))
test.TestDeleteNonExistingKeys(t, sqlStorage(t))
test.TestDumpLeafsImportLeafs(t, sqlStorage(t), sqlStorage(t))
test.TestAddAndGetCircomProof(t, sqlStorage(t))
test.TestUpdateCircomProcessorProof(t, sqlStorage(t))
test.TestSmtVerifier(t, sqlStorage(t))
test.TestTypesMarshalers(t, sqlStorage(t))
} }
var debug = false var debug = false
type Fatalable interface {
Fatal(args ...interface{})
}
func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree { func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree {
sto := sqlStorage(f) sto := sqlStorage(f)

+ 753
- 27
db/test/test.go

@ -2,23 +2,39 @@
package test package test
import ( import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/iden3/go-iden3-crypto/constants"
"github.com/iden3/go-merkletree" "github.com/iden3/go-merkletree"
"github.com/stretchr/testify/require"
"math/big"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var debug = false
func newTestingMerkle(t *testing.T, sto merkletree.Storage, numLevels int) *merkletree.MerkleTree {
mt, err := merkletree.NewMerkleTree(sto, numLevels)
if err != nil {
t.Fatal(err)
return nil
}
return mt
}
// TestReturnKnownErrIfNotExists checks that the implementation of the // TestReturnKnownErrIfNotExists checks that the implementation of the
// db.Storage interface returns the expected error in the case that the value // db.Storage interface returns the expected error in the case that the value
// is not found // is not found
func TestReturnKnownErrIfNotExists(t *testing.T, sto merkletree.Storage) { func TestReturnKnownErrIfNotExists(t *testing.T, sto merkletree.Storage) {
//defer sto.Close()
k := []byte("key") k := []byte("key")
tx, err := sto.NewTx() tx, err := sto.NewTx()
//defer func() {
// tx.Close()
// sto.Close()
//}()
defer tx.Close()
assert.Nil(t, err) assert.Nil(t, err)
_, err = tx.Get(k) _, err = tx.Get(k)
@ -28,28 +44,30 @@ func TestReturnKnownErrIfNotExists(t *testing.T, sto merkletree.Storage) {
// TestStorageInsertGet checks that the implementation of the db.Storage // TestStorageInsertGet checks that the implementation of the db.Storage
// interface behaves as expected // interface behaves as expected
func TestStorageInsertGet(t *testing.T, sto merkletree.Storage) { func TestStorageInsertGet(t *testing.T, sto merkletree.Storage) {
key := []byte("key")
defer sto.Close()
value := merkletree.Hash{1, 1, 1, 1} value := merkletree.Hash{1, 1, 1, 1}
tx, err := sto.NewTx() tx, err := sto.NewTx()
//defer func() {
// tx.Close()
// sto.Close()
//}()
defer tx.Close()
assert.Nil(t, err) assert.Nil(t, err)
node := merkletree.NewNodeMiddle(&value, &value) node := merkletree.NewNodeMiddle(&value, &value)
err = tx.Put(key, node)
key, err := node.Key()
assert.Nil(t, err) assert.Nil(t, err)
v, err := tx.Get(key)
err = tx.Put(key[:], node)
assert.Nil(t, err)
v, err := tx.Get(key[:])
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, value, *v.ChildL) assert.Equal(t, value, *v.ChildL)
assert.Equal(t, value, *v.ChildR) assert.Equal(t, value, *v.ChildR)
assert.Nil(t, tx.Commit()) assert.Nil(t, tx.Commit())
tx, err = sto.NewTx()
tx2, err := sto.NewTx()
defer tx2.Close()
assert.Nil(t, err) assert.Nil(t, err)
v, err = tx.Get(key)
v, err = tx2.Get(key[:])
assert.Nil(t, err) assert.Nil(t, err)
require.NotNil(t, v)
assert.Equal(t, value, *v.ChildL) assert.Equal(t, value, *v.ChildL)
assert.Equal(t, value, *v.ChildR) assert.Equal(t, value, *v.ChildR)
} }
@ -57,7 +75,7 @@ func TestStorageInsertGet(t *testing.T, sto merkletree.Storage) {
// TestStorageWithPrefix checks that the implementation of the db.Storage // TestStorageWithPrefix checks that the implementation of the db.Storage
// interface behaves as expected for the WithPrefix method // interface behaves as expected for the WithPrefix method
func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) { func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) {
k := []byte{9}
defer sto.Close()
sto1 := sto.WithPrefix([]byte{1}) sto1 := sto.WithPrefix([]byte{1})
sto2 := sto.WithPrefix([]byte{2}) sto2 := sto.WithPrefix([]byte{2})
@ -67,37 +85,44 @@ func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) {
sto1tx, err := sto1.NewTx() sto1tx, err := sto1.NewTx()
assert.Nil(t, err) assert.Nil(t, err)
node := merkletree.NewNodeLeaf(&merkletree.Hash{1, 2, 3}, &merkletree.Hash{4, 5, 6}) node := merkletree.NewNodeLeaf(&merkletree.Hash{1, 2, 3}, &merkletree.Hash{4, 5, 6})
err = sto1tx.Put(k, node)
k, err := node.Key()
err = sto1tx.Put(k[:], node)
assert.Nil(t, err) assert.Nil(t, err)
v1, err := sto1tx.Get(k)
v1, err := sto1tx.Get(k[:])
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1]) assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1])
assert.Nil(t, sto1tx.Commit()) assert.Nil(t, sto1tx.Commit())
sto2tx, err := sto2.NewTx() sto2tx, err := sto2.NewTx()
assert.Nil(t, err) assert.Nil(t, err)
node.Entry[1] = &merkletree.Hash{9, 10}
err = sto2tx.Put(k, node)
v2, err := sto2tx.Get(k[:])
assert.Equal(t, merkletree.ErrNotFound, err)
err = sto2tx.Put(k[:], node)
assert.Nil(t, err) assert.Nil(t, err)
v2, err := sto2tx.Get(k)
v2, err = sto2tx.Get(k[:])
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, merkletree.Hash{9, 10}, *v2.Entry[1])
assert.Equal(t, merkletree.Hash{4, 5, 6}, *v2.Entry[1])
assert.Nil(t, sto2tx.Commit()) assert.Nil(t, sto2tx.Commit())
// check outside tx // check outside tx
v1, err = sto1.Get(k)
v1, err = sto1.Get(k[:])
assert.Nil(t, err) assert.Nil(t, err)
require.NotNil(t, v1)
assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1]) assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1])
v2, err = sto2.Get(k)
v2, err = sto2.Get(k[:])
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, merkletree.Hash{9, 10}, *v2.Entry[1])
require.NotNil(t, v2)
assert.Equal(t, merkletree.Hash{4, 5, 6}, *v2.Entry[1])
} }
// TestIterate checks that the implementation of the db.Storage interface // TestIterate checks that the implementation of the db.Storage interface
// behaves as expected for the Iterate method // behaves as expected for the Iterate method
func TestIterate(t *testing.T, sto merkletree.Storage) { func TestIterate(t *testing.T, sto merkletree.Storage) {
defer sto.Close()
r := []merkletree.KV{} r := []merkletree.KV{}
lister := func(k []byte, v *merkletree.Node) (bool, error) { lister := func(k []byte, v *merkletree.Node) (bool, error) {
r = append(r, merkletree.KV{K: merkletree.Clone(k), V: *v}) r = append(r, merkletree.KV{K: merkletree.Clone(k), V: *v})
@ -175,12 +200,12 @@ func TestConcatTx(t *testing.T, sto merkletree.Storage) {
// check outside tx // check outside tx
v1, err := sto1.Get(k) v1, err := sto1.Get(k)
assert.Nil(t, err)
assert.Equal(t, v1, merkletree.NewNodeLeaf(&merkletree.Hash{4, 5, 6}, &merkletree.Hash{7, 8, 9}))
require.Nil(t, err)
assert.Equal(t, *merkletree.NewNodeLeaf(&merkletree.Hash{4, 5, 6}, &merkletree.Hash{7, 8, 9}), *v1)
v2, err := sto2.Get(k) v2, err := sto2.Get(k)
assert.Nil(t, err)
assert.Equal(t, v2, merkletree.NewNodeLeaf(&merkletree.Hash{8, 9}, &merkletree.Hash{10, 11}))
require.Nil(t, err)
assert.Equal(t, *merkletree.NewNodeLeaf(&merkletree.Hash{8, 9}, &merkletree.Hash{10, 11}), *v2)
} }
// TestList checks that the implementation of the db.Storage interface behaves // TestList checks that the implementation of the db.Storage interface behaves
@ -223,3 +248,704 @@ func TestList(t *testing.T, sto merkletree.Storage) {
assert.Equal(t, r[0], merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5})}) assert.Equal(t, r[0], merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5})})
assert.Equal(t, r[1], merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6})}) assert.Equal(t, r[1], merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6})})
} }
//
// TODO: Add tests for each storage
//
func TestNewTree(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
dbRoot, err := mt.DB().GetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
}
func TestAddDifferentOrder(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
mt1 := newTestingMerkle(t, sto, 140)
defer mt1.DB().Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt1.Add(k, v); err != nil {
t.Fatal(err)
}
}
mt2 := newTestingMerkle(t, sto2, 140)
defer mt2.DB().Close()
for i := 16 - 1; i >= 0; i-- {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt2.Add(k, v); err != nil {
t.Fatal(err)
}
}
assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll
}
func TestAddRepeatedIndex(t *testing.T, sto merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
k := big.NewInt(int64(3))
v := big.NewInt(int64(12))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
err := mt.Add(k, v)
assert.NotNil(t, err)
assert.Equal(t, err, merkletree.ErrEntryIndexAlreadyExists)
}
func TestGet(t *testing.T, sto merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
k, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(10), k)
assert.Equal(t, big.NewInt(20), v)
k, v, _, err = mt.Get(big.NewInt(15))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(15), k)
assert.Equal(t, big.NewInt(30), v)
k, v, _, err = mt.Get(big.NewInt(16))
assert.NotNil(t, err)
assert.Equal(t, merkletree.ErrKeyNotFound, err)
assert.Equal(t, "0", k.String())
assert.Equal(t, "0", v.String())
}
func TestUpdate(t *testing.T, sto merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
_, err = mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
_, v, _, err = mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(1024), v)
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
assert.Equal(t, merkletree.ErrKeyNotFound, err)
dbRoot, err := mt.DB().GetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
}
func TestUpdate2(t *testing.T, sto merkletree.Storage) {
mt1 := newTestingMerkle(t, sto, 140)
defer mt1.DB().Close()
mt2 := newTestingMerkle(t, sto, 140)
defer mt2.DB().Close()
err := mt1.Add(big.NewInt(1), big.NewInt(119))
assert.Nil(t, err)
err = mt1.Add(big.NewInt(2), big.NewInt(229))
assert.Nil(t, err)
err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(9876), big.NewInt(10))
assert.Nil(t, err)
_, err = mt1.Update(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err)
_, err = mt1.Update(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err)
_, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
assert.Nil(t, err)
assert.Equal(t, mt1.Root(), mt2.Root())
}
func TestGenerateAndVerifyProof128(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 140)
require.Nil(t, err)
defer mt.DB().Close()
for i := 0; i < 128; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
assert.Nil(t, err)
assert.Equal(t, "0", v.String())
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
}
func TestTreeLimit(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 5)
require.Nil(t, err)
defer mt.DB().Close()
for i := 0; i < 16; i++ {
err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
assert.Nil(t, err)
}
// here the tree is full, should not allow to add more data as reaches the maximum number of levels
err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
assert.NotNil(t, err)
assert.Equal(t, merkletree.ErrReachedMaxLevel, err)
}
func TestSiblingsFromProof(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 140)
require.Nil(t, err)
defer mt.DB().Close()
for i := 0; i < 64; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
if err != nil {
t.Fatal(err)
}
siblings := merkletree.SiblingsFromProof(proof)
assert.Equal(t, 6, len(siblings))
assert.Equal(t,
"d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b",
siblings[0].Hex())
assert.Equal(t,
"9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16",
siblings[1].Hex())
assert.Equal(t,
"de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27",
siblings[2].Hex())
assert.Equal(t,
"5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317",
siblings[3].Hex())
assert.Equal(t,
"77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122",
siblings[4].Hex())
assert.Equal(t,
"943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07",
siblings[5].Hex())
}
func TestVerifyProofCases(t *testing.T, sto merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
for i := 0; i < 8; i++ {
if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
t.Fatal(err)
}
}
// Existence proof
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll
for i := 8; i < 32; i++ {
proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
assert.Nil(t, err)
if debug {
fmt.Println(i, proof)
}
}
// Non-existence proof, empty aux
proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, false)
// assert.True(t, proof.nodeAux == nil)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
// Non-existence proof, diff. node aux
proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, false)
assert.True(t, proof.NodeAux != nil)
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
}
func TestVerifyProofFalse(t *testing.T, sto merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
for i := 0; i < 8; i++ {
if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
t.Fatal(err)
}
}
// Invalid existence proof (node used for verification doesn't
// correspond to node in the proof)
proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
// Invalid non-existence proof (Non-existence proof, diff. node aux)
proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
// Now we change the proof from existence to non-existence, and add e's
// data as auxiliary node.
proof.Existence = false
proof.NodeAux = &merkletree.NodeAux{Key: merkletree.NewHashFromBigInt(big.NewInt(int64(4))),
Value: merkletree.NewHashFromBigInt(big.NewInt(4))}
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
}
func TestGraphViz(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 10)
assert.Nil(t, err)
_ = mt.Add(big.NewInt(1), big.NewInt(0))
_ = mt.Add(big.NewInt(2), big.NewInt(0))
_ = mt.Add(big.NewInt(3), big.NewInt(0))
_ = mt.Add(big.NewInt(4), big.NewInt(0))
_ = mt.Add(big.NewInt(5), big.NewInt(0))
_ = mt.Add(big.NewInt(100), big.NewInt(0))
// mt.PrintGraphViz(nil)
expected := `digraph hierarchy {
node [fontname=Monospace,fontsize=10,shape=box]
"56332309..." -> {"18483622..." "20902180..."}
"18483622..." -> {"75768243..." "16893244..."}
"75768243..." -> {"empty0" "21857056..."}
"empty0" [style=dashed,label=0];
"21857056..." -> {"51072523..." "empty1"}
"empty1" [style=dashed,label=0];
"51072523..." -> {"17311038..." "empty2"}
"empty2" [style=dashed,label=0];
"17311038..." -> {"69499803..." "21008290..."}
"69499803..." [style=filled];
"21008290..." [style=filled];
"16893244..." [style=filled];
"20902180..." -> {"12496585..." "18055627..."}
"12496585..." -> {"19374975..." "15739329..."}
"19374975..." [style=filled];
"15739329..." [style=filled];
"18055627..." [style=filled];
}
`
w := bytes.NewBufferString("")
err = mt.GraphViz(w, nil)
assert.Nil(t, err)
assert.Equal(t, []byte(expected), w.Bytes())
}
func TestDelete(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
// mt.PrintGraphViz(nil)
err = mt.Delete(big.NewInt(33))
// mt.PrintGraphViz(nil)
assert.Nil(t, err)
assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1234))
assert.Nil(t, err)
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
dbRoot, err := mt.DB().GetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
}
func TestDelete2(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
for i := 0; i < 8; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
expectedRoot := mt.Root()
k := big.NewInt(8)
v := big.NewInt(0)
err := mt.Add(k, v)
require.Nil(t, err)
err = mt.Delete(big.NewInt(8))
assert.Nil(t, err)
assert.Equal(t, expectedRoot, mt.Root())
mt2 := newTestingMerkle(t, sto2, 140)
defer mt2.DB().Close()
for i := 0; i < 8; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt2.Add(k, v); err != nil {
t.Fatal(err)
}
}
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete3(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
err := mt.Add(big.NewInt(1), big.NewInt(1))
assert.Nil(t, err)
err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll
mt2 := newTestingMerkle(t, sto2, 140)
defer mt2.DB().Close()
err = mt2.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete4(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
mt := newTestingMerkle(t, sto, 140)
defer mt.DB().Close()
err := mt.Add(big.NewInt(1), big.NewInt(1))
assert.Nil(t, err)
err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)
assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll
mt2 := newTestingMerkle(t, sto2, 140)
defer mt2.DB().Close()
err = mt2.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete5(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 10)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll
mt2 := newTestingMerkle(t, sto2, 140)
defer mt2.DB().Close()
err = mt2.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDeleteNonExistingKeys(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 10)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
err = mt.Delete(big.NewInt(33))
assert.Nil(t, err)
err = mt.Delete(big.NewInt(33))
assert.Equal(t, merkletree.ErrKeyNotFound, err)
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
err = mt.Delete(big.NewInt(33))
assert.Equal(t, merkletree.ErrKeyNotFound, err)
}
func TestDumpLeafsImportLeafs(t *testing.T, sto merkletree.Storage, sto2 merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 140)
require.Nil(t, err)
defer mt.DB().Close()
q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
for i := 0; i < 10; i++ {
// use numbers near under Q
k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
v := big.NewInt(0)
err = mt.Add(k, v)
require.Nil(t, err)
// use numbers near above 0
k = big.NewInt(int64(i))
err = mt.Add(k, v)
require.Nil(t, err)
}
d, err := mt.DumpLeafs(nil)
assert.Nil(t, err)
mt2, err := merkletree.NewMerkleTree(sto2, 140)
require.Nil(t, err)
defer mt2.DB().Close()
err = mt2.ImportDumpedLeafs(d)
assert.Nil(t, err)
assert.Equal(t, mt.Root(), mt2.Root())
}
func TestAddAndGetCircomProof(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "0", cpp.OldRoot.String())
assert.Equal(t, "13578938...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "1", cpp.NewKey.String())
assert.Equal(t, "2", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "13578938...", cpp.OldRoot.String())
assert.Equal(t, "54123936...", cpp.NewRoot.String())
assert.Equal(t, "1", cpp.OldKey.String())
assert.Equal(t, "2", cpp.OldValue.String())
assert.Equal(t, "33", cpp.NewKey.String())
assert.Equal(t, "44", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
assert.Nil(t, err)
assert.Equal(t, "54123936...", cpp.OldRoot.String())
assert.Equal(t, "50943640...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "55", cpp.NewKey.String())
assert.Equal(t, "66", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.MaxLevels()+1, len(cpp.Siblings))
}
func TestUpdateCircomProcessorProof(t *testing.T, sto merkletree.Storage) {
mt := newTestingMerkle(t, sto, 10)
defer mt.DB().Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
assert.Equal(t, "39010880...", cpp.OldRoot.String())
assert.Equal(t, "18587862...", cpp.NewRoot.String())
assert.Equal(t, "10", cpp.OldKey.String())
assert.Equal(t, "20", cpp.OldValue.String())
assert.Equal(t, "10", cpp.NewKey.String())
assert.Equal(t, "1024", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t,
"[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]",
fmt.Sprintf("%v", cpp.Siblings))
}
func TestSmtVerifier(t *testing.T, sto merkletree.Storage) {
mt, err := merkletree.NewMerkleTree(sto, 4)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err)
cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil)
assert.Nil(t, err)
jCvp, err := json.Marshal(cvp)
assert.Nil(t, err)
// expect siblings to be '[]', instead of 'null'
expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll
assert.Equal(t, expected, string(jCvp))
err = mt.Add(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err)
err = mt.Add(big.NewInt(3), big.NewInt(33))
assert.Nil(t, err)
err = mt.Add(big.NewInt(4), big.NewInt(44))
assert.Nil(t, err)
cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil)
assert.Nil(t, err)
jCvp, err = json.Marshal(cvp)
assert.Nil(t, err)
// Test vectors generated using https://github.com/iden3/circomlib smt.js
// Expect siblings with the extra 0 that the circom circuits need
expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
assert.Equal(t, expected, string(jCvp))
cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)
assert.Nil(t, err)
jCvp, err = json.Marshal(cvp)
assert.Nil(t, err)
// Test vectors generated using https://github.com/iden3/circomlib smt.js
// Without the extra 0 that the circom circuits need, but that are not
// needed at a smart contract verification
expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
assert.Equal(t, expected, string(jCvp))
}
func TestTypesMarshalers(t *testing.T, sto merkletree.Storage) {
// test Hash marshalers
h, err := merkletree.NewHashFromString("42")
assert.Nil(t, err)
s, err := json.Marshal(h)
assert.Nil(t, err)
var h2 *merkletree.Hash
err = json.Unmarshal(s, &h2)
assert.Nil(t, err)
assert.Equal(t, h, h2)
// create CircomProcessorProof
mt := newTestingMerkle(t, sto, 10)
defer mt.DB().Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
// test CircomProcessorProof marshalers
b, err := json.Marshal(&cpp)
assert.Nil(t, err)
var cpp2 *merkletree.CircomProcessorProof
err = json.Unmarshal(b, &cpp2)
assert.Nil(t, err)
assert.Equal(t, cpp, cpp2)
}

+ 49
- 0
elembytes.go

@ -0,0 +1,49 @@
package merkletree
import (
"encoding/hex"
"fmt"
"math/big"
)
const (
// ElemBytesLen is the length of the Hash byte array
ElemBytesLen = 32
)
// ElemBytes is the basic type used to store data in the MT. ElemBytes
// corresponds to the serialization of an element from mimc7.
type ElemBytes [ElemBytesLen]byte
func NewElemBytesFromBigInt(v *big.Int) (e ElemBytes) {
bs := SwapEndianness(v.Bytes())
copy(e[:], bs)
return e
}
func (e *ElemBytes) BigInt() *big.Int {
return new(big.Int).SetBytes(SwapEndianness(e[:]))
}
// String returns the first 4 bytes of ElemBytes in hex.
func (e *ElemBytes) String() string {
return fmt.Sprintf("%v...", hex.EncodeToString(e[:4]))
}
// ElemBytesToBytes serializes an array of ElemBytes to []byte.
func ElemBytesToBytes(es []ElemBytes) []byte {
bs := make([]byte, len(es)*ElemBytesLen)
for i := 0; i < len(es); i++ {
copy(bs[i*ElemBytesLen:(i+1)*ElemBytesLen], es[i][:])
}
return bs
}
// ElemBytesToBigInts serializes an array of ElemBytes to []byte.
func ElemBytesToBigInts(es []ElemBytes) []*big.Int {
bs := make([]*big.Int, len(es))
for i := 0; i < len(es); i++ {
bs[i] = es[i].BigInt()
}
return bs
}

+ 98
- 0
entry.go

@ -0,0 +1,98 @@
package merkletree
import (
"encoding/hex"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
)
// Entry is the generic type that is stored in the MT. The entry should not be
// modified after creating because the cached hIndex and hValue won't be
// updated.
type Entry struct {
Data Data
// hIndex is a cache used to avoid recalculating hIndex
hIndex *Hash
// hValue is a cache used to avoid recalculating hValue
hValue *Hash
}
type Entrier interface {
Entry() *Entry
}
func (e *Entry) Index() []ElemBytes {
return e.Data[:IndexLen]
}
func (e *Entry) Value() []ElemBytes {
return e.Data[IndexLen:]
}
// HIndex calculates the hash of the Index of the Entry, used to find the path
// from the root to the leaf in the MT.
func (e *Entry) HIndex() (*Hash, error) {
var err error
if e.hIndex == nil { // Cache the hIndex.
hIndex, err := HashElems(ElemBytesToBigInts(e.Index())...)
if err != nil {
return nil, err
}
e.hIndex = hIndex
}
return e.hIndex, err
}
// HValue calculates the hash of the Value of the Entry
func (e *Entry) HValue() (*Hash, error) {
var err error
if e.hValue == nil { // Cache the hValue.
hValue, err := HashElems(ElemBytesToBigInts(e.Value())...)
if err != nil {
return nil, err
}
e.hValue = hValue
}
return e.hValue, err
}
// HiHv returns the HIndex and HValue of the Entry
func (e *Entry) HiHv() (*Hash, *Hash, error) {
hi, err := e.HIndex()
if err != nil {
return nil, nil, err
}
hv, err := e.HValue()
if err != nil {
return nil, nil, err
}
return hi, hv, nil
}
func (e *Entry) Bytes() []byte {
b := e.Data.Bytes()
return b[:]
}
func (e1 *Entry) Equal(e2 *Entry) bool {
return e1.Data.Equal(&e2.Data)
}
func (e Entry) MarshalText() ([]byte, error) {
return []byte(hex.EncodeToString(e.Bytes())), nil
}
func (e *Entry) UnmarshalText(text []byte) error {
return e.Data.UnmarshalText(text)
}
func (e *Entry) Clone() *Entry {
data := NewDataFromBytes(e.Data.Bytes())
return &Entry{Data: *data}
}
func CheckEntryInField(e Entry) bool {
bigints := ElemBytesToBigInts(e.Data[:])
ok := cryptoUtils.CheckBigIntArrayInField(bigints)
return ok
}

+ 124
- 0
hash.go

@ -0,0 +1,124 @@
package merkletree
import (
"bytes"
"encoding/hex"
"fmt"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
"math/big"
"strings"
)
var (
// HashZero is used at Empty nodes
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
// Hash is the generic type stored in the MerkleTree
type Hash [32]byte
// MarshalText implements the marshaler for the Hash type
func (h Hash) MarshalText() ([]byte, error) {
return []byte(h.BigInt().String()), nil
}
// UnmarshalText implements the unmarshaler for the Hash type
func (h *Hash) UnmarshalText(b []byte) error {
ha, err := NewHashFromString(string(b))
copy(h[:], ha[:])
return err
}
// String returns decimal representation in string format of the Hash
func (h Hash) String() string {
s := h.BigInt().String()
if len(s) < numCharPrint {
return s
}
return s[0:numCharPrint] + "..."
}
// Hex returns the hexadecimal representation of the Hash
func (h Hash) Hex() string {
return hex.EncodeToString(h[:])
// alternatively equivalent, but with too extra steps:
// bRaw := h.BigInt().Bytes()
// b := [32]byte{}
// copy(b[:], SwapEndianness(bRaw[:]))
// return hex.EncodeToString(b[:])
}
// BigInt returns the *big.Int representation of the *Hash
func (h *Hash) BigInt() *big.Int {
if new(big.Int).SetBytes(SwapEndianness(h[:])) == nil {
return big.NewInt(0)
}
return new(big.Int).SetBytes(SwapEndianness(h[:]))
}
// Bytes returns the []byte representation of the *Hash, which always is 32
// bytes length.
func (h *Hash) Bytes() []byte {
bi := new(big.Int).SetBytes(h[:]).Bytes()
b := [32]byte{}
copy(b[:], SwapEndianness(bi[:]))
return b[:]
}
func (h *Hash) Equals(h2 *Hash) bool {
return bytes.Equal(h[:], h2[:])
}
// NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
// endianness in the process. This is the intended method to get a *big.Int
// from a byte array that previously has ben generated by the Hash.Bytes()
// method.
func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
if len(b) != ElemBytesLen {
return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
}
bi := new(big.Int).SetBytes(b[:ElemBytesLen])
if !cryptoUtils.CheckBigIntInField(bi) {
return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
}
return bi, nil
}
// NewHashFromBigInt returns a *Hash representation of the given *big.Int
func NewHashFromBigInt(b *big.Int) *Hash {
r := &Hash{}
copy(r[:], SwapEndianness(b.Bytes()))
return r
}
// NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
// in the process. This is the intended method to get a *Hash from a byte array
// that previously has ben generated by the Hash.Bytes() method.
func NewHashFromBytes(b []byte) (*Hash, error) {
if len(b) != ElemBytesLen {
return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
}
var h Hash
copy(h[:], SwapEndianness(b))
return &h, nil
}
// NewHashFromHex returns a *Hash representation of the given hex string
func NewHashFromHex(h string) (*Hash, error) {
h = strings.TrimPrefix(h, "0x")
b, err := hex.DecodeString(h)
if err != nil {
return nil, err
}
return NewHashFromBytes(SwapEndianness(b[:]))
}
// NewHashFromString returns a *Hash representation of the given decimal string
func NewHashFromString(s string) (*Hash, error) {
bi, ok := new(big.Int).SetString(s, 10)
if !ok {
return nil, fmt.Errorf("Can not parse string to Hash")
}
return NewHashFromBigInt(bi), nil
}

+ 84
- 269
merkletree.go

@ -2,12 +2,10 @@ package merkletree
import ( import (
"bytes" "bytes"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
"strings"
"sync" "sync"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils" cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
@ -17,10 +15,13 @@ const (
// proofFlagsLen is the byte length of the flags in the proof header // proofFlagsLen is the byte length of the flags in the proof header
// (first 32 bytes). // (first 32 bytes).
proofFlagsLen = 2 proofFlagsLen = 2
// ElemBytesLen is the length of the Hash byte array
ElemBytesLen = 32
numCharPrint = 8 numCharPrint = 8
// IndexLen indicates how many elements are used for the index.
IndexLen = 4
// DataLen indicates how many elements are used for the data.
DataLen = 8
) )
var ( var (
@ -51,115 +52,8 @@ var (
ErrNotWritable = errors.New("Merkle Tree not writable") ErrNotWritable = errors.New("Merkle Tree not writable")
dbKeyRootNode = []byte("currentroot") dbKeyRootNode = []byte("currentroot")
// HashZero is used at Empty nodes
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
) )
// Hash is the generic type stored in the MerkleTree
type Hash [32]byte
// MarshalText implements the marshaler for the Hash type
func (h Hash) MarshalText() ([]byte, error) {
return []byte(h.BigInt().String()), nil
}
// UnmarshalText implements the unmarshaler for the Hash type
func (h *Hash) UnmarshalText(b []byte) error {
ha, err := NewHashFromString(string(b))
copy(h[:], ha[:])
return err
}
// String returns decimal representation in string format of the Hash
func (h Hash) String() string {
s := h.BigInt().String()
if len(s) < numCharPrint {
return s
}
return s[0:numCharPrint] + "..."
}
// Hex returns the hexadecimal representation of the Hash
func (h Hash) Hex() string {
return hex.EncodeToString(h[:])
// alternatively equivalent, but with too extra steps:
// bRaw := h.BigInt().Bytes()
// b := [32]byte{}
// copy(b[:], SwapEndianness(bRaw[:]))
// return hex.EncodeToString(b[:])
}
// BigInt returns the *big.Int representation of the *Hash
func (h *Hash) BigInt() *big.Int {
if new(big.Int).SetBytes(SwapEndianness(h[:])) == nil {
return big.NewInt(0)
}
return new(big.Int).SetBytes(SwapEndianness(h[:]))
}
// Bytes returns the []byte representation of the *Hash, which always is 32
// bytes length.
func (h *Hash) Bytes() []byte {
bi := new(big.Int).SetBytes(h[:]).Bytes()
b := [32]byte{}
copy(b[:], SwapEndianness(bi[:]))
return b[:]
}
// NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
// endianness in the process. This is the intended method to get a *big.Int
// from a byte array that previously has ben generated by the Hash.Bytes()
// method.
func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
if len(b) != ElemBytesLen {
return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
}
bi := new(big.Int).SetBytes(b[:ElemBytesLen])
if !cryptoUtils.CheckBigIntInField(bi) {
return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
}
return bi, nil
}
// NewHashFromBigInt returns a *Hash representation of the given *big.Int
func NewHashFromBigInt(b *big.Int) *Hash {
r := &Hash{}
copy(r[:], SwapEndianness(b.Bytes()))
return r
}
// NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
// in the process. This is the intended method to get a *Hash from a byte array
// that previously has ben generated by the Hash.Bytes() method.
func NewHashFromBytes(b []byte) (*Hash, error) {
if len(b) != ElemBytesLen {
return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
}
var h Hash
copy(h[:], SwapEndianness(b))
return &h, nil
}
// NewHashFromHex returns a *Hash representation of the given hex string
func NewHashFromHex(h string) (*Hash, error) {
h = strings.TrimPrefix(h, "0x")
b, err := hex.DecodeString(h)
if err != nil {
return nil, err
}
return NewHashFromBytes(SwapEndianness(b[:]))
}
// NewHashFromString returns a *Hash representation of the given decimal string
func NewHashFromString(s string) (*Hash, error) {
bi, ok := new(big.Int).SetString(s, 10)
if !ok {
return nil, fmt.Errorf("Can not parse string to Hash")
}
return NewHashFromBigInt(bi), nil
}
// MerkleTree is the struct with the main elements of the MerkleTree // MerkleTree is the struct with the main elements of the MerkleTree
type MerkleTree struct { type MerkleTree struct {
sync.RWMutex sync.RWMutex
@ -276,6 +170,51 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return nil return nil
} }
// AddEntry adds the Entry to the MerkleTree
func (mt *MerkleTree) AddEntry(e *Entry) error {
// verify that the MerkleTree is writable
if !mt.writable {
return ErrNotWritable
}
// verify that the ElemBytes are valid and fit inside the mimc7 field.
if !CheckEntryInField(*e) {
return errors.New("Elements not inside the Finite Field over R")
}
tx, err := mt.db.NewTx()
if err != nil {
return err
}
mt.Lock()
defer mt.Unlock()
hIndex, err := e.HIndex()
if err != nil {
return err
}
hValue, err := e.HValue()
if err != nil {
return err
}
newNodeLeaf := NewNodeLeaf(hIndex, hValue)
path := getPath(mt.maxLevels, hIndex[:])
newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
if err != nil {
return err
}
mt.rootKey = newRootKey
err = mt.setCurrentRoot(tx, mt.rootKey)
if err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
// AddAndGetCircomProof does an Add, and returns a CircomProcessorProof // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
func (mt *MerkleTree) AddAndGetCircomProof(k, func (mt *MerkleTree) AddAndGetCircomProof(k,
v *big.Int) (*CircomProcessorProof, error) { v *big.Int) (*CircomProcessorProof, error) {
@ -757,103 +696,6 @@ type NodeAux struct {
Value *Hash Value *Hash
} }
// Proof defines the required elements for a MT proof of existence or
// non-existence.
type Proof struct {
// existence indicates wether this is a proof of existence or
// non-existence.
Existence bool
// depth indicates how deep in the tree the proof goes.
depth uint
// notempties is a bitmap of non-empty Siblings found in Siblings.
notempties [ElemBytesLen - proofFlagsLen]byte
// Siblings is a list of non-empty sibling keys.
Siblings []*Hash
NodeAux *NodeAux
}
// NewProofFromBytes parses a byte array into a Proof.
func NewProofFromBytes(bs []byte) (*Proof, error) {
if len(bs) < ElemBytesLen {
return nil, ErrInvalidProofBytes
}
p := &Proof{}
if (bs[0] & 0x01) == 0 {
p.Existence = true
}
p.depth = uint(bs[1])
copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
siblingBytes := bs[ElemBytesLen:]
sibIdx := 0
for i := uint(0); i < p.depth; i++ {
if TestBitBigEndian(p.notempties[:], i) {
if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
return nil, ErrInvalidProofBytes
}
var sib Hash
copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
p.Siblings = append(p.Siblings, &sib)
sibIdx++
}
}
if !p.Existence && ((bs[0] & 0x02) != 0) {
p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
if len(nodeAuxBytes) != 2*ElemBytesLen {
return nil, ErrInvalidProofBytes
}
copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
}
return p, nil
}
// Bytes serializes a Proof into a byte array.
func (p *Proof) Bytes() []byte {
bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
if p.NodeAux != nil {
bsLen += 2 * ElemBytesLen //nolint:gomnd
}
bs := make([]byte, bsLen)
if !p.Existence {
bs[0] |= 0x01
}
bs[1] = byte(p.depth)
copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
for i, k := range p.Siblings {
copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
}
if p.NodeAux != nil {
bs[0] |= 0x02
copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
}
return bs
}
// SiblingsFromProof returns all the siblings of the proof.
func SiblingsFromProof(proof *Proof) []*Hash {
sibIdx := 0
siblings := []*Hash{}
for lvl := 0; lvl < int(proof.depth); lvl++ {
if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
siblings = append(siblings, proof.Siblings[sibIdx])
sibIdx++
} else {
siblings = append(siblings, &HashZero)
}
}
return siblings
}
// AllSiblings returns all the siblings of the proof.
func (p *Proof) AllSiblings() []*Hash {
return SiblingsFromProof(p)
}
// CircomSiblingsFromSiblings returns the full siblings compatible with circom // CircomSiblingsFromSiblings returns the full siblings compatible with circom
func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash { func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
// Add the rest of empty levels to the siblings // Add the rest of empty levels to the siblings
@ -1008,67 +850,6 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof,
return nil, nil, ErrKeyNotFound return nil, nil, ErrKeyNotFound
} }
// VerifyProof verifies the Merkle Proof for the entry and root.
func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
rootFromProof, err := RootFromProof(proof, k, v)
if err != nil {
return false
}
return bytes.Equal(rootKey[:], rootFromProof[:])
}
// RootFromProof calculates the root that would correspond to a tree whose
// siblings are the ones in the proof with the leaf hashing to hIndex and
// hValue.
func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
kHash := NewHashFromBigInt(k)
vHash := NewHashFromBigInt(v)
sibIdx := len(proof.Siblings) - 1
var err error
var midKey *Hash
if proof.Existence {
midKey, err = LeafKey(kHash, vHash)
if err != nil {
return nil, err
}
} else {
if proof.NodeAux == nil {
midKey = &HashZero
} else {
if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
return nil,
fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
}
midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
if err != nil {
return nil, err
}
}
}
path := getPath(int(proof.depth), kHash[:])
var siblingKey *Hash
for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
siblingKey = proof.Siblings[sibIdx]
sibIdx--
} else {
siblingKey = &HashZero
}
if path[lvl] {
midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
if err != nil {
return nil, err
}
} else {
midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
if err != nil {
return nil, err
}
}
}
return midKey, nil
}
// walk is a helper recursive function to iterate over all tree branches // walk is a helper recursive function to iterate over all tree branches
func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error { func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
n, err := mt.GetNode(key) n, err := mt.GetNode(key)
@ -1199,3 +980,37 @@ func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
} }
return nil return nil
} }
//// ImportTree imports the tree from the output from the DumpTree function
//func (mt *MerkleTree) ImportTree(i io.Reader) error {
// tx, err := mt.DB().NewTx()
// if err != nil {
// return err
// }
// mt.Lock()
// defer mt.Unlock()
//
// r := bufio.NewReader(i)
// for {
// k, v, err := deserializeKV(r)
// if err == io.EOF {
// break
// } else if err != nil {
// return err
// }
// tx.Put(k, v)
// }
//
// v, err := tx.GetRoot()
// if err != nil {
// return err
// }
//
// if err := tx.Commit(); err != nil {
// return err
// }
// mt.rootKey = &Hash{}
// copy(mt.rootKey[:], v[:])
//
// return nil
//}

+ 0
- 717
merkletree_test.go

@ -1,35 +1,15 @@
package merkletree package merkletree
import ( import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"math/big" "math/big"
"testing" "testing"
"github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-iden3-crypto/constants"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils" cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
"github.com/iden3/go-merkletree/db/memory"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
var debug = false
type Fatalable interface {
Fatal(args ...interface{})
}
func newTestingMerkle(f Fatalable, numLevels int) *MerkleTree {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), numLevels)
if err != nil {
f.Fatal(err)
return nil
}
return mt
}
func TestHashParsers(t *testing.T) { func TestHashParsers(t *testing.T) {
h0 := NewHashFromBigInt(big.NewInt(0)) h0 := NewHashFromBigInt(big.NewInt(0))
assert.Equal(t, "0", h0.String()) assert.Equal(t, "0", h0.String())
@ -89,700 +69,3 @@ func testHashParsers(t *testing.T, a *big.Int) {
assert.Equal(t, a, aBIFromHBytes) assert.Equal(t, a, aBIFromHBytes)
assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String()) assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
} }
func TestNewTree(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
}
func TestAddDifferentOrder(t *testing.T) {
mt1 := newTestingMerkle(t, 140)
defer mt1.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt1.Add(k, v); err != nil {
t.Fatal(err)
}
}
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
for i := 16 - 1; i >= 0; i-- {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt2.Add(k, v); err != nil {
t.Fatal(err)
}
}
assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll
}
func TestAddRepeatedIndex(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
k := big.NewInt(int64(3))
v := big.NewInt(int64(12))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
err := mt.Add(k, v)
assert.NotNil(t, err)
assert.Equal(t, err, ErrEntryIndexAlreadyExists)
}
func TestGet(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
k, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(10), k)
assert.Equal(t, big.NewInt(20), v)
k, v, _, err = mt.Get(big.NewInt(15))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(15), k)
assert.Equal(t, big.NewInt(30), v)
k, v, _, err = mt.Get(big.NewInt(16))
assert.NotNil(t, err)
assert.Equal(t, ErrKeyNotFound, err)
assert.Equal(t, "0", k.String())
assert.Equal(t, "0", v.String())
}
func TestUpdate(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
_, err = mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
_, v, _, err = mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(1024), v)
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
assert.Equal(t, ErrKeyNotFound, err)
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
}
func TestUpdate2(t *testing.T) {
mt1 := newTestingMerkle(t, 140)
defer mt1.db.Close()
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err := mt1.Add(big.NewInt(1), big.NewInt(119))
assert.Nil(t, err)
err = mt1.Add(big.NewInt(2), big.NewInt(229))
assert.Nil(t, err)
err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(9876), big.NewInt(10))
assert.Nil(t, err)
_, err = mt1.Update(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err)
_, err = mt1.Update(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err)
_, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
assert.Nil(t, err)
assert.Equal(t, mt1.Root(), mt2.Root())
}
func TestGenerateAndVerifyProof128(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
require.Nil(t, err)
defer mt.db.Close()
for i := 0; i < 128; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
assert.Nil(t, err)
assert.Equal(t, "0", v.String())
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
}
func TestTreeLimit(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 5)
require.Nil(t, err)
defer mt.db.Close()
for i := 0; i < 16; i++ {
err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
assert.Nil(t, err)
}
// here the tree is full, should not allow to add more data as reaches the maximum number of levels
err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
assert.NotNil(t, err)
assert.Equal(t, ErrReachedMaxLevel, err)
}
func TestSiblingsFromProof(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
require.Nil(t, err)
defer mt.db.Close()
for i := 0; i < 64; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
if err != nil {
t.Fatal(err)
}
siblings := SiblingsFromProof(proof)
assert.Equal(t, 6, len(siblings))
assert.Equal(t,
"d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b",
siblings[0].Hex())
assert.Equal(t,
"9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16",
siblings[1].Hex())
assert.Equal(t,
"de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27",
siblings[2].Hex())
assert.Equal(t,
"5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317",
siblings[3].Hex())
assert.Equal(t,
"77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122",
siblings[4].Hex())
assert.Equal(t,
"943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07",
siblings[5].Hex())
}
func TestVerifyProofCases(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.DB().Close()
for i := 0; i < 8; i++ {
if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
t.Fatal(err)
}
}
// Existence proof
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll
for i := 8; i < 32; i++ {
proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
assert.Nil(t, err)
if debug {
fmt.Println(i, proof)
}
}
// Non-existence proof, empty aux
proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, false)
// assert.True(t, proof.nodeAux == nil)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
// Non-existence proof, diff. node aux
proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, false)
assert.True(t, proof.NodeAux != nil)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
}
func TestVerifyProofFalse(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.DB().Close()
for i := 0; i < 8; i++ {
if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
t.Fatal(err)
}
}
// Invalid existence proof (node used for verification doesn't
// correspond to node in the proof)
proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
// Invalid non-existence proof (Non-existence proof, diff. node aux)
proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, proof.Existence, true)
// Now we change the proof from existence to non-existence, and add e's
// data as auxiliary node.
proof.Existence = false
proof.NodeAux = &NodeAux{Key: NewHashFromBigInt(big.NewInt(int64(4))),
Value: NewHashFromBigInt(big.NewInt(4))}
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
}
func TestGraphViz(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
_ = mt.Add(big.NewInt(1), big.NewInt(0))
_ = mt.Add(big.NewInt(2), big.NewInt(0))
_ = mt.Add(big.NewInt(3), big.NewInt(0))
_ = mt.Add(big.NewInt(4), big.NewInt(0))
_ = mt.Add(big.NewInt(5), big.NewInt(0))
_ = mt.Add(big.NewInt(100), big.NewInt(0))
// mt.PrintGraphViz(nil)
expected := `digraph hierarchy {
node [fontname=Monospace,fontsize=10,shape=box]
"56332309..." -> {"18483622..." "20902180..."}
"18483622..." -> {"75768243..." "16893244..."}
"75768243..." -> {"empty0" "21857056..."}
"empty0" [style=dashed,label=0];
"21857056..." -> {"51072523..." "empty1"}
"empty1" [style=dashed,label=0];
"51072523..." -> {"17311038..." "empty2"}
"empty2" [style=dashed,label=0];
"17311038..." -> {"69499803..." "21008290..."}
"69499803..." [style=filled];
"21008290..." [style=filled];
"16893244..." [style=filled];
"20902180..." -> {"12496585..." "18055627..."}
"12496585..." -> {"19374975..." "15739329..."}
"19374975..." [style=filled];
"15739329..." [style=filled];
"18055627..." [style=filled];
}
`
w := bytes.NewBufferString("")
err = mt.GraphViz(w, nil)
assert.Nil(t, err)
assert.Equal(t, []byte(expected), w.Bytes())
}
func TestDelete(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
// mt.PrintGraphViz(nil)
err = mt.Delete(big.NewInt(33))
// mt.PrintGraphViz(nil)
assert.Nil(t, err)
assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1234))
assert.Nil(t, err)
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
}
func TestDelete2(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
for i := 0; i < 8; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
expectedRoot := mt.Root()
k := big.NewInt(8)
v := big.NewInt(0)
err := mt.Add(k, v)
require.Nil(t, err)
err = mt.Delete(big.NewInt(8))
assert.Nil(t, err)
assert.Equal(t, expectedRoot, mt.Root())
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
for i := 0; i < 8; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(0)
if err := mt2.Add(k, v); err != nil {
t.Fatal(err)
}
}
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete3(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
err := mt.Add(big.NewInt(1), big.NewInt(1))
assert.Nil(t, err)
err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err = mt2.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete4(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
err := mt.Add(big.NewInt(1), big.NewInt(1))
assert.Nil(t, err)
err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)
assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err = mt2.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err)
err = mt2.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDelete5(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll
mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close()
err = mt2.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, mt2.Root(), mt.Root())
}
func TestDeleteNonExistingKeys(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
err = mt.Delete(big.NewInt(33))
assert.Nil(t, err)
err = mt.Delete(big.NewInt(33))
assert.Equal(t, ErrKeyNotFound, err)
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
err = mt.Delete(big.NewInt(33))
assert.Equal(t, ErrKeyNotFound, err)
}
func TestDumpLeafsImportLeafs(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
require.Nil(t, err)
defer mt.db.Close()
q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
for i := 0; i < 10; i++ {
// use numbers near under Q
k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
v := big.NewInt(0)
err = mt.Add(k, v)
require.Nil(t, err)
// use numbers near above 0
k = big.NewInt(int64(i))
err = mt.Add(k, v)
require.Nil(t, err)
}
d, err := mt.DumpLeafs(nil)
assert.Nil(t, err)
mt2, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
require.Nil(t, err)
defer mt2.db.Close()
err = mt2.ImportDumpedLeafs(d)
assert.Nil(t, err)
assert.Equal(t, mt.Root(), mt2.Root())
}
func TestAddAndGetCircomProof(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "0", cpp.OldRoot.String())
assert.Equal(t, "13578938...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "1", cpp.NewKey.String())
assert.Equal(t, "2", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "13578938...", cpp.OldRoot.String())
assert.Equal(t, "54123936...", cpp.NewRoot.String())
assert.Equal(t, "1", cpp.OldKey.String())
assert.Equal(t, "2", cpp.OldValue.String())
assert.Equal(t, "33", cpp.NewKey.String())
assert.Equal(t, "44", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
assert.Nil(t, err)
assert.Equal(t, "54123936...", cpp.OldRoot.String())
assert.Equal(t, "50943640...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "55", cpp.NewKey.String())
assert.Equal(t, "66", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
}
func TestUpdateCircomProcessorProof(t *testing.T) {
mt := newTestingMerkle(t, 10)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
assert.Equal(t, "39010880...", cpp.OldRoot.String())
assert.Equal(t, "18587862...", cpp.NewRoot.String())
assert.Equal(t, "10", cpp.OldKey.String())
assert.Equal(t, "20", cpp.OldValue.String())
assert.Equal(t, "10", cpp.NewKey.String())
assert.Equal(t, "1024", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t,
"[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]",
fmt.Sprintf("%v", cpp.Siblings))
}
func TestSmtVerifier(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 4)
assert.Nil(t, err)
err = mt.Add(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err)
cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil)
assert.Nil(t, err)
jCvp, err := json.Marshal(cvp)
assert.Nil(t, err)
// expect siblings to be '[]', instead of 'null'
expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll
assert.Equal(t, expected, string(jCvp))
err = mt.Add(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err)
err = mt.Add(big.NewInt(3), big.NewInt(33))
assert.Nil(t, err)
err = mt.Add(big.NewInt(4), big.NewInt(44))
assert.Nil(t, err)
cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil)
assert.Nil(t, err)
jCvp, err = json.Marshal(cvp)
assert.Nil(t, err)
// Test vectors generated using https://github.com/iden3/circomlib smt.js
// Expect siblings with the extra 0 that the circom circuits need
expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
assert.Equal(t, expected, string(jCvp))
cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)
assert.Nil(t, err)
jCvp, err = json.Marshal(cvp)
assert.Nil(t, err)
// Test vectors generated using https://github.com/iden3/circomlib smt.js
// Without the extra 0 that the circom circuits need, but that are not
// needed at a smart contract verification
expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
assert.Equal(t, expected, string(jCvp))
}
func TestTypesMarshalers(t *testing.T) {
// test Hash marshalers
h, err := NewHashFromString("42")
assert.Nil(t, err)
s, err := json.Marshal(h)
assert.Nil(t, err)
var h2 *Hash
err = json.Unmarshal(s, &h2)
assert.Nil(t, err)
assert.Equal(t, h, h2)
// create CircomProcessorProof
mt := newTestingMerkle(t, 10)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
// test CircomProcessorProof marshalers
b, err := json.Marshal(&cpp)
assert.Nil(t, err)
var cpp2 *CircomProcessorProof
err = json.Unmarshal(b, &cpp2)
assert.Nil(t, err)
assert.Equal(t, cpp, cpp2)
}

+ 165
- 0
proof.go

@ -0,0 +1,165 @@
package merkletree
import (
"bytes"
"fmt"
"math/big"
)
// Proof defines the required elements for a MT proof of existence or
// non-existence.
type Proof struct {
// existence indicates wether this is a proof of existence or
// non-existence.
Existence bool
// depth indicates how deep in the tree the proof goes.
depth uint
// notempties is a bitmap of non-empty Siblings found in Siblings.
notempties [ElemBytesLen - proofFlagsLen]byte
// Siblings is a list of non-empty sibling keys.
Siblings []*Hash
NodeAux *NodeAux
}
// NewProofFromBytes parses a byte array into a Proof.
func NewProofFromBytes(bs []byte) (*Proof, error) {
if len(bs) < ElemBytesLen {
return nil, ErrInvalidProofBytes
}
p := &Proof{}
if (bs[0] & 0x01) == 0 {
p.Existence = true
}
p.depth = uint(bs[1])
copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
siblingBytes := bs[ElemBytesLen:]
sibIdx := 0
for i := uint(0); i < p.depth; i++ {
if TestBitBigEndian(p.notempties[:], i) {
if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
return nil, ErrInvalidProofBytes
}
var sib Hash
copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
p.Siblings = append(p.Siblings, &sib)
sibIdx++
}
}
if !p.Existence && ((bs[0] & 0x02) != 0) {
p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
if len(nodeAuxBytes) != 2*ElemBytesLen {
return nil, ErrInvalidProofBytes
}
copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
}
return p, nil
}
// Bytes serializes a Proof into a byte array.
func (p *Proof) Bytes() []byte {
bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
if p.NodeAux != nil {
bsLen += 2 * ElemBytesLen //nolint:gomnd
}
bs := make([]byte, bsLen)
if !p.Existence {
bs[0] |= 0x01
}
bs[1] = byte(p.depth)
copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
for i, k := range p.Siblings {
copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
}
if p.NodeAux != nil {
bs[0] |= 0x02
copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
}
return bs
}
// SiblingsFromProof returns all the siblings of the proof.
func SiblingsFromProof(proof *Proof) []*Hash {
sibIdx := 0
siblings := []*Hash{}
for lvl := 0; lvl < int(proof.depth); lvl++ {
if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
siblings = append(siblings, proof.Siblings[sibIdx])
sibIdx++
} else {
siblings = append(siblings, &HashZero)
}
}
return siblings
}
// AllSiblings returns all the siblings of the proof.
func (p *Proof) AllSiblings() []*Hash {
return SiblingsFromProof(p)
}
// VerifyProof verifies the Merkle Proof for the entry and root.
func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
rootFromProof, err := RootFromProof(proof, k, v)
if err != nil {
return false
}
return bytes.Equal(rootKey[:], rootFromProof[:])
}
// RootFromProof calculates the root that would correspond to a tree whose
// siblings are the ones in the proof with the leaf hashing to hIndex and
// hValue.
func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
kHash := NewHashFromBigInt(k)
vHash := NewHashFromBigInt(v)
sibIdx := len(proof.Siblings) - 1
var err error
var midKey *Hash
if proof.Existence {
midKey, err = LeafKey(kHash, vHash)
if err != nil {
return nil, err
}
} else {
if proof.NodeAux == nil {
midKey = &HashZero
} else {
if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
return nil,
fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
}
midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
if err != nil {
return nil, err
}
}
}
path := getPath(int(proof.depth), kHash[:])
var siblingKey *Hash
for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
siblingKey = proof.Siblings[sibIdx]
sibIdx--
} else {
siblingKey = &HashZero
}
if path[lvl] {
midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
if err != nil {
return nil, err
}
} else {
midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
if err != nil {
return nil, err
}
}
}
return midKey, nil
}

+ 68
- 0
utils.go

@ -1,6 +1,9 @@
package merkletree package merkletree
import ( import (
"encoding/binary"
"fmt"
"io"
"math/big" "math/big"
"github.com/iden3/go-iden3-crypto/poseidon" "github.com/iden3/go-iden3-crypto/poseidon"
@ -56,3 +59,68 @@ func SwapEndianness(b []byte) []byte {
} }
return o return o
} }
func checkKVLen(kLen, vLen int) error {
if kLen > 0xff {
return fmt.Errorf("len(k) %d > 0xff", kLen)
}
if vLen > 0xffff {
return fmt.Errorf("len(v) %d > 0xffff", vLen)
}
return nil
}
func serializeKV(w io.Writer, k, v []byte) error {
if err := checkKVLen(len(k), len(v)); err != nil {
return err
}
kH := byte(len(k))
vH := Uint16ToBytes(uint16(len(v)))
_, err := w.Write([]byte{kH})
if err != nil {
return err
}
_, err = w.Write(vH)
if err != nil {
return err
}
_, err = w.Write(k)
if err != nil {
return err
}
_, err = w.Write(v)
if err != nil {
return err
}
return nil
}
func deserializeKV(r io.Reader) ([]byte, []byte, error) {
header := make([]byte, 3)
_, err := io.ReadFull(r, header)
if err != nil {
return nil, nil, err
}
kLen := int(header[0])
vLen := int(BytesToUint16(header[1:]))
kv := make([]byte, kLen+vLen)
_, err = io.ReadFull(r, kv)
if err == io.EOF {
return nil, nil, io.ErrUnexpectedEOF
} else if err != nil {
return nil, nil, err
}
return kv[:kLen], kv[kLen:], nil
}
// Uint16ToBytes returns a byte array from a uint16
func Uint16ToBytes(u uint16) []byte {
var b [2]byte
binary.LittleEndian.PutUint16(b[:], u)
return b[:]
}
// BytesToUint16 returns a uint16 from a byte array
func BytesToUint16(b []byte) uint16 {
return binary.LittleEndian.Uint16(b[:2])
}

Loading…
Cancel
Save