Browse Source

Merge pull request #7 from iden3/circomproofs

Add CircomProofs for Addition & Update
fix/hash-parsers
a_bennassar 4 years ago
committed by GitHub
parent
commit
dc656fdd32
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 318 additions and 126 deletions
  1. +1
    -1
      go.mod
  2. +5
    -4
      go.sum
  3. +172
    -31
      merkletree.go
  4. +131
    -57
      merkletree_test.go
  5. +9
    -33
      utils.go

+ 1
- 1
go.mod

@ -5,7 +5,7 @@ go 1.14
require ( require (
github.com/cockroachdb/pebble v0.0.0-20200814004841-77c18adb0ee3 github.com/cockroachdb/pebble v0.0.0-20200814004841-77c18adb0ee3
github.com/iden3/go-iden3-core v0.0.8 github.com/iden3/go-iden3-core v0.0.8
github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf
github.com/iden3/go-iden3-crypto v0.0.6-0.20200819064831-09d161e9f670
github.com/sirupsen/logrus v1.5.0 github.com/sirupsen/logrus v1.5.0
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1
github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d

+ 5
- 4
go.sum

@ -110,7 +110,6 @@ github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2-0.20190517061210-b285ee9cfc6c/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2-0.20190517061210-b285ee9cfc6c/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf h1:gFVkHXmVAhEbxZVDln5V9GKrLaluNoFHDbrZwAWZgws= github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf h1:gFVkHXmVAhEbxZVDln5V9GKrLaluNoFHDbrZwAWZgws=
github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
@ -132,10 +131,13 @@ github.com/iden3/go-circom-prover-verifier v0.0.1/go.mod h1:1FkpX4nUXxYcY2fpzqd2
github.com/iden3/go-circom-witnesscalc v0.0.1/go.mod h1:xjT1BlFZDBioHOlbD75SmZZLC1d1AfOycqbSa/1QRJU= github.com/iden3/go-circom-witnesscalc v0.0.1/go.mod h1:xjT1BlFZDBioHOlbD75SmZZLC1d1AfOycqbSa/1QRJU=
github.com/iden3/go-iden3-core v0.0.8 h1:PLw7iCiX7Pw1dqBkR+JaLQWqB5RKd+vgu25UBdvFXGQ= github.com/iden3/go-iden3-core v0.0.8 h1:PLw7iCiX7Pw1dqBkR+JaLQWqB5RKd+vgu25UBdvFXGQ=
github.com/iden3/go-iden3-core v0.0.8/go.mod h1:URNjIhMql6sEbWubIGrjJdw5wHCE1Pk1XghxjBOtA3s= github.com/iden3/go-iden3-core v0.0.8/go.mod h1:URNjIhMql6sEbWubIGrjJdw5wHCE1Pk1XghxjBOtA3s=
github.com/iden3/go-iden3-crypto v0.0.5 h1:inCSm5a+ry+nbpVTL/9+m6UcIwSv6nhUm0tnIxEbcps=
github.com/iden3/go-iden3-crypto v0.0.5/go.mod h1:XKw1oDwYn2CIxKOtr7m/mL5jMn4mLOxAxtZBRxQBev8= github.com/iden3/go-iden3-crypto v0.0.5/go.mod h1:XKw1oDwYn2CIxKOtr7m/mL5jMn4mLOxAxtZBRxQBev8=
github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf h1:/7L5dEqctuzJY2g8OEQct+1Y+n2sMKyd4JoYhw2jy1s= github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf h1:/7L5dEqctuzJY2g8OEQct+1Y+n2sMKyd4JoYhw2jy1s=
github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf/go.mod h1:XKw1oDwYn2CIxKOtr7m/mL5jMn4mLOxAxtZBRxQBev8= github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf/go.mod h1:XKw1oDwYn2CIxKOtr7m/mL5jMn4mLOxAxtZBRxQBev8=
github.com/iden3/go-iden3-crypto v0.0.6-0.20200818162919-3364756c2ca6 h1:8B0nnJejnuZZYmbg1MkyrbeQWGkhHTXOyTEU+htfpR8=
github.com/iden3/go-iden3-crypto v0.0.6-0.20200818162919-3364756c2ca6/go.mod h1:oBgthFLboAWi9feaBUFy7OxEcyn9vA1khHSL/WwWFyg=
github.com/iden3/go-iden3-crypto v0.0.6-0.20200819064831-09d161e9f670 h1:gNBFu/WnRfNn+xywE04fgCWSHlb6wr0nIIll9i4R2fc=
github.com/iden3/go-iden3-crypto v0.0.6-0.20200819064831-09d161e9f670/go.mod h1:oBgthFLboAWi9feaBUFy7OxEcyn9vA1khHSL/WwWFyg=
github.com/iden3/go-wasm3 v0.0.1/go.mod h1:j+TcAB94Dfrjlu5kJt83h2OqAU+oyNUTwNZnQyII1sI= github.com/iden3/go-wasm3 v0.0.1/go.mod h1:j+TcAB94Dfrjlu5kJt83h2OqAU+oyNUTwNZnQyII1sI=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY=
@ -232,7 +234,6 @@ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@ -289,6 +290,7 @@ golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 h1:DYfZAGf2WMFjMxbgTjaC+2HC7NkNAQs+6Q8b9WEB/F4= golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 h1:DYfZAGf2WMFjMxbgTjaC+2HC7NkNAQs+6Q8b9WEB/F4=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -321,7 +323,6 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD
gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 172
- 31
merkletree.go

@ -71,6 +71,9 @@ func (h Hash) Hex() string {
// BigInt returns the *big.Int representation of the *Hash // BigInt returns the *big.Int representation of the *Hash
func (h *Hash) BigInt() *big.Int { func (h *Hash) BigInt() *big.Int {
if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
return big.NewInt(0)
}
return new(big.Int).SetBytes(common.SwapEndianness(h[:])) return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
} }
@ -139,6 +142,11 @@ func (mt *MerkleTree) Root() *Hash {
return mt.rootKey return mt.rootKey
} }
// MaxLevels returns the MT maximum level
func (mt *MerkleTree) MaxLevels() int {
return mt.maxLevels
}
// Snapshot returns a read-only copy of the MerkleTree // Snapshot returns a read-only copy of the MerkleTree
func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) { func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
mt.RLock() mt.RLock()
@ -191,6 +199,38 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return nil return nil
} }
// AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
var cp CircomProcessorProof
cp.Fnc = 2
cp.OldRoot = mt.rootKey
gettedK, gettedV, siblings, err := mt.Get(k)
if err != nil && err != ErrKeyNotFound {
return nil, err
}
cp.OldKey = NewHashFromBigInt(gettedK)
cp.OldValue = NewHashFromBigInt(gettedV)
if bytes.Equal(cp.OldKey[:], HashZero[:]) {
cp.IsOld0 = true
}
_, _, siblings, err = mt.Get(k)
if err != nil && err != ErrKeyNotFound {
return nil, err
}
cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
err = mt.Add(k, v)
if err != nil {
return nil, err
}
cp.NewKey = NewHashFromBigInt(k)
cp.NewValue = NewHashFromBigInt(v)
cp.NewRoot = mt.rootKey
return &cp, nil
}
// pushLeaf recursively pushes an existing oldLeaf down until its path diverges // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
// from newLeaf, at which point both leafs are stored, all while updating the // from newLeaf, at which point both leafs are stored, all while updating the
// path. // path.
@ -262,7 +302,8 @@ func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
// We need to push newLeaf down until its path diverges from n's path // We need to push newLeaf down until its path diverges from n's path
return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf) return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
case NodeTypeMiddle: case NodeTypeMiddle:
// We need to go deeper, continue traversing the tree, left or right depending on path
// We need to go deeper, continue traversing the tree, left or
// right depending on path
var newNodeMiddle *Node var newNodeMiddle *Node
if path[lvl] { if path[lvl] {
nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
@ -305,62 +346,66 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
} }
// Get returns the value of the leaf for the given key // Get returns the value of the leaf for the given key
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, error) {
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
// verfy that k is valid and fit inside the Finite Field. // verfy that k is valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) { if !cryptoUtils.CheckBigIntInField(k) {
return nil, errors.New("Key not inside the Finite Field")
return nil, nil, nil, errors.New("Key not inside the Finite Field")
} }
kHash := NewHashFromBigInt(k) kHash := NewHashFromBigInt(k)
path := getPath(mt.maxLevels, kHash[:]) path := getPath(mt.maxLevels, kHash[:])
nextKey := mt.rootKey nextKey := mt.rootKey
var siblings []*Hash
for i := 0; i < mt.maxLevels; i++ { for i := 0; i < mt.maxLevels; i++ {
n, err := mt.GetNode(nextKey) n, err := mt.GetNode(nextKey)
if err != nil { if err != nil {
return nil, err
return nil, nil, nil, err
} }
switch n.Type { switch n.Type {
case NodeTypeEmpty: case NodeTypeEmpty:
return nil, ErrKeyNotFound
return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
case NodeTypeLeaf: case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) { if bytes.Equal(kHash[:], n.Entry[0][:]) {
return n.Entry[1].BigInt(), nil
return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
} else { } else {
return nil, ErrKeyNotFound
return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
} }
case NodeTypeMiddle: case NodeTypeMiddle:
if path[i] { if path[i] {
nextKey = n.ChildR nextKey = n.ChildR
siblings = append(siblings, n.ChildL)
} else { } else {
nextKey = n.ChildL nextKey = n.ChildL
siblings = append(siblings, n.ChildR)
} }
default: default:
return nil, ErrInvalidNodeFound
return nil, nil, nil, ErrInvalidNodeFound
} }
} }
return nil, ErrKeyNotFound
return nil, nil, nil, ErrKeyNotFound
} }
// Update updates the value of a specified key in the MerkleTree, and updates // Update updates the value of a specified key in the MerkleTree, and updates
// the path from the leaf to the Root with the new values.
func (mt *MerkleTree) Update(k, v *big.Int) error {
// the path from the leaf to the Root with the new values. Returns the
// CircomProcessorProof.
func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
// verify that the MerkleTree is writable // verify that the MerkleTree is writable
if !mt.writable { if !mt.writable {
return ErrNotWritable
return nil, ErrNotWritable
} }
// verfy that k & are valid and fit inside the Finite Field. // verfy that k & are valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) { if !cryptoUtils.CheckBigIntInField(k) {
return errors.New("Key not inside the Finite Field")
return nil, errors.New("Key not inside the Finite Field")
} }
if !cryptoUtils.CheckBigIntInField(v) { if !cryptoUtils.CheckBigIntInField(v) {
return errors.New("Key not inside the Finite Field")
return nil, errors.New("Key not inside the Finite Field")
} }
tx, err := mt.db.NewTx() tx, err := mt.db.NewTx()
if err != nil { if err != nil {
return err
return nil, err
} }
mt.Lock() mt.Lock()
defer mt.Unlock() defer mt.Unlock()
@ -369,33 +414,46 @@ func (mt *MerkleTree) Update(k, v *big.Int) error {
vHash := NewHashFromBigInt(v) vHash := NewHashFromBigInt(v)
path := getPath(mt.maxLevels, kHash[:]) path := getPath(mt.maxLevels, kHash[:])
var cp CircomProcessorProof
cp.Fnc = 1
cp.OldRoot = mt.rootKey
cp.OldKey = kHash
cp.NewKey = kHash
cp.NewValue = vHash
nextKey := mt.rootKey nextKey := mt.rootKey
var siblings []*Hash var siblings []*Hash
for i := 0; i < mt.maxLevels; i++ { for i := 0; i < mt.maxLevels; i++ {
n, err := mt.GetNode(nextKey) n, err := mt.GetNode(nextKey)
if err != nil { if err != nil {
return err
return nil, err
} }
switch n.Type { switch n.Type {
case NodeTypeEmpty: case NodeTypeEmpty:
return ErrKeyNotFound
return nil, ErrKeyNotFound
case NodeTypeLeaf: case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) { if bytes.Equal(kHash[:], n.Entry[0][:]) {
cp.OldValue = n.Entry[1]
cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
// update leaf and upload to the root // update leaf and upload to the root
newNodeLeaf := NewNodeLeaf(kHash, vHash) newNodeLeaf := NewNodeLeaf(kHash, vHash)
_, err := mt.addNode(tx, newNodeLeaf) _, err := mt.addNode(tx, newNodeLeaf)
if err != nil { if err != nil {
return err
return nil, err
} }
newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings) newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
if err != nil { if err != nil {
return err
return nil, err
} }
mt.rootKey = newRootKey mt.rootKey = newRootKey
mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
return tx.Commit()
cp.NewRoot = newRootKey
if err := tx.Commit(); err != nil {
return nil, err
}
return &cp, nil
} else { } else {
return ErrKeyNotFound
return nil, ErrKeyNotFound
} }
case NodeTypeMiddle: case NodeTypeMiddle:
if path[i] { if path[i] {
@ -406,11 +464,11 @@ func (mt *MerkleTree) Update(k, v *big.Int) error {
siblings = append(siblings, n.ChildR) siblings = append(siblings, n.ChildR)
} }
default: default:
return ErrInvalidNodeFound
return nil, ErrInvalidNodeFound
} }
} }
return ErrKeyNotFound
return nil, ErrKeyNotFound
} }
// Delete removes the specified Key from the MerkleTree and updates the path // Delete removes the specified Key from the MerkleTree and updates the path
@ -700,7 +758,18 @@ func (p *Proof) AllSiblings() []*Hash {
return SiblingsFromProof(p) return SiblingsFromProof(p)
} }
// AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits.
// CircomSiblingsFromSiblings returns the full siblings compatible with circom
func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
// Add the rest of empty levels to the siblings
for i := len(siblings); i < levels; i++ {
siblings = append(siblings, &HashZero)
}
siblings = append(siblings, &HashZero) // add extra level for circom compatibility
return siblings
}
// AllSiblingsCircom returns all the siblings of the proof. This function is
// used to generate the siblings input for the circom circuits.
func (p *Proof) AllSiblingsCircom(levels int) []*big.Int { func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
siblings := p.AllSiblings() siblings := p.AllSiblings()
// Add the rest of empty levels to the siblings // Add the rest of empty levels to the siblings
@ -715,10 +784,82 @@ func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
return siblingsBigInt return siblingsBigInt
} }
// CircomProcessorProof defines the ProcessorProof compatible with circom. Is
// the data of the proof between the transition from one state to another.
type CircomProcessorProof struct {
OldRoot *Hash
NewRoot *Hash
Siblings []*Hash
OldKey *Hash
OldValue *Hash
NewKey *Hash
NewValue *Hash
IsOld0 bool
Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete
}
// String returns a human readable string representation of the
// CircomProcessorProof
func (p CircomProcessorProof) String() string {
buf := bytes.NewBufferString("{")
fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
fmt.Fprintf(buf, " Siblings: [\n ")
for _, s := range p.Siblings {
fmt.Fprintf(buf, "%v, ", s)
}
fmt.Fprintf(buf, "\n ],\n")
fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
fmt.Fprintf(buf, "}\n")
return buf.String()
}
// CircomVerifierProof defines the VerifierProof compatible with circom. Is the
// data of the proof that a certain leaf exists in the MerkleTree.
type CircomVerifierProof struct {
Root *Hash
Siblings []*big.Int
OldKey *Hash
OldValue *Hash
IsOld0 bool
Key *Hash
Value *Hash
Fnc int // 0: inclusion, 1: non inclusion
}
// GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
// key in the MerkleTree. If the rootKey is nil, the current merkletree root
// is used.
func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int, rootKey *Hash) (*CircomVerifierProof, error) {
p, v, err := mt.GenerateProof(k, rootKey)
if err != nil || err != ErrKeyNotFound {
return nil, err
}
var cp CircomVerifierProof
cp.Root = mt.rootKey
cp.Siblings = p.AllSiblingsCircom(mt.maxLevels)
cp.OldKey = &HashZero
cp.OldValue = &HashZero
cp.Key = NewHashFromBigInt(k)
cp.Value = NewHashFromBigInt(v)
if p.Existence {
cp.Fnc = 0 // inclusion
} else {
cp.Fnc = 1 // non inclusion
}
return &cp, nil
}
// GenerateProof generates the proof of existence (or non-existence) of an // GenerateProof generates the proof of existence (or non-existence) of an
// Entry's hash Index for a Merkle Tree given the root. // Entry's hash Index for a Merkle Tree given the root.
// If the rootKey is nil, the current merkletree root is used // If the rootKey is nil, the current merkletree root is used
func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, *big.Int, error) {
p := &Proof{} p := &Proof{}
var siblingKey *Hash var siblingKey *Hash
@ -731,19 +872,19 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ { for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
n, err := mt.GetNode(nextKey) n, err := mt.GetNode(nextKey)
if err != nil { if err != nil {
return nil, err
return nil, nil, err
} }
switch n.Type { switch n.Type {
case NodeTypeEmpty: case NodeTypeEmpty:
return p, nil
return p, big.NewInt(0), nil
case NodeTypeLeaf: case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) { if bytes.Equal(kHash[:], n.Entry[0][:]) {
p.Existence = true p.Existence = true
return p, nil
return p, n.Entry[1].BigInt(), nil
} else { } else {
// We found a leaf whose entry didn't match hIndex // We found a leaf whose entry didn't match hIndex
p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]} p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
return p, nil
return p, n.Entry[1].BigInt(), nil
} }
case NodeTypeMiddle: case NodeTypeMiddle:
if path[p.depth] { if path[p.depth] {
@ -754,14 +895,14 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
siblingKey = n.ChildR siblingKey = n.ChildR
} }
default: default:
return nil, ErrInvalidNodeFound
return nil, nil, ErrInvalidNodeFound
} }
if !bytes.Equal(siblingKey[:], HashZero[:]) { if !bytes.Equal(siblingKey[:], HashZero[:]) {
common.SetBitBigEndian(p.notempties[:], uint(p.depth)) common.SetBitBigEndian(p.notempties[:], uint(p.depth))
p.Siblings = append(p.Siblings, siblingKey) p.Siblings = append(p.Siblings, siblingKey)
} }
} }
return nil, ErrKeyNotFound
return nil, nil, ErrKeyNotFound
} }
// VerifyProof verifies the Merkle Proof for the entry and root. // VerifyProof verifies the Merkle Proof for the entry and root.

