mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-07 19:46:43 +01:00
WIP. Implementation of Postgres Storage for MerkleTree. Changes in how storage works in general.
This commit is contained in:
75
db/db.go
75
db/db.go
@@ -1,75 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is used by the implementations of the interface db.Storage for
|
||||
// when a key is not found in the storage
|
||||
var ErrNotFound = errors.New("key not found")
|
||||
|
||||
// Storage is the interface that defines the methods for the storage used in
|
||||
// the merkletree. Examples of the interface implementation can be found at
|
||||
// db/memory and db/leveldb directories.
|
||||
type Storage interface {
|
||||
NewTx() (Tx, error)
|
||||
WithPrefix(prefix []byte) Storage
|
||||
Get([]byte) ([]byte, error)
|
||||
List(int) ([]KV, error)
|
||||
Close()
|
||||
Iterate(func([]byte, []byte) (bool, error)) error
|
||||
}
|
||||
|
||||
// Tx is the interface that defines the methods for the db transaction used in
|
||||
// the merkletree storage. Examples of the interface implementation can be
|
||||
// found at db/memory and db/leveldb directories.
|
||||
type Tx interface {
|
||||
// Get retreives the value for the given key
|
||||
// looking first in the content of the Tx, and
|
||||
// then into the content of the Storage
|
||||
Get([]byte) ([]byte, error)
|
||||
// Put sets the key & value into the Tx
|
||||
Put(k, v []byte) error
|
||||
// Add adds the given Tx into the Tx
|
||||
Add(Tx) error
|
||||
Commit() error
|
||||
Close()
|
||||
}
|
||||
|
||||
// KV contains a key (K) and a value (V)
|
||||
type KV struct {
|
||||
K []byte
|
||||
V []byte
|
||||
}
|
||||
|
||||
// KvMap is a key-value map between a sha256 byte array hash, and a KV struct
|
||||
type KvMap map[[sha256.Size]byte]KV
|
||||
|
||||
// Get retreives the value respective to a key from the KvMap
|
||||
func (m KvMap) Get(k []byte) ([]byte, bool) {
|
||||
v, ok := m[sha256.Sum256(k)]
|
||||
return v.V, ok
|
||||
}
|
||||
|
||||
// Put stores a key and a value in the KvMap
|
||||
func (m KvMap) Put(k, v []byte) {
|
||||
m[sha256.Sum256(k)] = KV{k, v}
|
||||
}
|
||||
|
||||
// Concat concatenates arrays of bytes
|
||||
func Concat(vs ...[]byte) []byte {
|
||||
var b bytes.Buffer
|
||||
for _, v := range vs {
|
||||
b.Write(v)
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
// Clone clones a byte array into a new byte array
|
||||
func Clone(b0 []byte) []byte {
|
||||
b1 := make([]byte, len(b0))
|
||||
copy(b1, b0)
|
||||
return b1
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package leveldb
|
||||
|
||||
import (
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
"github.com/iden3/go-merkletree"
|
||||
"github.com/syndtr/goleveldb/leveldb"
|
||||
"github.com/syndtr/goleveldb/leveldb/errors"
|
||||
"github.com/syndtr/goleveldb/leveldb/opt"
|
||||
@@ -17,7 +17,7 @@ type Storage struct {
|
||||
// StorageTx implements the db.Tx interface
|
||||
type StorageTx struct {
|
||||
*Storage
|
||||
cache db.KvMap
|
||||
cache merkletree.KvMap
|
||||
}
|
||||
|
||||
// NewLevelDbStorage returns a new Storage
|
||||
@@ -33,20 +33,20 @@ func NewLevelDbStorage(path string, errorIfMissing bool) (*Storage, error) {
|
||||
}
|
||||
|
||||
// WithPrefix implements the method WithPrefix of the interface db.Storage
|
||||
func (l *Storage) WithPrefix(prefix []byte) db.Storage {
|
||||
return &Storage{l.ldb, db.Concat(l.prefix, prefix)}
|
||||
func (l *Storage) WithPrefix(prefix []byte) merkletree.Storage {
|
||||
return &Storage{l.ldb, merkletree.Concat(l.prefix, prefix)}
|
||||
}
|
||||
|
||||
// NewTx implements the method NewTx of the interface db.Storage
|
||||
func (l *Storage) NewTx() (db.Tx, error) {
|
||||
return &StorageTx{l, make(db.KvMap)}, nil
|
||||
func (l *Storage) NewTx() (merkletree.Tx, error) {
|
||||
return &StorageTx{l, make(merkletree.KvMap)}, nil
|
||||
}
|
||||
|
||||
// Get retreives a value from a key in the db.Storage
|
||||
// Get retrieves a value from a key in the db.Storage
|
||||
func (l *Storage) Get(key []byte) ([]byte, error) {
|
||||
v, err := l.ldb.Get(db.Concat(l.prefix, key[:]), nil)
|
||||
v, err := l.ldb.Get(merkletree.Concat(l.prefix, key[:]), nil)
|
||||
if err == errors.ErrNotFound {
|
||||
return nil, db.ErrNotFound
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (l *Storage) Iterate(f func([]byte, []byte) (bool, error)) error {
|
||||
func (tx *StorageTx) Get(key []byte) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
fullkey := db.Concat(tx.prefix, key)
|
||||
fullkey := merkletree.Concat(tx.prefix, key)
|
||||
|
||||
if value, ok := tx.cache.Get(fullkey); ok {
|
||||
return value, nil
|
||||
@@ -84,7 +84,7 @@ func (tx *StorageTx) Get(key []byte) ([]byte, error) {
|
||||
|
||||
value, err := tx.ldb.Get(fullkey, nil)
|
||||
if err == errors.ErrNotFound {
|
||||
return nil, db.ErrNotFound
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
|
||||
return value, err
|
||||
@@ -92,12 +92,12 @@ func (tx *StorageTx) Get(key []byte) ([]byte, error) {
|
||||
|
||||
// Put saves a key:value into the db.Storage
|
||||
func (tx *StorageTx) Put(k, v []byte) error {
|
||||
tx.cache.Put(db.Concat(tx.prefix, k[:]), v)
|
||||
tx.cache.Put(merkletree.Concat(tx.prefix, k[:]), v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add implements the method Add of the interface db.Tx
|
||||
func (tx *StorageTx) Add(atx db.Tx) error {
|
||||
func (tx *StorageTx) Add(atx merkletree.Tx) error {
|
||||
ldbtx := atx.(*StorageTx)
|
||||
for _, v := range ldbtx.cache {
|
||||
tx.cache.Put(v.K, v.V)
|
||||
@@ -134,10 +134,10 @@ func (l *Storage) LevelDB() *leveldb.DB {
|
||||
}
|
||||
|
||||
// List implements the method List of the interface db.Storage
|
||||
func (l *Storage) List(limit int) ([]db.KV, error) {
|
||||
ret := []db.KV{}
|
||||
func (l *Storage) List(limit int) ([]merkletree.KV, error) {
|
||||
ret := []merkletree.KV{}
|
||||
err := l.Iterate(func(key []byte, value []byte) (bool, error) {
|
||||
ret = append(ret, db.KV{K: db.Clone(key), V: db.Clone(value)})
|
||||
ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: merkletree.Clone(value)})
|
||||
if len(ret) == limit {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package leveldb
|
||||
|
||||
import (
|
||||
"github.com/iden3/go-merkletree"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
"github.com/iden3/go-merkletree/db/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var rmDirs []string
|
||||
|
||||
func levelDbStorage(t *testing.T) db.Storage {
|
||||
func levelDbStorage(t *testing.T) merkletree.Storage {
|
||||
dir, err := ioutil.TempDir("", "db")
|
||||
rmDirs = append(rmDirs, dir)
|
||||
if err != nil {
|
||||
@@ -37,7 +37,7 @@ func TestLevelDb(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLevelDbInterface(t *testing.T) {
|
||||
var db db.Storage //nolint:gosimple
|
||||
var db merkletree.Storage //nolint:gosimple
|
||||
|
||||
dir, err := ioutil.TempDir("", "db")
|
||||
require.Nil(t, err)
|
||||
|
||||
@@ -2,64 +2,72 @@ package memory
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/iden3/go-merkletree"
|
||||
"sort"
|
||||
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
)
|
||||
|
||||
// Storage implements the db.Storage interface
|
||||
type Storage struct {
|
||||
prefix []byte
|
||||
kv db.KvMap
|
||||
prefix []byte
|
||||
kv merkletree.KvMap
|
||||
currentRoot *merkletree.Hash
|
||||
}
|
||||
|
||||
// StorageTx implements the db.Tx interface
|
||||
type StorageTx struct {
|
||||
s *Storage
|
||||
kv db.KvMap
|
||||
s *Storage
|
||||
kv merkletree.KvMap
|
||||
currentRoot *merkletree.Hash
|
||||
}
|
||||
|
||||
// NewMemoryStorage returns a new Storage
|
||||
func NewMemoryStorage() *Storage {
|
||||
kvmap := make(db.KvMap)
|
||||
return &Storage{[]byte{}, kvmap}
|
||||
kvmap := make(merkletree.KvMap)
|
||||
return &Storage{[]byte{}, kvmap, nil}
|
||||
}
|
||||
|
||||
// WithPrefix implements the method WithPrefix of the interface db.Storage
|
||||
func (m *Storage) WithPrefix(prefix []byte) db.Storage {
|
||||
return &Storage{db.Concat(m.prefix, prefix), m.kv}
|
||||
func (m *Storage) WithPrefix(prefix []byte) merkletree.Storage {
|
||||
return &Storage{merkletree.Concat(m.prefix, prefix), m.kv, nil}
|
||||
}
|
||||
|
||||
// NewTx implements the method NewTx of the interface db.Storage
|
||||
func (m *Storage) NewTx() (db.Tx, error) {
|
||||
return &StorageTx{m, make(db.KvMap)}, nil
|
||||
func (m *Storage) NewTx() (merkletree.Tx, error) {
|
||||
return &StorageTx{m, make(merkletree.KvMap), nil}, nil
|
||||
}
|
||||
|
||||
// Get retreives a value from a key in the db.Storage
|
||||
func (m *Storage) Get(key []byte) ([]byte, error) {
|
||||
if v, ok := m.kv.Get(db.Concat(m.prefix, key[:])); ok {
|
||||
return v, nil
|
||||
// Get retrieves a value from a key in the db.Storage
|
||||
func (m *Storage) Get(key []byte) (*merkletree.Node, error) {
|
||||
if v, ok := m.kv.Get(merkletree.Concat(m.prefix, key[:])); ok {
|
||||
return &v, nil
|
||||
}
|
||||
return nil, db.ErrNotFound
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
|
||||
func (m *Storage) GetRoot() (*merkletree.Hash, error) {
|
||||
if m.currentRoot != nil {
|
||||
return m.currentRoot, nil
|
||||
}
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
|
||||
// Iterate implements the method Iterate of the interface db.Storage
|
||||
func (m *Storage) Iterate(f func([]byte, []byte) (bool, error)) error {
|
||||
kvs := make([]db.KV, 0)
|
||||
func (m *Storage) Iterate(f func([]byte, *merkletree.Node) (bool, error)) error {
|
||||
kvs := make([]merkletree.KV, 0)
|
||||
for _, v := range m.kv {
|
||||
if len(v.K) < len(m.prefix) ||
|
||||
!bytes.Equal(v.K[:len(m.prefix)], m.prefix) {
|
||||
continue
|
||||
}
|
||||
localkey := v.K[len(m.prefix):]
|
||||
kvs = append(kvs, db.KV{K: localkey, V: v.V})
|
||||
kvs = append(kvs, merkletree.KV{K: localkey, V: v.V})
|
||||
}
|
||||
sort.SliceStable(kvs, func(i, j int) bool {
|
||||
return bytes.Compare(kvs[i].K, kvs[j].K) < 0
|
||||
})
|
||||
|
||||
for _, kv := range kvs {
|
||||
if cont, err := f(kv.K, kv.V); err != nil {
|
||||
if cont, err := f(kv.K, &kv.V); err != nil {
|
||||
return err
|
||||
} else if !cont {
|
||||
break
|
||||
@@ -69,20 +77,37 @@ func (m *Storage) Iterate(f func([]byte, []byte) (bool, error)) error {
|
||||
}
|
||||
|
||||
// Get implements the method Get of the interface db.Tx
|
||||
func (tx *StorageTx) Get(key []byte) ([]byte, error) {
|
||||
if v, ok := tx.kv.Get(db.Concat(tx.s.prefix, key)); ok {
|
||||
return v, nil
|
||||
func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
|
||||
if v, ok := tx.kv.Get(merkletree.Concat(tx.s.prefix, key)); ok {
|
||||
return &v, nil
|
||||
}
|
||||
if v, ok := tx.s.kv.Get(db.Concat(tx.s.prefix, key)); ok {
|
||||
return v, nil
|
||||
if v, ok := tx.s.kv.Get(merkletree.Concat(tx.s.prefix, key)); ok {
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
return nil, db.ErrNotFound
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
|
||||
// Put implements the method Put of the interface db.Tx
|
||||
func (tx *StorageTx) Put(k, v []byte) error {
|
||||
tx.kv.Put(db.Concat(tx.s.prefix, k), v)
|
||||
func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
|
||||
tx.kv.Put(merkletree.Concat(tx.s.prefix, k), *v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
|
||||
if tx.currentRoot != nil {
|
||||
hash := merkletree.Hash{}
|
||||
copy(tx.currentRoot[:], hash[:])
|
||||
return &hash, nil
|
||||
}
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
|
||||
// SetRoot sets a hash of merkle tree root in the interface db.Tx
|
||||
func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
|
||||
root := &merkletree.Hash{}
|
||||
copy(root[:], hash[:])
|
||||
tx.currentRoot = root
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,7 +121,7 @@ func (tx *StorageTx) Commit() error {
|
||||
}
|
||||
|
||||
// Add implements the method Add of the interface db.Tx
|
||||
func (tx *StorageTx) Add(atx db.Tx) error {
|
||||
func (tx *StorageTx) Add(atx merkletree.Tx) error {
|
||||
mstx := atx.(*StorageTx)
|
||||
for _, v := range mstx.kv {
|
||||
tx.kv.Put(v.K, v.V)
|
||||
@@ -114,10 +139,10 @@ func (m *Storage) Close() {
|
||||
}
|
||||
|
||||
// List implements the method List of the interface db.Storage
|
||||
func (m *Storage) List(limit int) ([]db.KV, error) {
|
||||
ret := []db.KV{}
|
||||
err := m.Iterate(func(key []byte, value []byte) (bool, error) {
|
||||
ret = append(ret, db.KV{K: db.Clone(key), V: db.Clone(value)})
|
||||
func (m *Storage) List(limit int) ([]merkletree.KV, error) {
|
||||
ret := []merkletree.KV{}
|
||||
err := m.Iterate(func(key []byte, value *merkletree.Node) (bool, error) {
|
||||
ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: *value})
|
||||
if len(ret) == limit {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"github.com/iden3/go-merkletree"
|
||||
"testing"
|
||||
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
"github.com/iden3/go-merkletree/db/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryStorageInterface(t *testing.T) {
|
||||
var db db.Storage //nolint:gosimple
|
||||
var db merkletree.Storage //nolint:gosimple
|
||||
|
||||
db = NewMemoryStorage()
|
||||
require.NotNil(t, db)
|
||||
|
||||
@@ -2,7 +2,7 @@ package pebble
|
||||
|
||||
import (
|
||||
"github.com/cockroachdb/pebble"
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
"github.com/iden3/go-merkletree"
|
||||
)
|
||||
|
||||
// Storage implements the db.Storage interface
|
||||
@@ -30,20 +30,20 @@ func NewPebbleStorage(path string, errorIfMissing bool) (*Storage, error) {
|
||||
}
|
||||
|
||||
// WithPrefix implements the method WithPrefix of the interface db.Storage
|
||||
func (p *Storage) WithPrefix(prefix []byte) db.Storage {
|
||||
return &Storage{p.pdb, db.Concat(p.prefix, prefix)}
|
||||
func (p *Storage) WithPrefix(prefix []byte) merkletree.Storage {
|
||||
return &Storage{p.pdb, merkletree.Concat(p.prefix, prefix)}
|
||||
}
|
||||
|
||||
// NewTx implements the method NewTx of the interface db.Storage
|
||||
func (p *Storage) NewTx() (db.Tx, error) {
|
||||
func (p *Storage) NewTx() (merkletree.Tx, error) {
|
||||
return &StorageTx{p, p.pdb.NewIndexedBatch()}, nil
|
||||
}
|
||||
|
||||
// Get retreives a value from a key in the db.Storage
|
||||
func (p *Storage) Get(key []byte) ([]byte, error) {
|
||||
v, closer, err := p.pdb.Get(db.Concat(p.prefix, key[:]))
|
||||
v, closer, err := p.pdb.Get(merkletree.Concat(p.prefix, key[:]))
|
||||
if err == pebble.ErrNotFound {
|
||||
return nil, db.ErrNotFound
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -104,11 +104,11 @@ func (p *Storage) Iterate(f func([]byte, []byte) (bool, error)) (err error) {
|
||||
func (tx *StorageTx) Get(key []byte) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
fullkey := db.Concat(tx.prefix, key)
|
||||
fullkey := merkletree.Concat(tx.prefix, key)
|
||||
|
||||
v, closer, err := tx.batch.Get(fullkey)
|
||||
if err == pebble.ErrNotFound {
|
||||
return nil, db.ErrNotFound
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -119,11 +119,11 @@ func (tx *StorageTx) Get(key []byte) ([]byte, error) {
|
||||
|
||||
// Put saves a key:value into the db.Storage
|
||||
func (tx *StorageTx) Put(k, v []byte) error {
|
||||
return tx.batch.Set(db.Concat(tx.prefix, k[:]), v, nil)
|
||||
return tx.batch.Set(merkletree.Concat(tx.prefix, k[:]), v, nil)
|
||||
}
|
||||
|
||||
// Add implements the method Add of the interface db.Tx
|
||||
func (tx *StorageTx) Add(atx db.Tx) error {
|
||||
func (tx *StorageTx) Add(atx merkletree.Tx) error {
|
||||
patx := atx.(*StorageTx)
|
||||
return tx.batch.Apply(patx.batch, nil)
|
||||
}
|
||||
@@ -151,10 +151,10 @@ func (p *Storage) Pebble() *pebble.DB {
|
||||
}
|
||||
|
||||
// List implements the method List of the interface db.Storage
|
||||
func (p *Storage) List(limit int) ([]db.KV, error) {
|
||||
ret := []db.KV{}
|
||||
func (p *Storage) List(limit int) ([]merkletree.KV, error) {
|
||||
ret := []merkletree.KV{}
|
||||
err := p.Iterate(func(key []byte, value []byte) (bool, error) {
|
||||
ret = append(ret, db.KV{K: db.Clone(key), V: db.Clone(value)})
|
||||
ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: merkletree.Clone(value)})
|
||||
if len(ret) == limit {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package pebble
|
||||
|
||||
import (
|
||||
"github.com/iden3/go-merkletree"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
"github.com/iden3/go-merkletree/db/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var rmDirs []string
|
||||
|
||||
func pebbleStorage(t *testing.T) db.Storage {
|
||||
func pebbleStorage(t *testing.T) merkletree.Storage {
|
||||
dir, err := ioutil.TempDir("", "db")
|
||||
rmDirs = append(rmDirs, dir)
|
||||
if err != nil {
|
||||
@@ -37,7 +37,7 @@ func TestPebble(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPebbleInterface(t *testing.T) {
|
||||
var db db.Storage //nolint:gosimple
|
||||
var db merkletree.Storage //nolint:gosimple
|
||||
|
||||
dir, err := ioutil.TempDir("", "db")
|
||||
require.Nil(t, err)
|
||||
|
||||
312
db/sql/sql.go
Normal file
312
db/sql/sql.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/iden3/go-merkletree"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// TODO: upsert or insert?
|
||||
const upsertStmt = `INSERT INTO mt_nodes (mt_id, key, type, child_l, child_r, entry) VALUES ($1, $2, $3, $4, $5, $6) ` +
|
||||
`ON CONFLICT (mt_id, key) DO UPDATE SET type = $3, child_l = $4, child_r = $5, entry = $6`
|
||||
|
||||
const updateRootStmt = `INSERT INTO mt_roots (mt_id, key) VALUES ($1, $2) ` +
|
||||
`ON CONFLICT (mt_id) DO UPDATE SET key = $2`
|
||||
|
||||
// Storage implements the db.Storage interface
|
||||
type Storage struct {
|
||||
db *sqlx.DB
|
||||
mtId uint64
|
||||
currentVersion uint64
|
||||
currentRoot *merkletree.Hash
|
||||
}
|
||||
|
||||
// StorageTx implements the db.Tx interface
|
||||
type StorageTx struct {
|
||||
*Storage
|
||||
tx *sqlx.Tx
|
||||
cache merkletree.KvMap
|
||||
currentRoot *merkletree.Hash
|
||||
}
|
||||
|
||||
type NodeItem struct {
|
||||
MTId uint64 `db:"mt_id"`
|
||||
Key []byte `db:"key"`
|
||||
// Type is the type of node in the tree.
|
||||
Type byte `db:"type"`
|
||||
// ChildL is the left child of a middle node.
|
||||
ChildL []byte `db:"child_l"`
|
||||
// ChildR is the right child of a middle node.
|
||||
ChildR []byte `db:"child_r"`
|
||||
// Entry is the data stored in a leaf node.
|
||||
Entry []byte `db:"entry"`
|
||||
CreatedAt *uint64 `db:"created_at"`
|
||||
DeletedAt *uint64 `db:"deleted_at"`
|
||||
}
|
||||
|
||||
type RootItem struct {
|
||||
MTId uint64 `db:"mt_id"`
|
||||
Key []byte `db:"key"`
|
||||
CreatedAt *uint64 `db:"created_at"`
|
||||
DeletedAt *uint64 `db:"deleted_at"`
|
||||
}
|
||||
|
||||
// NewSqlStorage returns a new Storage
|
||||
func NewSqlStorage(db *sqlx.DB, errorIfMissing bool) (*Storage, error) {
|
||||
return &Storage{db: db}, nil
|
||||
}
|
||||
|
||||
// WithPrefix implements the method WithPrefix of the interface db.Storage
|
||||
func (s *Storage) WithPrefix(prefix []byte) merkletree.Storage {
|
||||
//return &Storage{db: s.db, prefix: merkletree.Concat(s.prefix, prefix)}
|
||||
// TODO: remove WithPrefix method
|
||||
mtId := s.mtId<<4 | binary.LittleEndian.Uint64(prefix)
|
||||
return &Storage{db: s.db, mtId: mtId}
|
||||
}
|
||||
|
||||
// NewTx implements the method NewTx of the interface db.Storage
|
||||
func (s *Storage) NewTx() (merkletree.Tx, error) {
|
||||
tx, err := s.db.Beginx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StorageTx{s, tx, make(merkletree.KvMap), s.currentRoot}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from a key in the db.Storage
|
||||
func (s *Storage) Get(key []byte) (*merkletree.Node, error) {
|
||||
item := NodeItem{}
|
||||
err := s.db.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", s.mtId, key)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node, err := item.Node()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// GetRoot retrieves a merkle tree root hash in the interface db.Tx
|
||||
func (s *Storage) GetRoot() (*merkletree.Hash, error) {
|
||||
var root merkletree.Hash
|
||||
|
||||
if s.currentRoot != nil {
|
||||
copy(root[:], s.currentRoot[:])
|
||||
return &root, nil
|
||||
}
|
||||
|
||||
item := RootItem{}
|
||||
err := s.db.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", s.mtId)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(root[:], item.Key[:])
|
||||
return &root, nil
|
||||
}
|
||||
|
||||
// Iterate implements the method Iterate of the interface db.Storage
|
||||
func (s *Storage) Iterate(f func([]byte, *merkletree.Node) (bool, error)) error {
|
||||
items := []NodeItem{}
|
||||
|
||||
err := s.db.Select(&items, "SELECT * FROM mt_nodes WHERE key WHERE mt_id = $1", s.mtId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range items {
|
||||
k := v.Key[:]
|
||||
n, err := v.Node()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cont, err := f(k, n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !cont {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from a key in the interface db.Tx
|
||||
func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
|
||||
//fullKey := append(tx.mtId, key...)
|
||||
fullKey := key
|
||||
if value, ok := tx.cache.Get(fullKey); ok {
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
item := NodeItem{}
|
||||
err := tx.tx.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", tx.mtId, key)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node, err := item.Node()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// Put saves a key:value into the db.Storage
|
||||
func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
|
||||
//fullKey := append(tx.mtId, k...)
|
||||
fullKey := k
|
||||
tx.cache.Put(fullKey, *v)
|
||||
fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRoot retrieves a merkle tree root hash in the interface db.Tx
|
||||
func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
|
||||
var root merkletree.Hash
|
||||
|
||||
if tx.currentRoot != nil {
|
||||
copy(root[:], tx.currentRoot[:])
|
||||
return &root, nil
|
||||
}
|
||||
|
||||
item := RootItem{}
|
||||
err := tx.tx.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", tx.mtId)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, merkletree.ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(root[:], item.Key[:])
|
||||
return &root, nil
|
||||
}
|
||||
|
||||
// SetRoot sets a hash of merkle tree root in the interface db.Tx
|
||||
func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
|
||||
root := &merkletree.Hash{}
|
||||
copy(root[:], hash[:])
|
||||
tx.currentRoot = root
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add implements the method Add of the interface db.Tx
|
||||
func (tx *StorageTx) Add(atx merkletree.Tx) error {
|
||||
dbtx := atx.(*StorageTx)
|
||||
//if !bytes.Equal(tx.prefix, dbtx.prefix) {
|
||||
// // TODO: change cache to store prefix too!
|
||||
// return errors.New("adding StorageTx with different prefix is not implemented")
|
||||
//}
|
||||
if tx.mtId != dbtx.mtId {
|
||||
// TODO: change cache to store prefix too!
|
||||
return errors.New("adding StorageTx with different prefix is not implemented")
|
||||
}
|
||||
for _, v := range dbtx.cache {
|
||||
tx.cache.Put(v.K, v.V)
|
||||
}
|
||||
tx.currentRoot = dbtx.currentRoot
|
||||
return nil
|
||||
}
|
||||
|
||||
// Commit implements the method Commit of the interface db.Tx
|
||||
func (tx *StorageTx) Commit() error {
|
||||
// execute a query on the server
|
||||
fmt.Printf("Commit\n")
|
||||
for _, v := range tx.cache {
|
||||
fmt.Printf("key %x, value %+v\n", v.K, v.V)
|
||||
node := v.V
|
||||
|
||||
var childL []byte
|
||||
if node.ChildL != nil {
|
||||
childL = append(childL, node.ChildL[:]...)
|
||||
}
|
||||
|
||||
var childR []byte
|
||||
if node.ChildR != nil {
|
||||
childR = append(childR, node.ChildR[:]...)
|
||||
}
|
||||
|
||||
var entry []byte
|
||||
if node.Entry[0] != nil && node.Entry[1] != nil {
|
||||
entry = append(node.Entry[0][:], node.Entry[1][:]...)
|
||||
}
|
||||
|
||||
key, err := node.Key()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.tx.Exec(upsertStmt, tx.mtId, key[:], node.Type, childL, childR, entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.tx.Exec(updateRootStmt, tx.mtId, tx.currentRoot[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.cache = nil
|
||||
return tx.tx.Commit()
|
||||
}
|
||||
|
||||
// Close implements the method Close of the interface db.Tx
|
||||
func (tx *StorageTx) Close() {
|
||||
//tx.tx.Rollback()
|
||||
tx.cache = nil
|
||||
}
|
||||
|
||||
// Close implements the method Close of the interface db.Storage
|
||||
func (s *Storage) Close() {
|
||||
err := s.db.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// List implements the method List of the interface db.Storage
|
||||
func (s *Storage) List(limit int) ([]merkletree.KV, error) {
|
||||
ret := []merkletree.KV{}
|
||||
err := s.Iterate(func(key []byte, value *merkletree.Node) (bool, error) {
|
||||
ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: *value})
|
||||
if len(ret) == limit {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (item *NodeItem) Node() (*merkletree.Node, error) {
|
||||
node := merkletree.Node{
|
||||
Type: merkletree.NodeType(item.Type),
|
||||
}
|
||||
if item.ChildL != nil {
|
||||
node.ChildL = &merkletree.Hash{}
|
||||
copy(node.ChildL[:], item.ChildL[:])
|
||||
}
|
||||
if item.ChildR != nil {
|
||||
node.ChildR = &merkletree.Hash{}
|
||||
copy(node.ChildR[:], item.ChildR[:])
|
||||
}
|
||||
if len(item.Entry) > 0 {
|
||||
if len(item.Entry) != 2*merkletree.ElemBytesLen {
|
||||
return nil, merkletree.ErrNodeBytesBadSize
|
||||
}
|
||||
node.Entry = [2]*merkletree.Hash{{}, {}}
|
||||
copy(node.Entry[0][:], item.Entry[0:32])
|
||||
copy(node.Entry[1][:], item.Entry[32:64])
|
||||
}
|
||||
return &node, nil
|
||||
}
|
||||
816
db/sql/sql_test.go
Normal file
816
db/sql/sql_test.go
Normal file
@@ -0,0 +1,816 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/iden3/go-iden3-crypto/constants"
|
||||
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
|
||||
"github.com/iden3/go-merkletree"
|
||||
"github.com/iden3/go-merkletree/db/memory"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"math/big"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func sqlStorage(t *testing.T) merkletree.Storage {
|
||||
host := os.Getenv("PGHOST")
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
port, _ := strconv.Atoi(os.Getenv("PGPORT"))
|
||||
if port == 0 {
|
||||
port = 5432
|
||||
}
|
||||
user := os.Getenv("PGUSER")
|
||||
if user == "" {
|
||||
user = "user"
|
||||
}
|
||||
password := os.Getenv("PGPASSWORD")
|
||||
if password == "" {
|
||||
panic("No PGPASSWORD envvar specified")
|
||||
}
|
||||
dbname := os.Getenv("PGDATABASE")
|
||||
if dbname == "" {
|
||||
dbname = "test"
|
||||
}
|
||||
|
||||
psqlconn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
host,
|
||||
port,
|
||||
user,
|
||||
password,
|
||||
dbname,
|
||||
)
|
||||
dbx, err := sqlx.Connect("postgres", psqlconn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// clear MerkleTree table
|
||||
dbx.Exec("TRUNCATE TABLE mt_roots")
|
||||
dbx.Exec("TRUNCATE TABLE mt_nodes")
|
||||
|
||||
sto, err := NewSqlStorage(dbx, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
})
|
||||
|
||||
return sto
|
||||
}
|
||||
|
||||
func TestSql(t *testing.T) {
|
||||
//sto := sqlStorage(t)
|
||||
//t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
|
||||
// test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
|
||||
//})
|
||||
//t.Run("TestStorageInsertGet", func(t *testing.T) {
|
||||
// test.TestStorageInsertGet(t, sqlStorage(t))
|
||||
//})
|
||||
//test.TestStorageWithPrefix(t, sqlStorage(t))
|
||||
//test.TestConcatTx(t, sqlStorage(t))
|
||||
//test.TestList(t, sqlStorage(t))
|
||||
//test.TestIterate(t, sqlStorage(t))
|
||||
}
|
||||
|
||||
var debug = false
|
||||
|
||||
type Fatalable interface {
|
||||
Fatal(args ...interface{})
|
||||
}
|
||||
|
||||
func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree {
|
||||
sto := sqlStorage(f)
|
||||
|
||||
mt, err := merkletree.NewMerkleTree(sto, maxLevels)
|
||||
if err != nil {
|
||||
f.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
return mt
|
||||
}
|
||||
|
||||
func TestHashParsers(t *testing.T) {
|
||||
h0 := merkletree.NewHashFromBigInt(big.NewInt(0))
|
||||
assert.Equal(t, "0", h0.String())
|
||||
h1 := merkletree.NewHashFromBigInt(big.NewInt(1))
|
||||
assert.Equal(t, "1", h1.String())
|
||||
h10 := merkletree.NewHashFromBigInt(big.NewInt(10))
|
||||
assert.Equal(t, "10", h10.String())
|
||||
|
||||
h7l := merkletree.NewHashFromBigInt(big.NewInt(1234567))
|
||||
assert.Equal(t, "1234567", h7l.String())
|
||||
h8l := merkletree.NewHashFromBigInt(big.NewInt(12345678))
|
||||
assert.Equal(t, "12345678...", h8l.String())
|
||||
|
||||
b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll
|
||||
assert.True(t, ok)
|
||||
h := merkletree.NewHashFromBigInt(b)
|
||||
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll
|
||||
assert.Equal(t, "49322979...", h.String())
|
||||
assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
|
||||
|
||||
b1, err := merkletree.NewBigIntFromHashBytes(b.Bytes())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
|
||||
|
||||
b2, err := merkletree.NewHashFromBytes(b.Bytes())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, b.String(), b2.BigInt().String())
|
||||
|
||||
h2, err := merkletree.NewHashFromHex(h.Hex())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, h, h2)
|
||||
_, err = merkletree.NewHashFromHex("0x12")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// check limits
|
||||
a := new(big.Int).Sub(constants.Q, big.NewInt(1))
|
||||
testHashParsers(t, a)
|
||||
a = big.NewInt(int64(1))
|
||||
testHashParsers(t, a)
|
||||
}
|
||||
|
||||
func testHashParsers(t *testing.T, a *big.Int) {
|
||||
require.True(t, cryptoUtils.CheckBigIntInField(a))
|
||||
h := merkletree.NewHashFromBigInt(a)
|
||||
assert.Equal(t, a, h.BigInt())
|
||||
hFromBytes, err := merkletree.NewHashFromBytes(h.Bytes())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, h, hFromBytes)
|
||||
assert.Equal(t, a, hFromBytes.BigInt())
|
||||
assert.Equal(t, a.String(), hFromBytes.BigInt().String())
|
||||
hFromHex, err := merkletree.NewHashFromHex(h.Hex())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, h, hFromHex)
|
||||
|
||||
aBIFromHBytes, err := merkletree.NewBigIntFromHashBytes(h.Bytes())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, a, aBIFromHBytes)
|
||||
assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
|
||||
}
|
||||
|
||||
func TestNewTree(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 10)
|
||||
mt, err := merkletree.NewMerkleTree(memory.NewMemoryStorage(), 10)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
err = mt.Add(big.NewInt(1), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
err = mt.Add(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(44), v)
|
||||
|
||||
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
|
||||
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
|
||||
}
|
||||
|
||||
func TestAddDifferentOrder(t *testing.T) {
|
||||
mt1 := newTestingMerkle(t, 140)
|
||||
for i := 0; i < 16; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(0)
|
||||
if err := mt1.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
mt2 := newTestingMerkle(t, 140)
|
||||
for i := 16 - 1; i >= 0; i-- {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(0)
|
||||
if err := mt2.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
|
||||
assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll
|
||||
}
|
||||
|
||||
func TestAddRepeatedIndex(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
k := big.NewInt(int64(3))
|
||||
v := big.NewInt(int64(12))
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := mt.Add(k, v)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, merkletree.ErrEntryIndexAlreadyExists, err)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(int64(i * 2))
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
k, v, _, err := mt.Get(big.NewInt(10))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(10), k)
|
||||
assert.Equal(t, big.NewInt(20), v)
|
||||
|
||||
k, v, _, err = mt.Get(big.NewInt(15))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(15), k)
|
||||
assert.Equal(t, big.NewInt(30), v)
|
||||
|
||||
k, v, _, err = mt.Get(big.NewInt(16))
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, merkletree.ErrKeyNotFound, err)
|
||||
assert.Equal(t, "0", k.String())
|
||||
assert.Equal(t, "0", v.String())
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(int64(i * 2))
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
_, v, _, err := mt.Get(big.NewInt(10))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(20), v)
|
||||
|
||||
_, err = mt.Update(big.NewInt(10), big.NewInt(1024))
|
||||
assert.Nil(t, err)
|
||||
_, v, _, err = mt.Get(big.NewInt(10))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(1024), v)
|
||||
|
||||
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
|
||||
assert.Equal(t, merkletree.ErrKeyNotFound, err)
|
||||
|
||||
}
|
||||
|
||||
func TestUpdate2(t *testing.T) {
|
||||
mt1 := newTestingMerkle(t, 140)
|
||||
mt2 := newTestingMerkle(t, 140)
|
||||
|
||||
err := mt1.Add(big.NewInt(1), big.NewInt(119))
|
||||
assert.Nil(t, err)
|
||||
err = mt1.Add(big.NewInt(2), big.NewInt(229))
|
||||
assert.Nil(t, err)
|
||||
err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mt2.Add(big.NewInt(1), big.NewInt(11))
|
||||
assert.Nil(t, err)
|
||||
err = mt2.Add(big.NewInt(2), big.NewInt(22))
|
||||
assert.Nil(t, err)
|
||||
err = mt2.Add(big.NewInt(9876), big.NewInt(10))
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = mt1.Update(big.NewInt(1), big.NewInt(11))
|
||||
assert.Nil(t, err)
|
||||
_, err = mt1.Update(big.NewInt(2), big.NewInt(22))
|
||||
assert.Nil(t, err)
|
||||
_, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, mt1.Root(), mt2.Root())
|
||||
}
|
||||
|
||||
func TestGenerateAndVerifyProof128(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
for i := 0; i < 128; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(0)
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "0", v.String())
|
||||
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
|
||||
}
|
||||
|
||||
func TestTreeLimit(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 5)
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
err := mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
// here the tree is full, should not allow to add more data as reaches the maximum number of levels
|
||||
err := mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, merkletree.ErrReachedMaxLevel, err)
|
||||
}
|
||||
|
||||
func TestSiblingsFromProof(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
for i := 0; i < 64; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(0)
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
siblings := merkletree.SiblingsFromProof(proof)
|
||||
assert.Equal(t, 6, len(siblings))
|
||||
assert.Equal(t,
|
||||
"d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b",
|
||||
siblings[0].Hex())
|
||||
assert.Equal(t,
|
||||
"9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16",
|
||||
siblings[1].Hex())
|
||||
assert.Equal(t,
|
||||
"de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27",
|
||||
siblings[2].Hex())
|
||||
assert.Equal(t,
|
||||
"5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317",
|
||||
siblings[3].Hex())
|
||||
assert.Equal(t,
|
||||
"77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122",
|
||||
siblings[4].Hex())
|
||||
assert.Equal(t,
|
||||
"943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07",
|
||||
siblings[5].Hex())
|
||||
}
|
||||
|
||||
func TestVerifyProofCases(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
defer mt.DB().Close()
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Existence proof
|
||||
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, proof.Existence, true)
|
||||
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
|
||||
assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll
|
||||
|
||||
for i := 8; i < 32; i++ {
|
||||
proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
|
||||
assert.Nil(t, err)
|
||||
if debug {
|
||||
fmt.Println(i, proof)
|
||||
}
|
||||
}
|
||||
// Non-existence proof, empty aux
|
||||
proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, proof.Existence, false)
|
||||
// assert.True(t, proof.nodeAux == nil)
|
||||
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
|
||||
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
|
||||
|
||||
// Non-existence proof, diff. node aux
|
||||
proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, proof.Existence, false)
|
||||
assert.True(t, proof.NodeAux != nil)
|
||||
assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
|
||||
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
|
||||
}
|
||||
|
||||
func TestVerifyProofFalse(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
defer mt.DB().Close()
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Invalid existence proof (node used for verification doesn't
|
||||
// correspond to node in the proof)
|
||||
proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, proof.Existence, true)
|
||||
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
|
||||
|
||||
// Invalid non-existence proof (Non-existence proof, diff. node aux)
|
||||
proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, proof.Existence, true)
|
||||
// Now we change the proof from existence to non-existence, and add e's
|
||||
// data as auxiliary node.
|
||||
proof.Existence = false
|
||||
proof.NodeAux = &merkletree.NodeAux{Key: merkletree.NewHashFromBigInt(big.NewInt(int64(4))),
|
||||
Value: merkletree.NewHashFromBigInt(big.NewInt(4))}
|
||||
assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
|
||||
}
|
||||
|
||||
func TestGraphViz(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
_ = mt.Add(big.NewInt(1), big.NewInt(0))
|
||||
_ = mt.Add(big.NewInt(2), big.NewInt(0))
|
||||
_ = mt.Add(big.NewInt(3), big.NewInt(0))
|
||||
_ = mt.Add(big.NewInt(4), big.NewInt(0))
|
||||
_ = mt.Add(big.NewInt(5), big.NewInt(0))
|
||||
_ = mt.Add(big.NewInt(100), big.NewInt(0))
|
||||
|
||||
// mt.PrintGraphViz(nil)
|
||||
|
||||
expected := `digraph hierarchy {
|
||||
node [fontname=Monospace,fontsize=10,shape=box]
|
||||
"56332309..." -> {"18483622..." "20902180..."}
|
||||
"18483622..." -> {"75768243..." "16893244..."}
|
||||
"75768243..." -> {"empty0" "21857056..."}
|
||||
"empty0" [style=dashed,label=0];
|
||||
"21857056..." -> {"51072523..." "empty1"}
|
||||
"empty1" [style=dashed,label=0];
|
||||
"51072523..." -> {"17311038..." "empty2"}
|
||||
"empty2" [style=dashed,label=0];
|
||||
"17311038..." -> {"69499803..." "21008290..."}
|
||||
"69499803..." [style=filled];
|
||||
"21008290..." [style=filled];
|
||||
"16893244..." [style=filled];
|
||||
"20902180..." -> {"12496585..." "18055627..."}
|
||||
"12496585..." -> {"19374975..." "15739329..."}
|
||||
"19374975..." [style=filled];
|
||||
"15739329..." [style=filled];
|
||||
"18055627..." [style=filled];
|
||||
}
|
||||
`
|
||||
w := bytes.NewBufferString("")
|
||||
err := mt.GraphViz(w, nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, []byte(expected), w.Bytes())
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 10)
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
err := mt.Add(big.NewInt(1), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
err = mt.Add(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
// mt.PrintGraphViz(nil)
|
||||
|
||||
err = mt.Delete(big.NewInt(33))
|
||||
// mt.PrintGraphViz(nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
err = mt.Delete(big.NewInt(1234))
|
||||
assert.Nil(t, err)
|
||||
err = mt.Delete(big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
}
|
||||
|
||||
func TestDelete2(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
for i := 0; i < 8; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(0)
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
expectedRoot := mt.Root()
|
||||
|
||||
k := big.NewInt(8)
|
||||
v := big.NewInt(0)
|
||||
err := mt.Add(k, v)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = mt.Delete(big.NewInt(8))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectedRoot, mt.Root())
|
||||
|
||||
mt2 := newTestingMerkle(t, 140)
|
||||
for i := 0; i < 8; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(0)
|
||||
if err := mt2.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, mt2.Root(), mt.Root())
|
||||
}
|
||||
|
||||
func TestDelete3(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
err := mt.Add(big.NewInt(1), big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mt.Add(big.NewInt(2), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll
|
||||
err = mt.Delete(big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
mt2 := newTestingMerkle(t, 140)
|
||||
err = mt2.Add(big.NewInt(2), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, mt2.Root(), mt.Root())
|
||||
}
|
||||
|
||||
func TestDelete4(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
err := mt.Add(big.NewInt(1), big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mt.Add(big.NewInt(2), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mt.Add(big.NewInt(3), big.NewInt(3))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll
|
||||
err = mt.Delete(big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
mt2 := newTestingMerkle(t, 140)
|
||||
err = mt2.Add(big.NewInt(2), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
err = mt2.Add(big.NewInt(3), big.NewInt(3))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, mt2.Root(), mt.Root())
|
||||
}
|
||||
|
||||
func TestDelete5(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 10)
|
||||
|
||||
err := mt.Add(big.NewInt(1), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
err = mt.Add(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
err = mt.Delete(big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll
|
||||
|
||||
mt2 := newTestingMerkle(t, 140)
|
||||
err = mt2.Add(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, mt2.Root(), mt.Root())
|
||||
}
|
||||
|
||||
func TestDeleteNonExistingKeys(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 10)
|
||||
|
||||
err := mt.Add(big.NewInt(1), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
err = mt.Add(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = mt.Delete(big.NewInt(33))
|
||||
assert.Nil(t, err)
|
||||
err = mt.Delete(big.NewInt(33))
|
||||
assert.Equal(t, merkletree.ErrKeyNotFound, err)
|
||||
|
||||
err = mt.Delete(big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
err = mt.Delete(big.NewInt(33))
|
||||
assert.Equal(t, merkletree.ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestDumpLeafsImportLeafs(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 140)
|
||||
|
||||
q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
|
||||
for i := 0; i < 10; i++ {
|
||||
// use numbers near under Q
|
||||
k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
|
||||
v := big.NewInt(0)
|
||||
err := mt.Add(k, v)
|
||||
require.Nil(t, err)
|
||||
|
||||
// use numbers near above 0
|
||||
k = big.NewInt(int64(i))
|
||||
err = mt.Add(k, v)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
d, err := mt.DumpLeafs(nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
mt2, err := merkletree.NewMerkleTree(memory.NewMemoryStorage(), 140)
|
||||
require.Nil(t, err)
|
||||
err = mt2.ImportDumpedLeafs(d)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, mt.Root(), mt2.Root())
|
||||
}
|
||||
|
||||
func TestAddAndGetCircomProof(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 10)
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "0", cpp.OldRoot.String())
|
||||
assert.Equal(t, "13578938...", cpp.NewRoot.String())
|
||||
assert.Equal(t, "0", cpp.OldKey.String())
|
||||
assert.Equal(t, "0", cpp.OldValue.String())
|
||||
assert.Equal(t, "1", cpp.NewKey.String())
|
||||
assert.Equal(t, "2", cpp.NewValue.String())
|
||||
assert.Equal(t, true, cpp.IsOld0)
|
||||
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
|
||||
|
||||
cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "13578938...", cpp.OldRoot.String())
|
||||
assert.Equal(t, "54123936...", cpp.NewRoot.String())
|
||||
assert.Equal(t, "1", cpp.OldKey.String())
|
||||
assert.Equal(t, "2", cpp.OldValue.String())
|
||||
assert.Equal(t, "33", cpp.NewKey.String())
|
||||
assert.Equal(t, "44", cpp.NewValue.String())
|
||||
assert.Equal(t, false, cpp.IsOld0)
|
||||
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
|
||||
|
||||
cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "54123936...", cpp.OldRoot.String())
|
||||
assert.Equal(t, "50943640...", cpp.NewRoot.String())
|
||||
assert.Equal(t, "0", cpp.OldKey.String())
|
||||
assert.Equal(t, "0", cpp.OldValue.String())
|
||||
assert.Equal(t, "55", cpp.NewKey.String())
|
||||
assert.Equal(t, "66", cpp.NewValue.String())
|
||||
assert.Equal(t, true, cpp.IsOld0)
|
||||
assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
|
||||
}
|
||||
|
||||
func TestUpdateCircomProcessorProof(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 10)
|
||||
|
||||
for i := 0; i < 16; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(int64(i * 2))
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
_, v, _, err := mt.Get(big.NewInt(10))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(20), v)
|
||||
|
||||
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "39010880...", cpp.OldRoot.String())
|
||||
assert.Equal(t, "18587862...", cpp.NewRoot.String())
|
||||
assert.Equal(t, "10", cpp.OldKey.String())
|
||||
assert.Equal(t, "20", cpp.OldValue.String())
|
||||
assert.Equal(t, "10", cpp.NewKey.String())
|
||||
assert.Equal(t, "1024", cpp.NewValue.String())
|
||||
assert.Equal(t, false, cpp.IsOld0)
|
||||
assert.Equal(t,
|
||||
"[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]",
|
||||
fmt.Sprintf("%v", cpp.Siblings))
|
||||
}
|
||||
|
||||
func TestSmtVerifier(t *testing.T) {
|
||||
mt := newTestingMerkle(t, 4)
|
||||
|
||||
err := mt.Add(big.NewInt(1), big.NewInt(11))
|
||||
assert.Nil(t, err)
|
||||
|
||||
cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil)
|
||||
assert.Nil(t, err)
|
||||
jCvp, err := json.Marshal(cvp)
|
||||
assert.Nil(t, err)
|
||||
// expect siblings to be '[]', instead of 'null'
|
||||
expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll
|
||||
|
||||
assert.Equal(t, expected, string(jCvp))
|
||||
err = mt.Add(big.NewInt(2), big.NewInt(22))
|
||||
assert.Nil(t, err)
|
||||
err = mt.Add(big.NewInt(3), big.NewInt(33))
|
||||
assert.Nil(t, err)
|
||||
err = mt.Add(big.NewInt(4), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
|
||||
cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
jCvp, err = json.Marshal(cvp)
|
||||
assert.Nil(t, err)
|
||||
// Test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
// Expect siblings with the extra 0 that the circom circuits need
|
||||
expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
|
||||
assert.Equal(t, expected, string(jCvp))
|
||||
|
||||
cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
jCvp, err = json.Marshal(cvp)
|
||||
assert.Nil(t, err)
|
||||
// Test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
// Without the extra 0 that the circom circuits need, but that are not
|
||||
// needed at a smart contract verification
|
||||
expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
|
||||
assert.Equal(t, expected, string(jCvp))
|
||||
}
|
||||
|
||||
func TestTypesMarshalers(t *testing.T) {
|
||||
// test Hash marshalers
|
||||
h, err := merkletree.NewHashFromString("42")
|
||||
assert.Nil(t, err)
|
||||
s, err := json.Marshal(h)
|
||||
assert.Nil(t, err)
|
||||
var h2 *merkletree.Hash
|
||||
err = json.Unmarshal(s, &h2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, h, h2)
|
||||
|
||||
// create CircomProcessorProof
|
||||
mt := newTestingMerkle(t, 10)
|
||||
for i := 0; i < 16; i++ {
|
||||
k := big.NewInt(int64(i))
|
||||
v := big.NewInt(int64(i * 2))
|
||||
if err := mt.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
_, v, _, err := mt.Get(big.NewInt(10))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(20), v)
|
||||
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
|
||||
assert.Nil(t, err)
|
||||
|
||||
// test CircomProcessorProof marshalers
|
||||
b, err := json.Marshal(&cpp)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cpp2 *merkletree.CircomProcessorProof
|
||||
err = json.Unmarshal(b, &cpp2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, cpp, cpp2)
|
||||
}
|
||||
114
db/test/test.go
114
db/test/test.go
@@ -2,49 +2,61 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"github.com/iden3/go-merkletree"
|
||||
"testing"
|
||||
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestReturnKnownErrIfNotExists checks that the implementation of the
|
||||
// db.Storage interface returns the expected error in the case that the value
|
||||
// is not found
|
||||
func TestReturnKnownErrIfNotExists(t *testing.T, sto db.Storage) {
|
||||
func TestReturnKnownErrIfNotExists(t *testing.T, sto merkletree.Storage) {
|
||||
k := []byte("key")
|
||||
|
||||
tx, err := sto.NewTx()
|
||||
//defer func() {
|
||||
// tx.Close()
|
||||
// sto.Close()
|
||||
//}()
|
||||
|
||||
assert.Nil(t, err)
|
||||
_, err = tx.Get(k)
|
||||
assert.EqualError(t, err, db.ErrNotFound.Error())
|
||||
assert.EqualError(t, err, merkletree.ErrNotFound.Error())
|
||||
}
|
||||
|
||||
// TestStorageInsertGet checks that the implementation of the db.Storage
|
||||
// interface behaves as expected
|
||||
func TestStorageInsertGet(t *testing.T, sto db.Storage) {
|
||||
func TestStorageInsertGet(t *testing.T, sto merkletree.Storage) {
|
||||
key := []byte("key")
|
||||
value := []byte("data")
|
||||
value := merkletree.Hash{1, 1, 1, 1}
|
||||
|
||||
tx, err := sto.NewTx()
|
||||
//defer func() {
|
||||
// tx.Close()
|
||||
// sto.Close()
|
||||
//}()
|
||||
assert.Nil(t, err)
|
||||
err = tx.Put(key, value)
|
||||
node := merkletree.NewNodeMiddle(&value, &value)
|
||||
err = tx.Put(key, node)
|
||||
assert.Nil(t, err)
|
||||
v, err := tx.Get(key)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, value, v)
|
||||
assert.Equal(t, value, *v.ChildL)
|
||||
assert.Equal(t, value, *v.ChildR)
|
||||
assert.Nil(t, tx.Commit())
|
||||
|
||||
tx, err = sto.NewTx()
|
||||
assert.Nil(t, err)
|
||||
v, err = tx.Get(key)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, value, v)
|
||||
assert.Equal(t, value, *v.ChildL)
|
||||
assert.Equal(t, value, *v.ChildR)
|
||||
}
|
||||
|
||||
// TestStorageWithPrefix checks that the implementation of the db.Storage
|
||||
// interface behaves as expected for the WithPrefix method
|
||||
func TestStorageWithPrefix(t *testing.T, sto db.Storage) {
|
||||
func TestStorageWithPrefix(t *testing.T, sto merkletree.Storage) {
|
||||
k := []byte{9}
|
||||
|
||||
sto1 := sto.WithPrefix([]byte{1})
|
||||
@@ -54,39 +66,41 @@ func TestStorageWithPrefix(t *testing.T, sto db.Storage) {
|
||||
|
||||
sto1tx, err := sto1.NewTx()
|
||||
assert.Nil(t, err)
|
||||
err = sto1tx.Put(k, []byte{4, 5, 6})
|
||||
node := merkletree.NewNodeLeaf(&merkletree.Hash{1, 2, 3}, &merkletree.Hash{4, 5, 6})
|
||||
err = sto1tx.Put(k, node)
|
||||
assert.Nil(t, err)
|
||||
v1, err := sto1tx.Get(k)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, v1, []byte{4, 5, 6})
|
||||
assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1])
|
||||
assert.Nil(t, sto1tx.Commit())
|
||||
|
||||
sto2tx, err := sto2.NewTx()
|
||||
assert.Nil(t, err)
|
||||
err = sto2tx.Put(k, []byte{8, 9})
|
||||
node.Entry[1] = &merkletree.Hash{9, 10}
|
||||
err = sto2tx.Put(k, node)
|
||||
assert.Nil(t, err)
|
||||
v2, err := sto2tx.Get(k)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, v2, []byte{8, 9})
|
||||
assert.Equal(t, merkletree.Hash{9, 10}, *v2.Entry[1])
|
||||
assert.Nil(t, sto2tx.Commit())
|
||||
|
||||
// check outside tx
|
||||
|
||||
v1, err = sto1.Get(k)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, v1, []byte{4, 5, 6})
|
||||
assert.Equal(t, merkletree.Hash{4, 5, 6}, *v1.Entry[1])
|
||||
|
||||
v2, err = sto2.Get(k)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, v2, []byte{8, 9})
|
||||
assert.Equal(t, merkletree.Hash{9, 10}, *v2.Entry[1])
|
||||
}
|
||||
|
||||
// TestIterate checks that the implementation of the db.Storage interface
|
||||
// behaves as expected for the Iterate method
|
||||
func TestIterate(t *testing.T, sto db.Storage) {
|
||||
r := []db.KV{}
|
||||
lister := func(k []byte, v []byte) (bool, error) {
|
||||
r = append(r, db.KV{K: db.Clone(k), V: db.Clone(v)})
|
||||
func TestIterate(t *testing.T, sto merkletree.Storage) {
|
||||
r := []merkletree.KV{}
|
||||
lister := func(k []byte, v *merkletree.Node) (bool, error) {
|
||||
r = append(r, merkletree.KV{K: merkletree.Clone(k), V: *v})
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -96,44 +110,44 @@ func TestIterate(t *testing.T, sto db.Storage) {
|
||||
assert.Equal(t, 0, len(r))
|
||||
|
||||
sto1tx, _ := sto1.NewTx()
|
||||
err = sto1tx.Put([]byte{1}, []byte{4})
|
||||
err = sto1tx.Put([]byte{1}, merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5}))
|
||||
assert.Nil(t, err)
|
||||
err = sto1tx.Put([]byte{2}, []byte{5})
|
||||
err = sto1tx.Put([]byte{2}, merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6}))
|
||||
assert.Nil(t, err)
|
||||
err = sto1tx.Put([]byte{3}, []byte{6})
|
||||
err = sto1tx.Put([]byte{3}, merkletree.NewNodeMiddle(&merkletree.Hash{6}, &merkletree.Hash{7}))
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, sto1tx.Commit())
|
||||
|
||||
sto2 := sto.WithPrefix([]byte{2})
|
||||
sto2tx, _ := sto2.NewTx()
|
||||
err = sto2tx.Put([]byte{1}, []byte{7})
|
||||
err = sto2tx.Put([]byte{1}, merkletree.NewNodeMiddle(&merkletree.Hash{7}, &merkletree.Hash{8}))
|
||||
assert.Nil(t, err)
|
||||
err = sto2tx.Put([]byte{2}, []byte{8})
|
||||
err = sto2tx.Put([]byte{2}, merkletree.NewNodeMiddle(&merkletree.Hash{8}, &merkletree.Hash{9}))
|
||||
assert.Nil(t, err)
|
||||
err = sto2tx.Put([]byte{3}, []byte{9})
|
||||
err = sto2tx.Put([]byte{3}, merkletree.NewNodeMiddle(&merkletree.Hash{9}, &merkletree.Hash{10}))
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, sto2tx.Commit())
|
||||
|
||||
r = []db.KV{}
|
||||
r = []merkletree.KV{}
|
||||
err = sto1.Iterate(lister)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(r))
|
||||
assert.Equal(t, db.KV{K: []byte{1}, V: []byte{4}}, r[0])
|
||||
assert.Equal(t, db.KV{K: []byte{2}, V: []byte{5}}, r[1])
|
||||
assert.Equal(t, db.KV{K: []byte{3}, V: []byte{6}}, r[2])
|
||||
assert.Equal(t, merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5})}, r[0])
|
||||
assert.Equal(t, merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6})}, r[1])
|
||||
assert.Equal(t, merkletree.KV{K: []byte{3}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{6}, &merkletree.Hash{7})}, r[2])
|
||||
|
||||
r = []db.KV{}
|
||||
r = []merkletree.KV{}
|
||||
err = sto2.Iterate(lister)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(r))
|
||||
assert.Equal(t, db.KV{K: []byte{1}, V: []byte{7}}, r[0])
|
||||
assert.Equal(t, db.KV{K: []byte{2}, V: []byte{8}}, r[1])
|
||||
assert.Equal(t, db.KV{K: []byte{3}, V: []byte{9}}, r[2])
|
||||
assert.Equal(t, merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{7}, &merkletree.Hash{8})}, r[0])
|
||||
assert.Equal(t, merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{8}, &merkletree.Hash{9})}, r[1])
|
||||
assert.Equal(t, merkletree.KV{K: []byte{3}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{9}, &merkletree.Hash{10})}, r[2])
|
||||
}
|
||||
|
||||
// TestConcatTx checks that the implementation of the db.Storage interface
|
||||
// behaves as expected
|
||||
func TestConcatTx(t *testing.T, sto db.Storage) {
|
||||
func TestConcatTx(t *testing.T, sto merkletree.Storage) {
|
||||
k := []byte{9}
|
||||
|
||||
sto1 := sto.WithPrefix([]byte{1})
|
||||
@@ -145,13 +159,13 @@ func TestConcatTx(t *testing.T, sto db.Storage) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = sto1tx.Put(k, []byte{4, 5, 6})
|
||||
err = sto1tx.Put(k, merkletree.NewNodeLeaf(&merkletree.Hash{4, 5, 6}, &merkletree.Hash{7, 8, 9}))
|
||||
assert.Nil(t, err)
|
||||
sto2tx, err := sto2.NewTx()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = sto2tx.Put(k, []byte{8, 9})
|
||||
err = sto2tx.Put(k, merkletree.NewNodeLeaf(&merkletree.Hash{8, 9}, &merkletree.Hash{10, 11}))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = sto1tx.Add(sto2tx)
|
||||
@@ -162,50 +176,50 @@ func TestConcatTx(t *testing.T, sto db.Storage) {
|
||||
|
||||
v1, err := sto1.Get(k)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, v1, []byte{4, 5, 6})
|
||||
assert.Equal(t, v1, merkletree.NewNodeLeaf(&merkletree.Hash{4, 5, 6}, &merkletree.Hash{7, 8, 9}))
|
||||
|
||||
v2, err := sto2.Get(k)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, v2, []byte{8, 9})
|
||||
assert.Equal(t, v2, merkletree.NewNodeLeaf(&merkletree.Hash{8, 9}, &merkletree.Hash{10, 11}))
|
||||
}
|
||||
|
||||
// TestList checks that the implementation of the db.Storage interface behaves
|
||||
// as expected
|
||||
func TestList(t *testing.T, sto db.Storage) {
|
||||
func TestList(t *testing.T, sto merkletree.Storage) {
|
||||
sto1 := sto.WithPrefix([]byte{1})
|
||||
r1, err := sto1.List(100)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(r1))
|
||||
|
||||
sto1tx, _ := sto1.NewTx()
|
||||
err = sto1tx.Put([]byte{1}, []byte{4})
|
||||
err = sto1tx.Put([]byte{1}, merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5}))
|
||||
assert.Nil(t, err)
|
||||
err = sto1tx.Put([]byte{2}, []byte{5})
|
||||
err = sto1tx.Put([]byte{2}, merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6}))
|
||||
assert.Nil(t, err)
|
||||
err = sto1tx.Put([]byte{3}, []byte{6})
|
||||
err = sto1tx.Put([]byte{3}, merkletree.NewNodeMiddle(&merkletree.Hash{6}, &merkletree.Hash{7}))
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, sto1tx.Commit())
|
||||
|
||||
sto2 := sto.WithPrefix([]byte{2})
|
||||
sto2tx, _ := sto2.NewTx()
|
||||
err = sto2tx.Put([]byte{1}, []byte{7})
|
||||
err = sto2tx.Put([]byte{1}, merkletree.NewNodeMiddle(&merkletree.Hash{7}, &merkletree.Hash{8}))
|
||||
assert.Nil(t, err)
|
||||
err = sto2tx.Put([]byte{2}, []byte{8})
|
||||
err = sto2tx.Put([]byte{2}, merkletree.NewNodeMiddle(&merkletree.Hash{8}, &merkletree.Hash{9}))
|
||||
assert.Nil(t, err)
|
||||
err = sto2tx.Put([]byte{3}, []byte{9})
|
||||
err = sto2tx.Put([]byte{3}, merkletree.NewNodeMiddle(&merkletree.Hash{9}, &merkletree.Hash{10}))
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, sto2tx.Commit())
|
||||
|
||||
r, err := sto1.List(100)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(r))
|
||||
assert.Equal(t, r[0], db.KV{K: []byte{1}, V: []byte{4}})
|
||||
assert.Equal(t, r[1], db.KV{K: []byte{2}, V: []byte{5}})
|
||||
assert.Equal(t, r[2], db.KV{K: []byte{3}, V: []byte{6}})
|
||||
assert.Equal(t, r[0], merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5})})
|
||||
assert.Equal(t, r[1], merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6})})
|
||||
assert.Equal(t, r[2], merkletree.KV{K: []byte{3}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{6}, &merkletree.Hash{7})})
|
||||
|
||||
r, err = sto1.List(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(r))
|
||||
assert.Equal(t, r[0], db.KV{K: []byte{1}, V: []byte{4}})
|
||||
assert.Equal(t, r[1], db.KV{K: []byte{2}, V: []byte{5}})
|
||||
assert.Equal(t, r[0], merkletree.KV{K: []byte{1}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{4}, &merkletree.Hash{5})})
|
||||
assert.Equal(t, r[1], merkletree.KV{K: []byte{2}, V: *merkletree.NewNodeMiddle(&merkletree.Hash{5}, &merkletree.Hash{6})})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user