mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-07 11:36:47 +01:00
Add(k, v), add tests compatible with circomlib
This commit is contained in:
4
README.md
Normal file
4
README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# go-merkletree
|
||||
MerkleTree compatible with version from [circomlib](https://github.com/iden3/circomlib).
|
||||
|
||||
Adaptation of the merkletree from https://github.com/iden3/go-iden3-core/tree/v0.0.8
|
||||
205
merkletree.go
205
merkletree.go
@@ -1,12 +1,14 @@
|
||||
package merkletree
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"math/big"
|
||||
"sync"
|
||||
|
||||
"github.com/iden3/go-iden3-core/common"
|
||||
"github.com/iden3/go-iden3-core/db"
|
||||
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -95,3 +97,206 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
|
||||
func (mt *MerkleTree) Root() *Hash {
|
||||
return mt.rootKey
|
||||
}
|
||||
|
||||
func (mt *MerkleTree) Add(k, v *big.Int) error {
|
||||
// verify that the MerkleTree is writable
|
||||
if !mt.writable {
|
||||
return ErrNotWritable
|
||||
}
|
||||
|
||||
// verfy that the ElemBytes are valid and fit inside the Finite Field.
|
||||
if !cryptoUtils.CheckBigIntInField(k) {
|
||||
return errors.New("Key not inside the Finite Field")
|
||||
}
|
||||
if !cryptoUtils.CheckBigIntInField(v) {
|
||||
return errors.New("Value not inside the Finite Field")
|
||||
}
|
||||
|
||||
tx, err := mt.db.NewTx()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mt.Lock()
|
||||
defer mt.Unlock()
|
||||
|
||||
kHash := NewHashFromBigInt(k)
|
||||
vHash := NewHashFromBigInt(v)
|
||||
newNodeLeaf := NewNodeLeaf(kHash, vHash)
|
||||
path := getPath(mt.maxLevels, kHash[:])
|
||||
|
||||
newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mt.rootKey = newRootKey
|
||||
mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pushLeaf recursively pushes an existing oldLeaf down until its path diverges
|
||||
// from newLeaf, at which point both leafs are stored, all while updating the
|
||||
// path.
|
||||
func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
|
||||
lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
|
||||
if lvl > mt.maxLevels-2 {
|
||||
return nil, ErrReachedMaxLevel
|
||||
}
|
||||
var newNodeMiddle *Node
|
||||
if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
|
||||
nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pathNewLeaf[lvl] {
|
||||
newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
|
||||
} else {
|
||||
newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
|
||||
}
|
||||
return mt.addNode(tx, newNodeMiddle)
|
||||
} else {
|
||||
oldLeafKey, err := oldLeaf.Key()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newLeafKey, err := newLeaf.Key()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pathNewLeaf[lvl] {
|
||||
newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
|
||||
} else {
|
||||
newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
|
||||
}
|
||||
// We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
|
||||
_, err = mt.addNode(tx, newLeaf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mt.addNode(tx, newNodeMiddle)
|
||||
}
|
||||
}
|
||||
|
||||
// addLeaf recursively adds a newLeaf in the MT while updating the path.
|
||||
func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
|
||||
lvl int, path []bool) (*Hash, error) {
|
||||
var err error
|
||||
var nextKey *Hash
|
||||
if lvl > mt.maxLevels-1 {
|
||||
return nil, ErrReachedMaxLevel
|
||||
}
|
||||
n, err := mt.GetNode(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch n.Type {
|
||||
case NodeTypeEmpty:
|
||||
// We can add newLeaf now
|
||||
return mt.addNode(tx, newLeaf)
|
||||
case NodeTypeLeaf:
|
||||
nKey := n.Entry[0]
|
||||
// Check if leaf node found contains the leaf node we are trying to add
|
||||
newLeafKey := newLeaf.Entry[0]
|
||||
if bytes.Equal(nKey[:], newLeafKey[:]) {
|
||||
return nil, ErrEntryIndexAlreadyExists
|
||||
}
|
||||
pathOldLeaf := getPath(mt.maxLevels, nKey[:])
|
||||
// We need to push newLeaf down until its path diverges from n's path
|
||||
return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
|
||||
case NodeTypeMiddle:
|
||||
// We need to go deeper, continue traversing the tree, left or right depending on path
|
||||
var newNodeMiddle *Node
|
||||
if path[lvl] {
|
||||
nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
|
||||
newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
|
||||
} else {
|
||||
nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
|
||||
newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Update the node to reflect the modified child
|
||||
return mt.addNode(tx, newNodeMiddle)
|
||||
default:
|
||||
return nil, ErrInvalidNodeFound
|
||||
}
|
||||
}
|
||||
|
||||
// addNode adds a node into the MT. Empty nodes are not stored in the tree;
|
||||
// they are all the same and assumed to always exist.
|
||||
func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
|
||||
// verify that the MerkleTree is writable
|
||||
if !mt.writable {
|
||||
return nil, ErrNotWritable
|
||||
}
|
||||
if n.Type == NodeTypeEmpty {
|
||||
return n.Key()
|
||||
}
|
||||
k, err := n.Key()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := n.Value()
|
||||
// Check that the node key doesn't already exist
|
||||
if _, err := tx.Get(k[:]); err == nil {
|
||||
return nil, ErrNodeKeyAlreadyExists
|
||||
}
|
||||
tx.Put(k[:], v)
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// dbGet is a helper function to get the node of a key from the internal
|
||||
// storage.
|
||||
func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
|
||||
if bytes.Equal(k, HashZero[:]) {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
value, err := mt.db.Get(k)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if len(value) < 2 {
|
||||
return 0, nil, ErrInvalidDBValue
|
||||
}
|
||||
nodeType := value[0]
|
||||
nodeBytes := value[1:]
|
||||
|
||||
return NodeType(nodeType), nodeBytes, nil
|
||||
}
|
||||
|
||||
// dbInsert is a helper function to insert a node into a key in an open db
|
||||
// transaction.
|
||||
func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
|
||||
v := append([]byte{byte(t)}, data...)
|
||||
tx.Put(k, v)
|
||||
}
|
||||
|
||||
// GetNode gets a node by key from the MT. Empty nodes are not stored in the
|
||||
// tree; they are all the same and assumed to always exist.
|
||||
func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
|
||||
if bytes.Equal(key[:], HashZero[:]) {
|
||||
return NewNodeEmpty(), nil
|
||||
}
|
||||
nBytes, err := mt.db.Get(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewNodeFromBytes(nBytes)
|
||||
}
|
||||
|
||||
// getPath returns the binary path, from the root to the leaf.
|
||||
func getPath(numLevels int, k []byte) []bool {
|
||||
path := make([]bool, numLevels)
|
||||
for n := 0; n < numLevels; n++ {
|
||||
path[n] = common.TestBit(k[:], uint(n))
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
28
merkletree_test.go
Normal file
28
merkletree_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package merkletree
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/iden3/go-iden3-core/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewTree(t *testing.T) {
|
||||
mt, err := NewMerkleTree(db.NewMemoryStorage(), 10)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
err = mt.Add(big.NewInt(1), big.NewInt(2))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String())
|
||||
|
||||
err = mt.Add(big.NewInt(33), big.NewInt(44))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
|
||||
|
||||
err = mt.Add(big.NewInt(1234), big.NewInt(9876))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String())
|
||||
}
|
||||
Reference in New Issue
Block a user