mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-08 15:01:29 +01:00
Add Mutex, integrate tx into Tree struct
This commit is contained in:
94
tree.go
94
tree.go
@@ -17,6 +17,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -46,6 +47,8 @@ var (
|
|||||||
|
|
||||||
// Tree defines the struct that implements the MerkleTree functionalities
|
// Tree defines the struct that implements the MerkleTree functionalities
|
||||||
type Tree struct {
|
type Tree struct {
|
||||||
|
sync.RWMutex
|
||||||
|
tx db.Tx
|
||||||
db db.Storage
|
db db.Storage
|
||||||
lastAccess int64 // in unix time
|
lastAccess int64 // in unix time
|
||||||
maxLevels int
|
maxLevels int
|
||||||
@@ -60,7 +63,7 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error
|
|||||||
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
|
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
|
||||||
t.updateAccessTime()
|
t.updateAccessTime()
|
||||||
|
|
||||||
root, err := t.dbGet(nil, dbKeyRoot)
|
root, err := t.dbGet(dbKeyRoot)
|
||||||
if err == db.ErrNotFound {
|
if err == db.ErrNotFound {
|
||||||
// store new root 0
|
// store new root 0
|
||||||
tx, err := t.db.NewTx()
|
tx, err := t.db.NewTx()
|
||||||
@@ -106,24 +109,28 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
|
|||||||
len(keys), len(values))
|
len(keys), len(values))
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := t.db.NewTx()
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
t.tx, err = t.db.NewTx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var indexes []int
|
var indexes []int
|
||||||
for i := 0; i < len(keys); i++ {
|
for i := 0; i < len(keys); i++ {
|
||||||
tx, err = t.add(tx, keys[i], values[i])
|
err = t.add(keys[i], values[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
indexes = append(indexes, i)
|
indexes = append(indexes, i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// store root to db
|
// store root to db
|
||||||
if err := tx.Put(dbKeyRoot, t.root); err != nil {
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||||
return indexes, err
|
return indexes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := t.tx.Commit(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return indexes, nil
|
return indexes, nil
|
||||||
@@ -134,23 +141,28 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
|
|||||||
// compatibility).
|
// compatibility).
|
||||||
func (t *Tree) Add(k, v []byte) error {
|
func (t *Tree) Add(k, v []byte) error {
|
||||||
t.updateAccessTime()
|
t.updateAccessTime()
|
||||||
tx, err := t.db.NewTx()
|
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
t.tx, err = t.db.NewTx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err = t.add(tx, k, v)
|
err = t.add(k, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// store root to db
|
// store root to db
|
||||||
if err := tx.Put(dbKeyRoot, t.root); err != nil {
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return t.tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) {
|
func (t *Tree) add(k, v []byte) error {
|
||||||
// TODO check validity of key & value (for the Tree.HashFunction type)
|
// TODO check validity of key & value (for the Tree.HashFunction type)
|
||||||
|
|
||||||
keyPath := make([]byte, t.hashFunction.Len())
|
keyPath := make([]byte, t.hashFunction.Len())
|
||||||
@@ -159,36 +171,36 @@ func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) {
|
|||||||
path := getPath(t.maxLevels, keyPath)
|
path := getPath(t.maxLevels, keyPath)
|
||||||
// go down to the leaf
|
// go down to the leaf
|
||||||
var siblings [][]byte
|
var siblings [][]byte
|
||||||
_, _, siblings, err := t.down(tx, k, t.root, siblings, path, 0, false)
|
_, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tx, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
|
leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tx, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Put(leafKey, leafValue); err != nil {
|
if err := t.tx.Put(leafKey, leafValue); err != nil {
|
||||||
return tx, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// go up to the root
|
// go up to the root
|
||||||
if len(siblings) == 0 {
|
if len(siblings) == 0 {
|
||||||
t.root = leafKey
|
t.root = leafKey
|
||||||
return tx, nil
|
return nil
|
||||||
}
|
}
|
||||||
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
|
root, err := t.up(leafKey, siblings, path, len(siblings)-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tx, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.root = root
|
t.root = root
|
||||||
return tx, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// down goes down to the leaf recursively
|
// down goes down to the leaf recursively
|
||||||
func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
|
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
|
||||||
path []bool, l int, getLeaf bool) (
|
path []bool, l int, getLeaf bool) (
|
||||||
[]byte, []byte, [][]byte, error) {
|
[]byte, []byte, [][]byte, error) {
|
||||||
if l > t.maxLevels-1 {
|
if l > t.maxLevels-1 {
|
||||||
@@ -201,7 +213,7 @@ func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
|
|||||||
// empty value
|
// empty value
|
||||||
return currKey, emptyValue, siblings, nil
|
return currKey, emptyValue, siblings, nil
|
||||||
}
|
}
|
||||||
currValue, err = t.dbGet(tx, currKey)
|
currValue, err = t.dbGet(currKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -244,12 +256,12 @@ func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
|
|||||||
// right
|
// right
|
||||||
lChild, rChild := readIntermediateChilds(currValue)
|
lChild, rChild := readIntermediateChilds(currValue)
|
||||||
siblings = append(siblings, lChild)
|
siblings = append(siblings, lChild)
|
||||||
return t.down(tx, newKey, rChild, siblings, path, l+1, getLeaf)
|
return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
|
||||||
}
|
}
|
||||||
// left
|
// left
|
||||||
lChild, rChild := readIntermediateChilds(currValue)
|
lChild, rChild := readIntermediateChilds(currValue)
|
||||||
siblings = append(siblings, rChild)
|
siblings = append(siblings, rChild)
|
||||||
return t.down(tx, newKey, lChild, siblings, path, l+1, getLeaf)
|
return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
|
||||||
default:
|
default:
|
||||||
return nil, nil, nil, fmt.Errorf("invalid value")
|
return nil, nil, nil, fmt.Errorf("invalid value")
|
||||||
}
|
}
|
||||||
@@ -281,7 +293,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// up goes up recursively updating the intermediate nodes
|
// up goes up recursively updating the intermediate nodes
|
||||||
func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
|
func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
|
||||||
var k, v []byte
|
var k, v []byte
|
||||||
var err error
|
var err error
|
||||||
if path[l] {
|
if path[l] {
|
||||||
@@ -296,7 +308,7 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// store k-v to db
|
// store k-v to db
|
||||||
if err = tx.Put(k, v); err != nil {
|
if err = t.tx.Put(k, v); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -305,7 +317,7 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
|
|||||||
return k, nil
|
return k, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.up(tx, k, siblings, path, l-1)
|
return t.up(k, siblings, path, l-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
|
func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
|
||||||
@@ -377,7 +389,11 @@ func getPath(numLevels int, k []byte) []bool {
|
|||||||
func (t *Tree) Update(k, v []byte) error {
|
func (t *Tree) Update(k, v []byte) error {
|
||||||
t.updateAccessTime()
|
t.updateAccessTime()
|
||||||
|
|
||||||
tx, err := t.db.NewTx()
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
t.tx, err = t.db.NewTx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -387,7 +403,7 @@ func (t *Tree) Update(k, v []byte) error {
|
|||||||
path := getPath(t.maxLevels, keyPath)
|
path := getPath(t.maxLevels, keyPath)
|
||||||
|
|
||||||
var siblings [][]byte
|
var siblings [][]byte
|
||||||
_, valueAtBottom, siblings, err := t.down(tx, k, t.root, siblings, path, 0, true)
|
_, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -401,26 +417,26 @@ func (t *Tree) Update(k, v []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Put(leafKey, leafValue); err != nil {
|
if err := t.tx.Put(leafKey, leafValue); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// go up to the root
|
// go up to the root
|
||||||
if len(siblings) == 0 {
|
if len(siblings) == 0 {
|
||||||
t.root = leafKey
|
t.root = leafKey
|
||||||
return tx.Commit()
|
return t.tx.Commit()
|
||||||
}
|
}
|
||||||
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
|
root, err := t.up(leafKey, siblings, path, len(siblings)-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.root = root
|
t.root = root
|
||||||
// store root to db
|
// store root to db
|
||||||
if err := tx.Put(dbKeyRoot, t.root); err != nil {
|
if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return t.tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenProof generates a MerkleTree proof for the given key. If the key exists in
|
// GenProof generates a MerkleTree proof for the given key. If the key exists in
|
||||||
@@ -434,7 +450,7 @@ func (t *Tree) GenProof(k []byte) ([]byte, error) {
|
|||||||
path := getPath(t.maxLevels, keyPath)
|
path := getPath(t.maxLevels, keyPath)
|
||||||
// go down to the leaf
|
// go down to the leaf
|
||||||
var siblings [][]byte
|
var siblings [][]byte
|
||||||
_, value, siblings, err := t.down(nil, k, t.root, siblings, path, 0, true)
|
_, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -533,7 +549,7 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
|
|||||||
path := getPath(t.maxLevels, keyPath)
|
path := getPath(t.maxLevels, keyPath)
|
||||||
// go down to the leaf
|
// go down to the leaf
|
||||||
var siblings [][]byte
|
var siblings [][]byte
|
||||||
_, value, _, err := t.down(nil, k, t.root, siblings, path, 0, true)
|
_, value, _, err := t.down(k, t.root, siblings, path, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -581,13 +597,13 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) {
|
func (t *Tree) dbGet(k []byte) ([]byte, error) {
|
||||||
v, err := t.db.Get(k)
|
v, err := t.db.Get(k)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
if tx != nil {
|
if t.tx != nil {
|
||||||
return tx.Get(k)
|
return t.tx.Get(k)
|
||||||
}
|
}
|
||||||
return nil, db.ErrNotFound
|
return nil, db.ErrNotFound
|
||||||
}
|
}
|
||||||
@@ -600,7 +616,7 @@ func (t *Tree) Iterate(f func([]byte, []byte)) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
|
func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
|
||||||
v, err := t.dbGet(nil, k)
|
v, err := t.dbGet(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
31
tree_test.go
31
tree_test.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
qt "github.com/frankban/quicktest"
|
qt "github.com/frankban/quicktest"
|
||||||
"github.com/iden3/go-merkletree/db/memory"
|
"github.com/iden3/go-merkletree/db/memory"
|
||||||
@@ -194,7 +195,7 @@ func TestUpdate(t *testing.T) {
|
|||||||
c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11)))
|
c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAux(t *testing.T) {
|
func TestAux(t *testing.T) { // TMP
|
||||||
c := qt.New(t)
|
c := qt.New(t)
|
||||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||||
c.Assert(err, qt.IsNil)
|
c.Assert(err, qt.IsNil)
|
||||||
@@ -293,6 +294,34 @@ func TestDumpAndImportDump(t *testing.T) {
|
|||||||
"0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08")
|
"0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRWMutex(t *testing.T) {
|
||||||
|
c := qt.New(t)
|
||||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||||
|
c.Assert(err, qt.IsNil)
|
||||||
|
defer tree.db.Close()
|
||||||
|
|
||||||
|
var keys, values [][]byte
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||||
|
v := BigIntToBytes(big.NewInt(0))
|
||||||
|
keys = append(keys, k)
|
||||||
|
values = append(values, v)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
_, err = tree.AddBatch(keys, values)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
k := BigIntToBytes(big.NewInt(int64(99999)))
|
||||||
|
v := BigIntToBytes(big.NewInt(int64(99999)))
|
||||||
|
if err := tree.Add(k, v); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkAdd(b *testing.B) {
|
func BenchmarkAdd(b *testing.B) {
|
||||||
// prepare inputs
|
// prepare inputs
|
||||||
var ks, vs [][]byte
|
var ks, vs [][]byte
|
||||||
|
|||||||
Reference in New Issue
Block a user