diff --git a/merkletree.go b/merkletree.go index 17e5b53..67cfbd9 100644 --- a/merkletree.go +++ b/merkletree.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "math/big" "sync" @@ -491,3 +492,95 @@ func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) { } return midKey, nil } + +// walk is a helper recursive function to iterate over all tree branches +func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error { + n, err := mt.GetNode(key) + if err != nil { + return err + } + switch n.Type { + case NodeTypeEmpty: + f(n) + case NodeTypeLeaf: + f(n) + case NodeTypeMiddle: + f(n) + if err := mt.walk(n.ChildL, f); err != nil { + return err + } + if err := mt.walk(n.ChildR, f); err != nil { + return err + } + default: + return ErrInvalidNodeFound + } + return nil +} + +// Walk iterates over all the branches of a MerkleTree with the given rootKey +// if rootKey is nil, it will get the current RootKey of the current state of the MerkleTree. +// For each node, it calls the f function given in the parameters. +// See some examples of the Walk function usage in the merkletree_test.go +// test functions: TestMTWalk, TestMTWalkGraphViz, TestMTWalkDumpClaims +func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error { + if rootKey == nil { + rootKey = mt.Root() + } + err := mt.walk(rootKey, f) + return err +} + +// GraphViz uses Walk function to generate a string GraphViz representation of the +// tree and writes it to w +func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error { + fmt.Fprintf(w, `digraph hierarchy { +node [fontname=Monospace,fontsize=10,shape=box] +`) + cnt := 0 + var errIn error + err := mt.Walk(rootKey, func(n *Node) { + k, err := n.Key() + if err != nil { + errIn = err + } + switch n.Type { + case NodeTypeEmpty: + case NodeTypeLeaf: + fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.BigInt().String()) + case NodeTypeMiddle: + lr := [2]string{n.ChildL.BigInt().String(), n.ChildR.BigInt().String()} + for i := range lr { + if lr[i] == "0" { + lr[i] = fmt.Sprintf("empty%v", cnt) + fmt.Fprintf(w, "\"%v\" [style=dashed,label=0];\n", lr[i]) + cnt++ + } + } + fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.BigInt().String(), lr[0], lr[1]) + default: + } + }) + fmt.Fprintf(w, "}\n") + if errIn != nil { + return errIn + } + return err +} + +// PrintGraphViz prints directly the GraphViz() output +func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error { + if rootKey == nil { + rootKey = mt.Root() + } + w := bytes.NewBufferString("") + fmt.Fprintf(w, "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n") + err := mt.GraphViz(w, nil) + if err != nil { + return err + } + fmt.Fprintf(w, "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n") + + fmt.Println(w) + return nil +}