From c26b23c544d96eb1ed3a3c9d0e3a65c16fd32295 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 12 Apr 2021 21:26:26 +0200 Subject: [PATCH] Add Mutex, integrate tx into Tree struct --- tree.go | 94 ++++++++++++++++++++++++++++++---------------------- tree_test.go | 31 ++++++++++++++++- 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/tree.go b/tree.go index 23d80bc..8194a6b 100644 --- a/tree.go +++ b/tree.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "math" + "sync" "sync/atomic" "time" @@ -46,6 +47,8 @@ var ( // Tree defines the struct that implements the MerkleTree functionalities type Tree struct { + sync.RWMutex + tx db.Tx db db.Storage lastAccess int64 // in unix time maxLevels int @@ -60,7 +63,7 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} t.updateAccessTime() - root, err := t.dbGet(nil, dbKeyRoot) + root, err := t.dbGet(dbKeyRoot) if err == db.ErrNotFound { // store new root 0 tx, err := t.db.NewTx() @@ -106,24 +109,28 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { len(keys), len(values)) } - tx, err := t.db.NewTx() + t.Lock() + defer t.Unlock() + + var err error + t.tx, err = t.db.NewTx() if err != nil { return nil, err } var indexes []int for i := 0; i < len(keys); i++ { - tx, err = t.add(tx, keys[i], values[i]) + err = t.add(keys[i], values[i]) if err != nil { indexes = append(indexes, i) } } // store root to db - if err := tx.Put(dbKeyRoot, t.root); err != nil { + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { return indexes, err } - if err := tx.Commit(); err != nil { + if err := t.tx.Commit(); err != nil { return nil, err } return indexes, nil @@ -134,23 +141,28 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { // compatibility). func (t *Tree) Add(k, v []byte) error { t.updateAccessTime() - tx, err := t.db.NewTx() + + t.Lock() + defer t.Unlock() + + var err error + t.tx, err = t.db.NewTx() if err != nil { return err } - tx, err = t.add(tx, k, v) + err = t.add(k, v) if err != nil { return err } // store root to db - if err := tx.Put(dbKeyRoot, t.root); err != nil { + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { return err } - return tx.Commit() + return t.tx.Commit() } -func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) { +func (t *Tree) add(k, v []byte) error { // TODO check validity of key & value (for the Tree.HashFunction type) keyPath := make([]byte, t.hashFunction.Len()) @@ -159,36 +171,36 @@ func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) { path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(tx, k, t.root, siblings, path, 0, false) + _, _, siblings, err := t.down(k, t.root, siblings, path, 0, false) if err != nil { - return tx, err + return err } leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v) if err != nil { - return tx, err + return err } - if err := tx.Put(leafKey, leafValue); err != nil { - return tx, err + if err := t.tx.Put(leafKey, leafValue); err != nil { + return err } // go up to the root if len(siblings) == 0 { t.root = leafKey - return tx, nil + return nil } - root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1) + root, err := t.up(leafKey, siblings, path, len(siblings)-1) if err != nil { - return tx, err + return err } t.root = root - return tx, nil + return nil } // down goes down to the leaf recursively -func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte, +func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int, getLeaf bool) ( []byte, []byte, [][]byte, error) { if l > t.maxLevels-1 { @@ -201,7 +213,7 @@ func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte, // empty value return currKey, emptyValue, siblings, nil } - currValue, err = t.dbGet(tx, currKey) + currValue, err = t.dbGet(currKey) if err != nil { return nil, nil, nil, err } @@ -244,12 +256,12 @@ func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte, // right lChild, rChild := readIntermediateChilds(currValue) siblings = append(siblings, lChild) - return t.down(tx, newKey, rChild, siblings, path, l+1, getLeaf) + return t.down(newKey, rChild, siblings, path, l+1, getLeaf) } // left lChild, rChild := readIntermediateChilds(currValue) siblings = append(siblings, rChild) - return t.down(tx, newKey, lChild, siblings, path, l+1, getLeaf) + return t.down(newKey, lChild, siblings, path, l+1, getLeaf) default: return nil, nil, nil, fmt.Errorf("invalid value") } @@ -281,7 +293,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, } // up goes up recursively updating the intermediate nodes -func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) { +func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) { var k, v []byte var err error if path[l] { @@ -296,7 +308,7 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ( } } // store k-v to db - if err = tx.Put(k, v); err != nil { + if err = t.tx.Put(k, v); err != nil { return nil, err } @@ -305,7 +317,7 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ( return k, nil } - return t.up(tx, k, siblings, path, l-1) + return t.up(k, siblings, path, l-1) } func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { @@ -377,7 +389,11 @@ func getPath(numLevels int, k []byte) []bool { func (t *Tree) Update(k, v []byte) error { t.updateAccessTime() - tx, err := t.db.NewTx() + t.Lock() + defer t.Unlock() + + var err error + t.tx, err = t.db.NewTx() if err != nil { return err } @@ -387,7 +403,7 @@ func (t *Tree) Update(k, v []byte) error { path := getPath(t.maxLevels, keyPath) var siblings [][]byte - _, valueAtBottom, siblings, err := t.down(tx, k, t.root, siblings, path, 0, true) + _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true) if err != nil { return err } @@ -401,26 +417,26 @@ func (t *Tree) Update(k, v []byte) error { return err } - if err := tx.Put(leafKey, leafValue); err != nil { + if err := t.tx.Put(leafKey, leafValue); err != nil { return err } // go up to the root if len(siblings) == 0 { t.root = leafKey - return tx.Commit() + return t.tx.Commit() } - root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1) + root, err := t.up(leafKey, siblings, path, len(siblings)-1) if err != nil { return err } t.root = root // store root to db - if err := tx.Put(dbKeyRoot, t.root); err != nil { + if err := t.tx.Put(dbKeyRoot, t.root); err != nil { return err } - return tx.Commit() + return t.tx.Commit() } // GenProof generates a MerkleTree proof for the given key. If the key exists in @@ -434,7 +450,7 @@ func (t *Tree) GenProof(k []byte) ([]byte, error) { path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, value, siblings, err := t.down(nil, k, t.root, siblings, path, 0, true) + _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true) if err != nil { return nil, err } @@ -533,7 +549,7 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, value, _, err := t.down(nil, k, t.root, siblings, path, 0, true) + _, value, _, err := t.down(k, t.root, siblings, path, 0, true) if err != nil { return nil, nil, err } @@ -581,13 +597,13 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, return false, nil } -func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) { +func (t *Tree) dbGet(k []byte) ([]byte, error) { v, err := t.db.Get(k) if err == nil { return v, nil } - if tx != nil { - return tx.Get(k) + if t.tx != nil { + return t.tx.Get(k) } return nil, db.ErrNotFound } @@ -600,7 +616,7 @@ func (t *Tree) Iterate(f func([]byte, []byte)) error { } func (t *Tree) iter(k []byte, f func([]byte, []byte)) error { - v, err := t.dbGet(nil, k) + v, err := t.dbGet(k) if err != nil { return err } diff --git a/tree_test.go b/tree_test.go index bb67954..6412e72 100644 --- a/tree_test.go +++ b/tree_test.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "math/big" "testing" + "time" qt "github.com/frankban/quicktest" "github.com/iden3/go-merkletree/db/memory" @@ -194,7 +195,7 @@ func TestUpdate(t *testing.T) { c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11))) } -func TestAux(t *testing.T) { +func TestAux(t *testing.T) { // TMP c := qt.New(t) tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) c.Assert(err, qt.IsNil) @@ -293,6 +294,34 @@ func TestDumpAndImportDump(t *testing.T) { "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08") } +func TestRWMutex(t *testing.T) { + c := qt.New(t) + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() + + var keys, values [][]byte + for i := 0; i < 1000; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(0)) + keys = append(keys, k) + values = append(values, v) + } + go func() { + _, err = tree.AddBatch(keys, values) + if err != nil { + panic(err) + } + }() + + time.Sleep(500 * time.Millisecond) + k := BigIntToBytes(big.NewInt(int64(99999))) + v := BigIntToBytes(big.NewInt(int64(99999))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } +} + func BenchmarkAdd(b *testing.B) { // prepare inputs var ks, vs [][]byte