From 66f6ae14bbd0eb1dbb79b5158a2312b13b20d6e4 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 17 May 2021 22:16:08 +0200 Subject: [PATCH] Add computeHashes at virtual tree --- vt.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++++----- vt_test.go | 51 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 vt_test.go diff --git a/vt.go b/vt.go index 5970db8..ae5d80b 100644 --- a/vt.go +++ b/vt.go @@ -2,7 +2,6 @@ // without computing any hash. With the idea of once all the leafs are placed in // their positions, the hashes can be computed, avoiding computing a node hash // more than one time. -//nolint:unused,deadcode package arbo import ( @@ -18,7 +17,7 @@ type node struct { k []byte v []byte path []bool - // h []byte + h []byte } type params struct { @@ -60,6 +59,18 @@ func (t *vt) add(k, v []byte) error { return nil } +// computeHashes should be called after all the vt.add is used, once all the +// leafs are in the tree +func (t *vt) computeHashes() ([][2][]byte, error) { + var pairs [][2][]byte + var err error + pairs, err = t.root.computeHashes(t.params, pairs) + if err != nil { + return pairs, err + } + return pairs, nil +} + func newLeafNode(p *params, k, v []byte) *node { keyPath := make([]byte, p.hashFunction.Len()) copy(keyPath[:], k) @@ -150,10 +161,8 @@ func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *nod return fmt.Errorf("max virtual level %d", currLvl) } - // if oldLeaf.path[currLvl+1] != newLeaf.path[currLvl+1] { if oldLeaf.path[currLvl] != newLeaf.path[currLvl] { // reached divergence in next level - // if newLeaf.path[currLvl+1] { if newLeaf.path[currLvl] { n.l = oldLeaf n.r = newLeaf @@ -181,10 +190,60 @@ func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *nod return nil } -func (n *node) computeHashes() ([]kv, error) { - return nil, nil +// returns an array of key-values to store in the db +func (n *node) computeHashes(p *params, pairs [][2][]byte) ([][2][]byte, error) { + if pairs == nil { + pairs = [][2][]byte{} + } + var err error + t := n.typ() + switch t { + case vtLeaf: + leafKey, leafValue, err := newLeafValue(p.hashFunction, n.k, n.v) + if err != nil { + return pairs, err + } + n.h = leafKey + kv := [2][]byte{leafKey, leafValue} + pairs = append(pairs, kv) + case vtMid: + if n.l != nil { + pairs, err = n.l.computeHashes(p, pairs) + if err != nil { + return pairs, err + } + } else { + n.l = &node{ + h: p.emptyHash, + } + } + if n.r != nil { + pairs, err = n.r.computeHashes(p, pairs) + if err != nil { + return pairs, err + } + } else { + n.r = &node{ + h: p.emptyHash, + } + } + // once the sub nodes are computed, can compute the current node + // hash + k, v, err := newIntermediate(p.hashFunction, n.l.h, n.r.h) + if err != nil { + return nil, err + } + n.h = k + kv := [2][]byte{k, v} + pairs = append(pairs, kv) + default: + return nil, fmt.Errorf("ERR TMP") // TODO + } + + return pairs, nil } +//nolint:unused func (t *vt) graphviz(w io.Writer) error { fmt.Fprintf(w, `digraph hierarchy { node [fontname=Monospace,fontsize=10,shape=box] @@ -196,6 +255,7 @@ node [fontname=Monospace,fontsize=10,shape=box] return nil } +//nolint:unused func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) { nChars := 4 // TODO move to global constant if n == nil { @@ -254,6 +314,7 @@ func (n *node) graphviz(w io.Writer, p *params, nEmpties int) (int, error) { return nEmpties, nil } +//nolint:unused func (t *vt) printGraphviz() error { w := bytes.NewBufferString("") fmt.Fprintf(w, diff --git a/vt_test.go b/vt_test.go new file mode 100644 index 0000000..779d728 --- /dev/null +++ b/vt_test.go @@ -0,0 +1,51 @@ +package arbo + +import ( + "math/big" + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestVirtualTree(t *testing.T) { + c := qt.New(t) + vTree := newVT(10, HashFunctionSha256) + + c.Assert(vTree.root, qt.IsNil) + + k := BigIntToBytes(big.NewInt(1)) + v := BigIntToBytes(big.NewInt(2)) + err := vTree.add(k, v) + c.Assert(err, qt.IsNil) + + // check values + c.Assert(vTree.root.k, qt.DeepEquals, k) + c.Assert(vTree.root.v, qt.DeepEquals, v) + + // compute hashes + pairs, err := vTree.computeHashes() + c.Assert(err, qt.IsNil) + c.Assert(len(pairs), qt.Equals, 1) + + rootBI := BytesToBigInt(vTree.root.h) + c.Assert(rootBI.String(), qt.Equals, + "46910109172468462938850740851377282682950237270676610513794735904325820156367") + + k = BigIntToBytes(big.NewInt(33)) + v = BigIntToBytes(big.NewInt(44)) + err = vTree.add(k, v) + c.Assert(err, qt.IsNil) + + // compute hashes + pairs, err = vTree.computeHashes() + c.Assert(err, qt.IsNil) + c.Assert(len(pairs), qt.Equals, 8) + + // err = vTree.printGraphviz() + // c.Assert(err, qt.IsNil) + + rootBI = BytesToBigInt(vTree.root.h) + c.Assert(rootBI.String(), qt.Equals, + "59481735341404520835410489183267411392292882901306595567679529387376287440550") + c.Assert(err, qt.IsNil) +}