Browse Source

Add computeHashes at virtual tree

master
arnaucube 3 years ago
parent
commit
66f6ae14bb
2 changed files with 118 additions and 6 deletions
  1. +67
    -6
      vt.go
  2. +51
    -0
      vt_test.go

+ 67
- 6
vt.go

@ -2,7 +2,6 @@
// without computing any hash. With the idea of once all the leafs are placed in
// their positions, the hashes can be computed, avoiding computing a node hash
// more than one time.
//nolint:unused,deadcode
package arbo
import (
@ -18,7 +17,7 @@ type node struct {
k []byte
v []byte
path []bool
// h []byte
h []byte
}
type params struct {
@ -60,6 +59,18 @@ func (t *vt) add(k, v []byte) error {
return nil
}
// computeHashes should be called after all the vt.add is used, once all the
// leafs are in the tree
func (t *vt) computeHashes() ([][2][]byte, error) {
var pairs [][2][]byte
var err error
pairs, err = t.root.computeHashes(t.params, pairs)
if err != nil {
return pairs, err
}
return pairs, nil
}
func newLeafNode(p *params, k, v []byte) *node {
keyPath := make([]byte, p.hashFunction.Len())
copy(keyPath[:], k)
@ -150,10 +161,8 @@ func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *nod
return fmt.Errorf("max virtual level %d", currLvl)
}
// if oldLeaf.path[currLvl+1] != newLeaf.path[currLvl+1] {
if oldLeaf.path[currLvl] != newLeaf.path[currLvl] {
// reached divergence in next level
// if newLeaf.path[currLvl+1] {
if newLeaf.path[currLvl] {
n.l = oldLeaf
n.r = newLeaf
@ -181,10 +190,60 @@ func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *nod
return nil
}
func (n *node) computeHashes() ([]kv, error) {
return nil, nil
// returns an array of key-values to store in the db
func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) {
if pairs == nil {
pairs = [][2][]byte{}
}
var err error
t := n.typ()
switch t {
case vtLeaf:
leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v)
if err != nil {
return pairs, err
}
n.h = leafKey
kv := [2][]byte{leafKey, leafValue}
pairs = append(pairs, kv)
case vtMid:
if n.l != nil {
pairs, err = n.l.computeHashes(p, pairs)
if err != nil {
return pairs, err
}
} else {
n.l = &node{
h: p.emptyHash,
}
}
if n.r != nil {
pairs, err = n.r.computeHashes(p, pairs)
if err != nil {
return pairs, err
}
} else {
n.r = &node{
h: p.emptyHash,
}
}
// once the sub nodes are computed, can compute the current node
// hash
k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h)
if err != nil {
return nil, err
}
n.h = k
kv := [2][]byte{k, v}
pairs = append(pairs, kv)
default:
return nil, fmt.Errorf("ERR TMP") // TODO
}
return pairs, nil
}
//nolint:unused
func (t *vt) graphviz(w io.Writer) error {
fmt.Fprintf(w, `digraph hierarchy {
node [fontname=Monospace,fontsize=10,shape=box]
@ -196,6 +255,7 @@ node [fontname=Monospace,fontsize=10,shape=box]
return nil
}
//nolint:unused
func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
nChars := 4 // TODO move to global constant
if n == nil {
@ -254,6 +314,7 @@ func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) {
return nEmpties, nil
}
//nolint:unused
func (t *vt) printGraphviz() error {
w := bytes.NewBufferString("")
fmt.Fprintf(w,

+ 51
- 0
vt_test.go

@ -0,0 +1,51 @@
package arbo
import (
"math/big"
"testing"
qt "github.com/frankban/quicktest"
)
func TestVirtualTree(t *testing.T) {
c := qt.New(t)
vTree := newVT(10, HashFunctionSha256)
c.Assert(vTree.root, qt.IsNil)
k := BigIntToBytes(big.NewInt(1))
v := BigIntToBytes(big.NewInt(2))
err := vTree.add(k, v)
c.Assert(err, qt.IsNil)
// check values
c.Assert(vTree.root.k, qt.DeepEquals, k)
c.Assert(vTree.root.v, qt.DeepEquals, v)
// compute hashes
pairs, err := vTree.computeHashes()
c.Assert(err, qt.IsNil)
c.Assert(len(pairs), qt.Equals, 1)
rootBI := BytesToBigInt(vTree.root.h)
c.Assert(rootBI.String(), qt.Equals,
"46910109172468462938850740851377282682950237270676610513794735904325820156367")
k = BigIntToBytes(big.NewInt(33))
v = BigIntToBytes(big.NewInt(44))
err = vTree.add(k, v)
c.Assert(err, qt.IsNil)
// compute hashes
pairs, err = vTree.computeHashes()
c.Assert(err, qt.IsNil)
c.Assert(len(pairs), qt.Equals, 8)
// err = vTree.printGraphviz()
// c.Assert(err, qt.IsNil)
rootBI = BytesToBigInt(vTree.root.h)
c.Assert(rootBI.String(), qt.Equals,
"59481735341404520835410489183267411392292882901306595567679529387376287440550")
c.Assert(err, qt.IsNil)
}

Loading…
Cancel
Save