|
|
@ -1,12 +1,14 @@ |
|
|
|
package merkletree |
|
|
|
|
|
|
|
import ( |
|
|
|
"bytes" |
|
|
|
"errors" |
|
|
|
"math/big" |
|
|
|
"sync" |
|
|
|
|
|
|
|
"github.com/iden3/go-iden3-core/common" |
|
|
|
"github.com/iden3/go-iden3-core/db" |
|
|
|
cryptoUtils "github.com/iden3/go-iden3-crypto/utils" |
|
|
|
) |
|
|
|
|
|
|
|
const ( |
|
|
@ -95,3 +97,206 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) { |
|
|
|
func (mt *MerkleTree) Root() *Hash { |
|
|
|
return mt.rootKey |
|
|
|
} |
|
|
|
|
|
|
|
func (mt *MerkleTree) Add(k, v *big.Int) error { |
|
|
|
// verify that the MerkleTree is writable
|
|
|
|
if !mt.writable { |
|
|
|
return ErrNotWritable |
|
|
|
} |
|
|
|
|
|
|
|
// verfy that the ElemBytes are valid and fit inside the Finite Field.
|
|
|
|
if !cryptoUtils.CheckBigIntInField(k) { |
|
|
|
return errors.New("Key not inside the Finite Field") |
|
|
|
} |
|
|
|
if !cryptoUtils.CheckBigIntInField(v) { |
|
|
|
return errors.New("Value not inside the Finite Field") |
|
|
|
} |
|
|
|
|
|
|
|
tx, err := mt.db.NewTx() |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
mt.Lock() |
|
|
|
defer mt.Unlock() |
|
|
|
|
|
|
|
kHash := NewHashFromBigInt(k) |
|
|
|
vHash := NewHashFromBigInt(v) |
|
|
|
newNodeLeaf := NewNodeLeaf(kHash, vHash) |
|
|
|
path := getPath(mt.maxLevels, kHash[:]) |
|
|
|
|
|
|
|
newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
mt.rootKey = newRootKey |
|
|
|
mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) |
|
|
|
|
|
|
|
if err := tx.Commit(); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
// pushLeaf recursively pushes an existing oldLeaf down until its path diverges
|
|
|
|
// from newLeaf, at which point both leafs are stored, all while updating the
|
|
|
|
// path.
|
|
|
|
func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node, |
|
|
|
lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) { |
|
|
|
if lvl > mt.maxLevels-2 { |
|
|
|
return nil, ErrReachedMaxLevel |
|
|
|
} |
|
|
|
var newNodeMiddle *Node |
|
|
|
if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
|
|
|
|
nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
if pathNewLeaf[lvl] { |
|
|
|
newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
|
|
|
|
} else { |
|
|
|
newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
|
|
|
|
} |
|
|
|
return mt.addNode(tx, newNodeMiddle) |
|
|
|
} else { |
|
|
|
oldLeafKey, err := oldLeaf.Key() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
newLeafKey, err := newLeaf.Key() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
if pathNewLeaf[lvl] { |
|
|
|
newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey) |
|
|
|
} else { |
|
|
|
newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey) |
|
|
|
} |
|
|
|
// We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
|
|
|
|
_, err = mt.addNode(tx, newLeaf) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
return mt.addNode(tx, newNodeMiddle) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// addLeaf recursively adds a newLeaf in the MT while updating the path.
|
|
|
|
func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash, |
|
|
|
lvl int, path []bool) (*Hash, error) { |
|
|
|
var err error |
|
|
|
var nextKey *Hash |
|
|
|
if lvl > mt.maxLevels-1 { |
|
|
|
return nil, ErrReachedMaxLevel |
|
|
|
} |
|
|
|
n, err := mt.GetNode(key) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
switch n.Type { |
|
|
|
case NodeTypeEmpty: |
|
|
|
// We can add newLeaf now
|
|
|
|
return mt.addNode(tx, newLeaf) |
|
|
|
case NodeTypeLeaf: |
|
|
|
nKey := n.Entry[0] |
|
|
|
// Check if leaf node found contains the leaf node we are trying to add
|
|
|
|
newLeafKey := newLeaf.Entry[0] |
|
|
|
if bytes.Equal(nKey[:], newLeafKey[:]) { |
|
|
|
return nil, ErrEntryIndexAlreadyExists |
|
|
|
} |
|
|
|
pathOldLeaf := getPath(mt.maxLevels, nKey[:]) |
|
|
|
// We need to push newLeaf down until its path diverges from n's path
|
|
|
|
return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf) |
|
|
|
case NodeTypeMiddle: |
|
|
|
// We need to go deeper, continue traversing the tree, left or right depending on path
|
|
|
|
var newNodeMiddle *Node |
|
|
|
if path[lvl] { |
|
|
|
nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
|
|
|
|
newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey) |
|
|
|
} else { |
|
|
|
nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
|
|
|
|
newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR) |
|
|
|
} |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
// Update the node to reflect the modified child
|
|
|
|
return mt.addNode(tx, newNodeMiddle) |
|
|
|
default: |
|
|
|
return nil, ErrInvalidNodeFound |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// addNode adds a node into the MT. Empty nodes are not stored in the tree;
|
|
|
|
// they are all the same and assumed to always exist.
|
|
|
|
func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) { |
|
|
|
// verify that the MerkleTree is writable
|
|
|
|
if !mt.writable { |
|
|
|
return nil, ErrNotWritable |
|
|
|
} |
|
|
|
if n.Type == NodeTypeEmpty { |
|
|
|
return n.Key() |
|
|
|
} |
|
|
|
k, err := n.Key() |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
v := n.Value() |
|
|
|
// Check that the node key doesn't already exist
|
|
|
|
if _, err := tx.Get(k[:]); err == nil { |
|
|
|
return nil, ErrNodeKeyAlreadyExists |
|
|
|
} |
|
|
|
tx.Put(k[:], v) |
|
|
|
return k, nil |
|
|
|
} |
|
|
|
|
|
|
|
// dbGet is a helper function to get the node of a key from the internal
|
|
|
|
// storage.
|
|
|
|
func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) { |
|
|
|
if bytes.Equal(k, HashZero[:]) { |
|
|
|
return 0, nil, nil |
|
|
|
} |
|
|
|
|
|
|
|
value, err := mt.db.Get(k) |
|
|
|
if err != nil { |
|
|
|
return 0, nil, err |
|
|
|
} |
|
|
|
|
|
|
|
if len(value) < 2 { |
|
|
|
return 0, nil, ErrInvalidDBValue |
|
|
|
} |
|
|
|
nodeType := value[0] |
|
|
|
nodeBytes := value[1:] |
|
|
|
|
|
|
|
return NodeType(nodeType), nodeBytes, nil |
|
|
|
} |
|
|
|
|
|
|
|
// dbInsert is a helper function to insert a node into a key in an open db
|
|
|
|
// transaction.
|
|
|
|
func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) { |
|
|
|
v := append([]byte{byte(t)}, data...) |
|
|
|
tx.Put(k, v) |
|
|
|
} |
|
|
|
|
|
|
|
// GetNode gets a node by key from the MT. Empty nodes are not stored in the
|
|
|
|
// tree; they are all the same and assumed to always exist.
|
|
|
|
func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) { |
|
|
|
if bytes.Equal(key[:], HashZero[:]) { |
|
|
|
return NewNodeEmpty(), nil |
|
|
|
} |
|
|
|
nBytes, err := mt.db.Get(key[:]) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
return NewNodeFromBytes(nBytes) |
|
|
|
} |
|
|
|
|
|
|
|
// getPath returns the binary path, from the root to the leaf.
|
|
|
|
func getPath(numLevels int, k []byte) []bool { |
|
|
|
path := make([]bool, numLevels) |
|
|
|
for n := 0; n < numLevels; n++ { |
|
|
|
path[n] = common.TestBit(k[:], uint(n)) |
|
|
|
} |
|
|
|
return path |
|
|
|
} |