Browse Source

Add tree.AddBatch (using db.Tx)

master
arnaucube 3 years ago
parent
commit
cf572f628e
2 changed files with 101 additions and 35 deletions
  1. +79
    -33
      tree.go
  2. +22
    -2
      tree_test.go

+ 79
- 33
tree.go

@ -58,7 +58,7 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
t.updateAccessTime() t.updateAccessTime()
root, err := t.db.Get(dbKeyRoot)
root, err := t.dbGet(nil, dbKeyRoot)
if err == db.ErrNotFound { if err == db.ErrNotFound {
// store new root 0 // store new root 0
tx, err := t.db.NewTx() tx, err := t.db.NewTx()
@ -66,12 +66,10 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error
return nil, err return nil, err
} }
t.root = make([]byte, t.hashFunction.Len()) // empty t.root = make([]byte, t.hashFunction.Len()) // empty
err = tx.Put(dbKeyRoot, t.root)
if err != nil {
if err = tx.Put(dbKeyRoot, t.root); err != nil {
return nil, err return nil, err
} }
err = tx.Commit()
if err != nil {
if err = tx.Commit(); err != nil {
return nil, err return nil, err
} }
return &t, err return &t, err
@ -96,17 +94,59 @@ func (t *Tree) Root() []byte {
return t.root return t.root
} }
// AddBatch adds a batch of key-values to the Tree. This method is optimized to
// do some internal parallelization. Returns an array containing the indexes of
// the keys failed to add.
// AddBatch adds a batch of key-values to the Tree. This method will be
// optimized to do some internal parallelization. Returns an array containing
// the indexes of the keys failed to add.
func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
return nil, fmt.Errorf("unimplemented")
if len(keys) != len(values) {
return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
len(keys), len(values))
}
tx, err := t.db.NewTx()
if err != nil {
return nil, err
}
var indexes []int
for i := 0; i < len(keys); i++ {
tx, err = t.add(tx, keys[i], values[i])
if err != nil {
indexes = append(indexes, i)
}
}
// store root to db
if err := tx.Put(dbKeyRoot, t.root); err != nil {
return indexes, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return indexes, nil
} }
// Add inserts the key-value into the Tree.
// If the inputs come from a *big.Int, is expected that are represented by a
// Little-Endian byte array (for circom compatibility).
// Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
// is expected that are represented by a Little-Endian byte array (for circom
// compatibility).
func (t *Tree) Add(k, v []byte) error { func (t *Tree) Add(k, v []byte) error {
tx, err := t.db.NewTx()
if err != nil {
return err
}
tx, err = t.add(tx, k, v)
if err != nil {
return err
}
// store root to db
if err := tx.Put(dbKeyRoot, t.root); err != nil {
return err
}
return tx.Commit()
}
func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) {
// TODO check validity of key & value (for the Tree.HashFunction type) // TODO check validity of key & value (for the Tree.HashFunction type)
keyPath := make([]byte, t.hashFunction.Len()) keyPath := make([]byte, t.hashFunction.Len())
@ -115,41 +155,37 @@ func (t *Tree) Add(k, v []byte) error {
path := getPath(t.maxLevels, keyPath) path := getPath(t.maxLevels, keyPath)
// go down to the leaf // go down to the leaf
var siblings [][]byte var siblings [][]byte
_, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
_, _, siblings, err := t.down(tx, k, t.root, siblings, path, 0, false)
if err != nil { if err != nil {
return err
return tx, err
} }
leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v) leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
if err != nil { if err != nil {
return err
return tx, err
} }
tx, err := t.db.NewTx()
if err != nil {
return err
}
if err := tx.Put(leafKey, leafValue); err != nil { if err := tx.Put(leafKey, leafValue); err != nil {
return err
return tx, err
} }
// go up to the root // go up to the root
if len(siblings) == 0 { if len(siblings) == 0 {
t.root = leafKey t.root = leafKey
return tx.Commit()
return tx, nil
} }
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1) root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
if err != nil { if err != nil {
return err
return tx, err
} }
t.root = root t.root = root
// store root to db
return tx.Commit()
return tx, nil
} }
// down goes down to the leaf recursively // down goes down to the leaf recursively
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int, getLeaf bool) (
func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
path []bool, l int, getLeaf bool) (
[]byte, []byte, [][]byte, error) { []byte, []byte, [][]byte, error) {
if l > t.maxLevels-1 { if l > t.maxLevels-1 {
return nil, nil, nil, fmt.Errorf("max level") return nil, nil, nil, fmt.Errorf("max level")
@ -161,7 +197,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in
// empty value // empty value
return currKey, emptyValue, siblings, nil return currKey, emptyValue, siblings, nil
} }
currValue, err = t.db.Get(currKey)
currValue, err = t.dbGet(tx, currKey)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -204,12 +240,12 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in
// right // right
lChild, rChild := readIntermediateChilds(currValue) lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, lChild) siblings = append(siblings, lChild)
return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
return t.down(tx, newKey, rChild, siblings, path, l+1, getLeaf)
} }
// left // left
lChild, rChild := readIntermediateChilds(currValue) lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, rChild) siblings = append(siblings, rChild)
return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
return t.down(tx, newKey, lChild, siblings, path, l+1, getLeaf)
default: default:
return nil, nil, nil, fmt.Errorf("invalid value") return nil, nil, nil, fmt.Errorf("invalid value")
} }
@ -225,7 +261,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
} }
if oldPath[l] == newPath[l] { if oldPath[l] == newPath[l] {
emptyKey := make([]byte, t.hashFunction.Len()) // empty
emptyKey := make([]byte, t.hashFunction.Len())
siblings = append(siblings, emptyKey) siblings = append(siblings, emptyKey)
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
@ -256,8 +292,7 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
} }
} }
// store k-v to db // store k-v to db
err = tx.Put(k, v)
if err != nil {
if err = tx.Put(k, v); err != nil {
return nil, err return nil, err
} }
@ -343,7 +378,7 @@ func (t *Tree) GenProof(k []byte) ([]byte, error) {
path := getPath(t.maxLevels, keyPath) path := getPath(t.maxLevels, keyPath)
// go down to the leaf // go down to the leaf
var siblings [][]byte var siblings [][]byte
_, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
_, value, siblings, err := t.down(nil, k, t.root, siblings, path, 0, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -440,7 +475,7 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
path := getPath(t.maxLevels, keyPath) path := getPath(t.maxLevels, keyPath)
// go down to the leaf // go down to the leaf
var siblings [][]byte var siblings [][]byte
_, value, _, err := t.down(k, t.root, siblings, path, 0, true)
_, value, _, err := t.down(nil, k, t.root, siblings, path, 0, true)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -487,3 +522,14 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
} }
return false, nil return false, nil
} }
func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) {
v, err := t.db.Get(k)
if err == nil {
return v, nil
}
if tx != nil {
return tx.Get(k)
}
return nil, db.ErrNotFound
}