+ 131
- 57
merkletree_test.go

@ -56,18 +56,19 @@ func TestNewTree(t *testing.T) {
// test vectors generated using https://github.com/iden3/circomlib smt.js // test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2)) err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String())
assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String())
err = mt.Add(big.NewInt(33), big.NewInt(44)) err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
err = mt.Add(big.NewInt(1234), big.NewInt(9876)) err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String())
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
proof, err := mt.GenerateProof(big.NewInt(33), nil)
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44))) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45))) assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
@ -95,7 +96,7 @@ func TestAddDifferentOrder(t *testing.T) {
} }
assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex()) assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
assert.Equal(t, "0967b777d660e54aa3a0f0f3405bb962504d3d69d6b930146cd212dff9913bee", mt1.Root().Hex())
assert.Equal(t, "0630b27c6f8c7d36d144369ab1ac408552b544ebe96ad642bad6a94a96258e26", mt1.Root().Hex())
} }
func TestAddRepeatedIndex(t *testing.T) { func TestAddRepeatedIndex(t *testing.T) {
@ -122,18 +123,21 @@ func TestGet(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
v, err := mt.Get(big.NewInt(10))
k, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(10), k)
assert.Equal(t, big.NewInt(20), v) assert.Equal(t, big.NewInt(20), v)
v, err = mt.Get(big.NewInt(15))
k, v, _, err = mt.Get(big.NewInt(15))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(15), k)
assert.Equal(t, big.NewInt(30), v) assert.Equal(t, big.NewInt(30), v)
v, err = mt.Get(big.NewInt(16))
k, v, _, err = mt.Get(big.NewInt(16))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrKeyNotFound, err) assert.Equal(t, ErrKeyNotFound, err)
assert.Nil(t, v)
assert.Equal(t, "0", k.String())
assert.Equal(t, "0", v.String())
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
@ -147,17 +151,17 @@ func TestUpdate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
v, err := mt.Get(big.NewInt(10))
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v) assert.Equal(t, big.NewInt(20), v)
err = mt.Update(big.NewInt(10), big.NewInt(1024))
_, err = mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err) assert.Nil(t, err)
v, err = mt.Get(big.NewInt(10))
_, v, _, err = mt.Get(big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(1024), v) assert.Equal(t, big.NewInt(1024), v)
err = mt.Update(big.NewInt(1000), big.NewInt(1024))
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
assert.Equal(t, ErrKeyNotFound, err) assert.Equal(t, ErrKeyNotFound, err)
} }
@ -181,11 +185,11 @@ func TestUpdate2(t *testing.T) {
err = mt2.Add(big.NewInt(9876), big.NewInt(10)) err = mt2.Add(big.NewInt(9876), big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
err = mt1.Update(big.NewInt(1), big.NewInt(11))
_, err = mt1.Update(big.NewInt(1), big.NewInt(11))
assert.Nil(t, err) assert.Nil(t, err)
err = mt1.Update(big.NewInt(2), big.NewInt(22))
_, err = mt1.Update(big.NewInt(2), big.NewInt(22))
assert.Nil(t, err) assert.Nil(t, err)
err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
_, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, mt1.Root(), mt2.Root()) assert.Equal(t, mt1.Root(), mt2.Root())
@ -203,8 +207,9 @@ func TestGenerateAndVerifyProof128(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
proof, err := mt.GenerateProof(big.NewInt(42), nil)
proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "0", v.String())
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0))) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
} }
@ -237,19 +242,19 @@ func TestSiblingsFromProof(t *testing.T) {
} }
} }
proof, err := mt.GenerateProof(big.NewInt(4), nil)
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
siblings := SiblingsFromProof(proof) siblings := SiblingsFromProof(proof)
assert.Equal(t, 6, len(siblings)) assert.Equal(t, 6, len(siblings))
assert.Equal(t, "23db1f6fb07af47d7715f18960548c215fc7a2e6d25cb4a7eb82c9d3cf69bc26", siblings[0].Hex())
assert.Equal(t, "2156e64dedcb76719ec732414dd6a8aa4348dafb24c19351a68fbc4158bb7fba", siblings[1].Hex())
assert.Equal(t, "04a8e9b34d5a8b55268ca96b0b8c7c5aaef4f606ec3437ec67e4152d9b323913", siblings[2].Hex())
assert.Equal(t, "0ff484133e0d25deb4a7c0cb46d90432e00fcc280948c2fab6fed9476f1e26b2", siblings[3].Hex())
assert.Equal(t, "015dff482e87eb2046b8f5323049afd05f8dd8554e2c9aa1ef28991cf205c9b6", siblings[4].Hex())
assert.Equal(t, "1e4da486ad68b07acec1406bed5a60732de5ff72d63910f7afbb491f953a8769", siblings[5].Hex())
assert.Equal(t, "2f59aeef9e5b881609aa56940dba76b5cb1440a794f4eb03ad5e5958dd8b475b", siblings[0].Hex())
assert.Equal(t, "2eb29ffbded0987f36a62aecddf748d2b9bf28326300bfa15e474e0a12abe8c1", siblings[1].Hex())
assert.Equal(t, "0c6ee1298933d073a390cc3d267a8a4d5a7df65a126d3fdc5a16b9c28afddaf4", siblings[2].Hex())
assert.Equal(t, "1575898b0b4e7802a6be130e7b76ede64fe42079b6852eba6af985bd46a34aa9", siblings[3].Hex())
assert.Equal(t, "1d15b701c1fd521841120980c5cbfa86f15b1f22bf1d3079ed0d0314751d7954", siblings[4].Hex())
assert.Equal(t, "1ee00f37756159cfefaa0bce02779460b449a049165f3bb9fef81105bc285d43", siblings[5].Hex())
} }
func TestVerifyProofCases(t *testing.T) { func TestVerifyProofCases(t *testing.T) {
@ -264,40 +269,40 @@ func TestVerifyProofCases(t *testing.T) {
// Existence proof // Existence proof
proof, err := mt.GenerateProof(big.NewInt(4), nil)
proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, proof.Existence, true) assert.Equal(t, proof.Existence, true)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0))) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
assert.Equal(t, "000300000000000000000000000000000000000000000000000000000000000728ea2b267d2a9436657f20b5827285175e030f58c07375535106903b16621630b9104d995843c7cffa685009a1b28dcd371022a3b27b3a4d6987f7c8b39b0f2fffc165330710754ca0fc24451bdd5d5f82a05f42f1427fbdf17879c0b84be60f", hex.EncodeToString(proof.Bytes()))
assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e", hex.EncodeToString(proof.Bytes()))
for i := 8; i < 32; i++ { for i := 8; i < 32; i++ {
proof, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
assert.Nil(t, err) assert.Nil(t, err)
if debug { if debug {
fmt.Println(i, proof) fmt.Println(i, proof)
} }
} }
// Non-existence proof, empty aux // Non-existence proof, empty aux
proof, err = mt.GenerateProof(big.NewInt(12), nil)
proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, proof.Existence, false) assert.Equal(t, proof.Existence, false)
// assert.True(t, proof.nodeAux == nil) // assert.True(t, proof.nodeAux == nil)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0))) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
assert.Equal(t, "030300000000000000000000000000000000000000000000000000000000000728ea2b267d2a9436657f20b5827285175e030f58c07375535106903b16621630b9104d995843c7cffa685009a1b28dcd371022a3b27b3a4d6987f7c8b39b0f2fffc165330710754ca0fc24451bdd5d5f82a05f42f1427fbdf17879c0b84be60f04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
// Non-existence proof, diff. node aux // Non-existence proof, diff. node aux
proof, err = mt.GenerateProof(big.NewInt(10), nil)
proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, proof.Existence, false) assert.Equal(t, proof.Existence, false)
assert.True(t, proof.NodeAux != nil) assert.True(t, proof.NodeAux != nil)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0))) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
assert.Equal(t, "030300000000000000000000000000000000000000000000000000000000000728ea2b267d2a9436657f20b5827285175e030f58c07375535106903b1662163097fcf8f911b271df196e0a75667b8a4f3024ef39f87201ed2b7cda349ba202296b7aeba35dc19ab0d4f65e175536c9952a90b6de18c3205611c3cd4fb408f01602000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730e667e2ca15909c4a23beff18e3cc74348fbd3c1a4c765a5bbbca126c9607a42b77e008a73926f1280f8531b139dc1cacf8d83fcec31d405f5c51b7cbddfe152902000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
} }
func TestVerifyProofFalse(t *testing.T) { func TestVerifyProofFalse(t *testing.T) {
@ -312,7 +317,7 @@ func TestVerifyProofFalse(t *testing.T) {
// Invalid existence proof (node used for verification doesn't // Invalid existence proof (node used for verification doesn't
// correspond to node in the proof) // correspond to node in the proof)
proof, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -320,7 +325,7 @@ func TestVerifyProofFalse(t *testing.T) {
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5)))) assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
// Invalid non-existence proof (Non-existence proof, diff. node aux) // Invalid non-existence proof (Non-existence proof, diff. node aux)
proof, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -347,23 +352,23 @@ func TestGraphViz(t *testing.T) {
expected := `digraph hierarchy { expected := `digraph hierarchy {
node [fontname=Monospace,fontsize=10,shape=box] node [fontname=Monospace,fontsize=10,shape=box]
"60195538..." -> {"19759736..." "18893277..."}
"19759736..." -> {"16152312..." "43945008..."}
"16152312..." -> {"empty0" "13952255..."}
"16053348..." -> {"19137630..." "14119616..."}
"19137630..." -> {"19543983..." "19746229..."}
"19543983..." -> {"empty0" "65773153..."}
"empty0" [style=dashed,label=0]; "empty0" [style=dashed,label=0];
"13952255..." -> {"61769925..." "empty1"}
"65773153..." -> {"73498412..." "empty1"}
"empty1" [style=dashed,label=0]; "empty1" [style=dashed,label=0];
"61769925..." -> {"92723289..." "empty2"}
"73498412..." -> {"53169236..." "empty2"}
"empty2" [style=dashed,label=0]; "empty2" [style=dashed,label=0];
"92723289..." -> {"21082735..." "82784818..."}
"21082735..." [style=filled];
"82784818..." [style=filled];
"43945008..." [style=filled];
"18893277..." -> {"19855703..." "17718670..."}
"19855703..." -> {"11499909..." "15828714..."}
"11499909..." [style=filled];
"15828714..." [style=filled];
"17718670..." [style=filled];
"53169236..." -> {"73522717..." "34811870..."}
"73522717..." [style=filled];
"34811870..." [style=filled];
"19746229..." [style=filled];
"14119616..." -> {"19419204..." "15569531..."}
"19419204..." -> {"78154875..." "34589916..."}
"78154875..." [style=filled];
"34589916..." [style=filled];
"15569531..." [style=filled];
} }
` `
w := bytes.NewBufferString("") w := bytes.NewBufferString("")
@ -379,22 +384,22 @@ func TestDelete(t *testing.T) {
// test vectors generated using https://github.com/iden3/circomlib smt.js // test vectors generated using https://github.com/iden3/circomlib smt.js
err = mt.Add(big.NewInt(1), big.NewInt(2)) err = mt.Add(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String())
assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String())
err = mt.Add(big.NewInt(33), big.NewInt(44)) err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
err = mt.Add(big.NewInt(1234), big.NewInt(9876)) err = mt.Add(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String())
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
// mt.PrintGraphViz(nil) // mt.PrintGraphViz(nil)
err = mt.Delete(big.NewInt(33)) err = mt.Delete(big.NewInt(33))
// mt.PrintGraphViz(nil) // mt.PrintGraphViz(nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "12820263606494630162816839760750120928463716794691735985748071431547370997091", mt.Root().BigInt().String())
assert.Equal(t, "16195585003843604118922861401064871511855368913846540536604351220077317790615", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1234)) err = mt.Delete(big.NewInt(1234))
assert.Nil(t, err) assert.Nil(t, err)
@ -448,10 +453,10 @@ func TestDelete3(t *testing.T) {
err = mt.Add(big.NewInt(2), big.NewInt(2)) err = mt.Add(big.NewInt(2), big.NewInt(2))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "2427629547967522489273866134471574861207714751136138191708011221765688788661", mt.Root().BigInt().String())
assert.Equal(t, "6701939280963330813043570145125351311131831356446202146710280245621673558344", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1)) err = mt.Delete(big.NewInt(1))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10822920717809411688334493481050035035708810950159417482558569847174767667301", mt.Root().BigInt().String())
assert.Equal(t, "10304354743004778619823249005484018655542356856535590307973732141291410579841", mt.Root().BigInt().String())
mt2 := newTestingMerkle(t, 140) mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close() defer mt2.db.Close()
@ -473,10 +478,10 @@ func TestDelete4(t *testing.T) {
err = mt.Add(big.NewInt(3), big.NewInt(3)) err = mt.Add(big.NewInt(3), big.NewInt(3))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "16614298246517994771186095530428786749320098419259206061045083278756632941513", mt.Root().BigInt().String())
assert.Equal(t, "6989694633650442615746486460134957295274675622748484439660143938730686550248", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1)) err = mt.Delete(big.NewInt(1))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "6117330520107511783353383870014397665359816230889739699667943862706617498952", mt.Root().BigInt().String())
assert.Equal(t, "1192610901536912535888866440319084773171371421781091005185759505381507049136", mt.Root().BigInt().String())
mt2 := newTestingMerkle(t, 140) mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close() defer mt2.db.Close()
@ -495,11 +500,11 @@ func TestDelete5(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
err = mt.Add(big.NewInt(33), big.NewInt(44)) err = mt.Add(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
err = mt.Delete(big.NewInt(1)) err = mt.Delete(big.NewInt(1))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "12075524681474630909546786277734445038384732558409197537058769521806571391765", mt.Root().BigInt().String())
assert.Equal(t, "12802904154263054831102426711825443668153853847661287611768065280921698471037", mt.Root().BigInt().String())
mt2 := newTestingMerkle(t, 140) mt2 := newTestingMerkle(t, 140)
defer mt2.db.Close() defer mt2.db.Close()
@ -554,3 +559,72 @@ func TestDumpLeafsImportLeafs(t *testing.T) {
assert.Equal(t, mt.Root(), mt2.Root()) assert.Equal(t, mt.Root(), mt2.Root())
} }
func TestAddAndGetCircomProof(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "0", cpp.OldRoot.String())
assert.Equal(t, "64497120...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "1", cpp.NewKey.String())
assert.Equal(t, "2", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "64497120...", cpp.OldRoot.String())
assert.Equal(t, "11404118...", cpp.NewRoot.String())
assert.Equal(t, "1", cpp.OldKey.String())
assert.Equal(t, "2", cpp.OldValue.String())
assert.Equal(t, "33", cpp.NewKey.String())
assert.Equal(t, "44", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
assert.Nil(t, err)
assert.Equal(t, "11404118...", cpp.OldRoot.String())
assert.Equal(t, "18284203...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "55", cpp.NewKey.String())
assert.Equal(t, "66", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 42948778... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
// fmt.Println(cpp)
}
func TestUpdateCircomProcessorProof(t *testing.T) {
mt := newTestingMerkle(t, 10)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
assert.Equal(t, "14895645...", cpp.OldRoot.String())
assert.Equal(t, "75223641...", cpp.NewRoot.String())
assert.Equal(t, "10", cpp.OldKey.String())
assert.Equal(t, "20", cpp.OldValue.String())
assert.Equal(t, "10", cpp.NewKey.String())
assert.Equal(t, "1024", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t, "[19625419... 46910949... 18399594... 20473908... 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
}

+ 9
- 33
utils.go

@ -1,58 +1,34 @@
package merkletree package merkletree
import ( import (
"fmt"
"math/big" "math/big"
"github.com/iden3/go-iden3-crypto/poseidon" "github.com/iden3/go-iden3-crypto/poseidon"
) )
// HashElems performs a poseidon hash over the array of ElemBytes.
// Uses poseidon.Hash to be compatible with the circom circuits
// implementations.
// The maxim slice input size is poseidon.T
// HashElems performs a poseidon hash over the array of ElemBytes, currently we
// are using 2 elements. Uses poseidon.Hash to be compatible with the circom
// circuits implementations.
func HashElems(elems ...*big.Int) (*Hash, error) { func HashElems(elems ...*big.Int) (*Hash, error) {
if len(elems) > poseidon.T {
return nil, fmt.Errorf("HashElems input can not be bigger than %v", poseidon.T)
}
bi, err := BigIntsToPoseidonInput(elems...)
if err != nil {
return nil, err
}
poseidonHash, err := poseidon.Hash(bi)
poseidonHash, err := poseidon.Hash(elems)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewHashFromBigInt(poseidonHash), nil return NewHashFromBigInt(poseidonHash), nil
} }
// HashElemsKey performs a poseidon hash over the array of ElemBytes.
// HashElemsKey performs a poseidon hash over the array of ElemBytes, currently
// we are using 2 elements.
func HashElemsKey(key *big.Int, elems ...*big.Int) (*Hash, error) { func HashElemsKey(key *big.Int, elems ...*big.Int) (*Hash, error) {
if len(elems) > poseidon.T-1 {
return nil, fmt.Errorf("HashElemsKey input can not be bigger than %v", poseidon.T-1)
}
if key == nil { if key == nil {
key = new(big.Int).SetInt64(0) key = new(big.Int).SetInt64(0)
} }
bi, err := BigIntsToPoseidonInput(elems...)
if err != nil {
return nil, err
}
copy(bi[len(elems):], []*big.Int{key})
bi := make([]*big.Int, 3)
copy(bi[:], elems)
bi[2] = key
poseidonHash, err := poseidon.Hash(bi) poseidonHash, err := poseidon.Hash(bi)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewHashFromBigInt(poseidonHash), nil return NewHashFromBigInt(poseidonHash), nil
} }
// BigIntsToPoseidonInput takes *big.Ints and returns a fixed-length array of the size `poseidon.T`
func BigIntsToPoseidonInput(bigints ...*big.Int) ([poseidon.T]*big.Int, error) {
z := big.NewInt(0)
b := [poseidon.T]*big.Int{z, z, z, z, z, z}
copy(b[:poseidon.T], bigints[:])
return b, nil
}

Loading…
Cancel
Save