diff --git a/README.md b/README.md new file mode 100644 index 0000000..6af6993 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# go-merkletree +MerkleTree compatible with version from [circomlib](https://github.com/iden3/circomlib). + +Adaptation of the merkletree from https://github.com/iden3/go-iden3-core/tree/v0.0.8 diff --git a/merkletree.go b/merkletree.go index 434fad3..693ca43 100644 --- a/merkletree.go +++ b/merkletree.go @@ -1,12 +1,14 @@ package merkletree import ( + "bytes" "errors" "math/big" "sync" "github.com/iden3/go-iden3-core/common" "github.com/iden3/go-iden3-core/db" + cryptoUtils "github.com/iden3/go-iden3-crypto/utils" ) const ( @@ -95,3 +97,206 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) { func (mt *MerkleTree) Root() *Hash { return mt.rootKey } + +func (mt *MerkleTree) Add(k, v *big.Int) error { + // verify that the MerkleTree is writable + if !mt.writable { + return ErrNotWritable + } + + // verfy that the ElemBytes are valid and fit inside the Finite Field. + if !cryptoUtils.CheckBigIntInField(k) { + return errors.New("Key not inside the Finite Field") + } + if !cryptoUtils.CheckBigIntInField(v) { + return errors.New("Value not inside the Finite Field") + } + + tx, err := mt.db.NewTx() + if err != nil { + return err + } + mt.Lock() + defer mt.Unlock() + + kHash := NewHashFromBigInt(k) + vHash := NewHashFromBigInt(v) + newNodeLeaf := NewNodeLeaf(kHash, vHash) + path := getPath(mt.maxLevels, kHash[:]) + + newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path) + if err != nil { + return err + } + mt.rootKey = newRootKey + mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// pushLeaf recursively pushes an existing oldLeaf down until its path diverges +// from newLeaf, at which point both leafs are stored, all while updating the +// path. +func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node, + lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) { + if lvl > mt.maxLevels-2 { + return nil, ErrReachedMaxLevel + } + var newNodeMiddle *Node + if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper! + nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf) + if err != nil { + return nil, err + } + if pathNewLeaf[lvl] { + newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right + } else { + newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left + } + return mt.addNode(tx, newNodeMiddle) + } else { + oldLeafKey, err := oldLeaf.Key() + if err != nil { + return nil, err + } + newLeafKey, err := newLeaf.Key() + if err != nil { + return nil, err + } + + if pathNewLeaf[lvl] { + newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey) + } else { + newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey) + } + // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree. + _, err = mt.addNode(tx, newLeaf) + if err != nil { + return nil, err + } + return mt.addNode(tx, newNodeMiddle) + } +} + +// addLeaf recursively adds a newLeaf in the MT while updating the path. +func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash, + lvl int, path []bool) (*Hash, error) { + var err error + var nextKey *Hash + if lvl > mt.maxLevels-1 { + return nil, ErrReachedMaxLevel + } + n, err := mt.GetNode(key) + if err != nil { + return nil, err + } + switch n.Type { + case NodeTypeEmpty: + // We can add newLeaf now + return mt.addNode(tx, newLeaf) + case NodeTypeLeaf: + nKey := n.Entry[0] + // Check if leaf node found contains the leaf node we are trying to add + newLeafKey := newLeaf.Entry[0] + if bytes.Equal(nKey[:], newLeafKey[:]) { + return nil, ErrEntryIndexAlreadyExists + } + pathOldLeaf := getPath(mt.maxLevels, nKey[:]) + // We need to push newLeaf down until its path diverges from n's path + return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf) + case NodeTypeMiddle: + // We need to go deeper, continue traversing the tree, left or right depending on path + var newNodeMiddle *Node + if path[lvl] { + nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right + newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey) + } else { + nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left + newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR) + } + if err != nil { + return nil, err + } + // Update the node to reflect the modified child + return mt.addNode(tx, newNodeMiddle) + default: + return nil, ErrInvalidNodeFound + } +} + +// addNode adds a node into the MT. Empty nodes are not stored in the tree; +// they are all the same and assumed to always exist. +func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) { + // verify that the MerkleTree is writable + if !mt.writable { + return nil, ErrNotWritable + } + if n.Type == NodeTypeEmpty { + return n.Key() + } + k, err := n.Key() + if err != nil { + return nil, err + } + v := n.Value() + // Check that the node key doesn't already exist + if _, err := tx.Get(k[:]); err == nil { + return nil, ErrNodeKeyAlreadyExists + } + tx.Put(k[:], v) + return k, nil +} + +// dbGet is a helper function to get the node of a key from the internal +// storage. +func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) { + if bytes.Equal(k, HashZero[:]) { + return 0, nil, nil + } + + value, err := mt.db.Get(k) + if err != nil { + return 0, nil, err + } + + if len(value) < 2 { + return 0, nil, ErrInvalidDBValue + } + nodeType := value[0] + nodeBytes := value[1:] + + return NodeType(nodeType), nodeBytes, nil +} + +// dbInsert is a helper function to insert a node into a key in an open db +// transaction. +func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) { + v := append([]byte{byte(t)}, data...) + tx.Put(k, v) +} + +// GetNode gets a node by key from the MT. Empty nodes are not stored in the +// tree; they are all the same and assumed to always exist. +func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) { + if bytes.Equal(key[:], HashZero[:]) { + return NewNodeEmpty(), nil + } + nBytes, err := mt.db.Get(key[:]) + if err != nil { + return nil, err + } + return NewNodeFromBytes(nBytes) +} + +// getPath returns the binary path, from the root to the leaf. +func getPath(numLevels int, k []byte) []bool { + path := make([]bool, numLevels) + for n := 0; n < numLevels; n++ { + path[n] = common.TestBit(k[:], uint(n)) + } + return path +} diff --git a/merkletree_test.go b/merkletree_test.go new file mode 100644 index 0000000..ae923fb --- /dev/null +++ b/merkletree_test.go @@ -0,0 +1,28 @@ +package merkletree + +import ( + "math/big" + "testing" + + "github.com/iden3/go-iden3-core/db" + "github.com/stretchr/testify/assert" +) + +func TestNewTree(t *testing.T) { + mt, err := NewMerkleTree(db.NewMemoryStorage(), 10) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + err = mt.Add(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String()) + + err = mt.Add(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String()) + + err = mt.Add(big.NewInt(1234), big.NewInt(9876)) + assert.Nil(t, err) + assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String()) +}