From 75e24244a1e3a8d2c1dbf2cd32a61d0355358b60 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Thu, 23 Jul 2020 22:27:38 +0200 Subject: [PATCH] Small fixes --- README.md | 6 ++++++ db/leveldb/leveldb.go | 24 ++++++++++++------------ db/memory/memory.go | 18 +++++++++--------- merkletree.go | 8 +++++++- merkletree_test.go | 1 + 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 424e881..7dd7952 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,12 @@ func TestExampleMerkleTree(t *testing.T) { assert.Nil(t, err) fmt.Println(mt.Root().String()) + v, err := mt.Get(key) + asseert.Equal(t, value, v) + + value = big.NewInt(3) + err = mt.Update(key, value) + proof, err := mt.GenerateProof(key, nil) assert.Nil(t, err) diff --git a/db/leveldb/leveldb.go b/db/leveldb/leveldb.go index 78e8ec8..a56061c 100644 --- a/db/leveldb/leveldb.go +++ b/db/leveldb/leveldb.go @@ -24,7 +24,7 @@ type LevelDbStorageTx struct { cache db.KvMap } -// NewLevelStorage returns a new LevelDbStorage +// NewLevelDbStorage returns a new LevelDbStorage func NewLevelDbStorage(path string, errorIfMissing bool) (*LevelDbStorage, error) { o := &opt.Options{ ErrorIfMissing: errorIfMissing, @@ -113,16 +113,16 @@ func (l *LevelDbStorage) Iterate(f func([]byte, []byte) (bool, error)) error { } // Get retreives a value from a key in the interface db.Tx -func (l *LevelDbStorageTx) Get(key []byte) ([]byte, error) { +func (tx *LevelDbStorageTx) Get(key []byte) ([]byte, error) { var err error - fullkey := db.Concat(l.prefix, key) + fullkey := db.Concat(tx.prefix, key) - if value, ok := l.cache.Get(fullkey); ok { + if value, ok := tx.cache.Get(fullkey); ok { return value, nil } - value, err := l.ldb.Get(fullkey, nil) + value, err := tx.ldb.Get(fullkey, nil) if err == errors.ErrNotFound { return nil, db.ErrNotFound } @@ -130,7 +130,7 @@ func (l *LevelDbStorageTx) Get(key []byte) ([]byte, error) { return value, err } -// Insert saves a key:value into the db.Storage +// Put saves a key:value into the db.Storage func (tx *LevelDbStorageTx) Put(k, v []byte) { tx.cache.Put(db.Concat(tx.prefix, k[:]), v) } @@ -144,19 +144,19 @@ func (tx *LevelDbStorageTx) Add(atx db.Tx) { } // Commit implements the method Commit of the interface db.Tx -func (l *LevelDbStorageTx) Commit() error { +func (tx *LevelDbStorageTx) Commit() error { var batch leveldb.Batch - for _, v := range l.cache { + for _, v := range tx.cache { batch.Put(v.K, v.V) } - l.cache = nil - return l.ldb.Write(&batch, nil) + tx.cache = nil + return tx.ldb.Write(&batch, nil) } // Close implements the method Close of the interface db.Tx -func (l *LevelDbStorageTx) Close() { - l.cache = nil +func (tx *LevelDbStorageTx) Close() { + tx.cache = nil } // Close implements the method Close of the interface db.Storage diff --git a/db/memory/memory.go b/db/memory/memory.go index 6301b56..76002e2 100644 --- a/db/memory/memory.go +++ b/db/memory/memory.go @@ -26,7 +26,7 @@ func NewMemoryStorage() *MemoryStorage { } // Info implements the method Info of the interface db.Storage -func (l *MemoryStorage) Info() string { +func (m *MemoryStorage) Info() string { return "in-memory" } @@ -41,21 +41,21 @@ func (m *MemoryStorage) NewTx() (db.Tx, error) { } // Get retreives a value from a key in the db.Storage -func (l *MemoryStorage) Get(key []byte) ([]byte, error) { - if v, ok := l.kv.Get(db.Concat(l.prefix, key[:])); ok { +func (m *MemoryStorage) Get(key []byte) ([]byte, error) { + if v, ok := m.kv.Get(db.Concat(m.prefix, key[:])); ok { return v, nil } return nil, db.ErrNotFound } // Iterate implements the method Iterate of the interface db.Storage -func (l *MemoryStorage) Iterate(f func([]byte, []byte) (bool, error)) error { +func (m *MemoryStorage) Iterate(f func([]byte, []byte) (bool, error)) error { kvs := make([]db.KV, 0) - for _, v := range l.kv { - if len(v.K) < len(l.prefix) || !bytes.Equal(v.K[:len(l.prefix)], l.prefix) { + for _, v := range m.kv { + if len(v.K) < len(m.prefix) || !bytes.Equal(v.K[:len(m.prefix)], m.prefix) { continue } - localkey := v.K[len(l.prefix):] + localkey := v.K[len(m.prefix):] kvs = append(kvs, db.KV{K: localkey, V: v.V}) } @@ -116,9 +116,9 @@ func (m *MemoryStorage) Close() { } // List implements the method List of the interface db.Storage -func (l *MemoryStorage) List(limit int) ([]db.KV, error) { +func (m *MemoryStorage) List(limit int) ([]db.KV, error) { ret := []db.KV{} - err := l.Iterate(func(key []byte, value []byte) (bool, error) { + err := m.Iterate(func(key []byte, value []byte) (bool, error) { ret = append(ret, db.KV{K: db.Clone(key), V: db.Clone(value)}) if len(ret) == limit { return false, nil diff --git a/merkletree.go b/merkletree.go index 66b702a..fc3a92a 100644 --- a/merkletree.go +++ b/merkletree.go @@ -139,6 +139,7 @@ func (mt *MerkleTree) Root() *Hash { return mt.rootKey } +// Snapshot returns a read-only copy of the MerkleTree func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) { mt.RLock() defer mt.RUnlock() @@ -383,6 +384,9 @@ func (mt *MerkleTree) Update(k, v *big.Int) error { // update leaf and upload to the root newNodeLeaf := NewNodeLeaf(kHash, vHash) _, err := mt.addNode(tx, newNodeLeaf) + if err != nil { + return err + } newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings) if err != nil { return err @@ -676,7 +680,7 @@ func (p *Proof) Bytes() []byte { return bs } -// SiblingsFromProof returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits. +// SiblingsFromProof returns all the siblings of the proof. func SiblingsFromProof(proof *Proof) []*Hash { sibIdx := 0 var siblings []*Hash @@ -691,10 +695,12 @@ func SiblingsFromProof(proof *Proof) []*Hash { return siblings } +// AllSiblings returns all the siblings of the proof. func (p *Proof) AllSiblings() []*Hash { return SiblingsFromProof(p) } +// AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits. func (p *Proof) AllSiblingsCircom(levels int) []*big.Int { siblings := p.AllSiblings() // Add the rest of empty levels to the siblings diff --git a/merkletree_test.go b/merkletree_test.go index 4825de8..7e15f8a 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -152,6 +152,7 @@ func TestUpdate(t *testing.T) { assert.Equal(t, big.NewInt(20), v) err = mt.Update(big.NewInt(10), big.NewInt(1024)) + assert.Nil(t, err) v, err = mt.Get(big.NewInt(10)) assert.Nil(t, err) assert.Equal(t, big.NewInt(1024), v)