+ 22
- 2
tree_test.go

@ -58,11 +58,11 @@ func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) {
assert.Equal(t, testVectors[3], rootBI.String()) assert.Equal(t, testVectors[3], rootBI.String())
} }
func TestAdd1000(t *testing.T) {
func TestAddBatch(t *testing.T) {
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err) require.Nil(t, err)
defer tree.db.Close() defer tree.db.Close()
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
k := BigIntToBytes(big.NewInt(int64(i))) k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(0)) v := BigIntToBytes(big.NewInt(0))
@ -75,6 +75,26 @@ func TestAdd1000(t *testing.T) {
assert.Equal(t, assert.Equal(t,
"296519252211642170490407814696803112091039265640052570497930797516015811235", "296519252211642170490407814696803112091039265640052570497930797516015811235",
rootBI.String()) rootBI.String())
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree2.db.Close()
var keys, values [][]byte
for i := 0; i < 1000; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(0))
keys = append(keys, k)
values = append(values, v)
}
indexes, err := tree2.AddBatch(keys, values)
assert.Nil(t, err)
assert.Equal(t, 0, len(indexes))
rootBI = BytesToBigInt(tree2.Root())
assert.Equal(t,
"296519252211642170490407814696803112091039265640052570497930797516015811235",
rootBI.String())
} }
func TestAddDifferentOrder(t *testing.T) { func TestAddDifferentOrder(t *testing.T) {

Loading…
Cancel
Save