diff --git a/go.sum b/go.sum index f36ef43..4e72d12 100644 --- a/go.sum +++ b/go.sum @@ -110,7 +110,6 @@ github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2-0.20190517061210-b285ee9cfc6c/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf h1:gFVkHXmVAhEbxZVDln5V9GKrLaluNoFHDbrZwAWZgws= github.com/golang/snappy v0.0.2-0.20190904063534-ff6b7dc882cf/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -132,7 +131,6 @@ github.com/iden3/go-circom-prover-verifier v0.0.1/go.mod h1:1FkpX4nUXxYcY2fpzqd2 github.com/iden3/go-circom-witnesscalc v0.0.1/go.mod h1:xjT1BlFZDBioHOlbD75SmZZLC1d1AfOycqbSa/1QRJU= github.com/iden3/go-iden3-core v0.0.8 h1:PLw7iCiX7Pw1dqBkR+JaLQWqB5RKd+vgu25UBdvFXGQ= github.com/iden3/go-iden3-core v0.0.8/go.mod h1:URNjIhMql6sEbWubIGrjJdw5wHCE1Pk1XghxjBOtA3s= -github.com/iden3/go-iden3-crypto v0.0.5 h1:inCSm5a+ry+nbpVTL/9+m6UcIwSv6nhUm0tnIxEbcps= github.com/iden3/go-iden3-crypto v0.0.5/go.mod h1:XKw1oDwYn2CIxKOtr7m/mL5jMn4mLOxAxtZBRxQBev8= github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf h1:/7L5dEqctuzJY2g8OEQct+1Y+n2sMKyd4JoYhw2jy1s= github.com/iden3/go-iden3-crypto v0.0.6-0.20200723082457-29a66457f0bf/go.mod h1:XKw1oDwYn2CIxKOtr7m/mL5jMn4mLOxAxtZBRxQBev8= @@ -232,7 +230,6 @@ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -321,7 +318,6 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/merkletree.go b/merkletree.go index 6167ff7..4d34775 100644 --- a/merkletree.go +++ b/merkletree.go @@ -139,6 +139,11 @@ func (mt *MerkleTree) Root() *Hash { return mt.rootKey } +// MaxLevels returns the MT maximum level +func (mt *MerkleTree) MaxLevels() int { + return mt.maxLevels +} + // Snapshot returns a read-only copy of the MerkleTree func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) { mt.RLock() @@ -190,6 +195,39 @@ func (mt *MerkleTree) Add(k, v *big.Int) error { return nil } +func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) { + + var cp CircomProcessorProof + cp.OldRoot = mt.rootKey + gettedV, siblings, err := mt.Get(k) + if err != nil && err != ErrKeyNotFound { + return nil, err + } + if err == ErrKeyNotFound { + cp.OldKey = &HashZero + cp.OldValue = &HashZero + } else { + cp.OldKey = NewHashFromBigInt(k) + cp.OldValue = NewHashFromBigInt(gettedV) + } + + err = mt.Add(k, v) + if err != nil { + return nil, err + } + + cp.NewKey = NewHashFromBigInt(k) + cp.NewValue = NewHashFromBigInt(v) + cp.NewRoot = mt.rootKey + + _, siblings, err = mt.Get(k) + if err != nil { + return nil, err + } + cp.Siblings = siblings + + return &cp, nil +} // pushLeaf recursively pushes an existing oldLeaf down until its path diverges // from newLeaf, at which point both leafs are stored, all while updating the @@ -305,62 +343,65 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) { } // Get returns the value of the leaf for the given key -func (mt *MerkleTree) Get(k *big.Int) (*big.Int, error) { +func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) { // verfy that k is valid and fit inside the Finite Field. if !cryptoUtils.CheckBigIntInField(k) { - return nil, errors.New("Key not inside the Finite Field") + return nil, nil, errors.New("Key not inside the Finite Field") } kHash := NewHashFromBigInt(k) path := getPath(mt.maxLevels, kHash[:]) nextKey := mt.rootKey + var siblings []*Hash for i := 0; i < mt.maxLevels; i++ { n, err := mt.GetNode(nextKey) if err != nil { - return nil, err + return nil, nil, err } switch n.Type { case NodeTypeEmpty: - return nil, ErrKeyNotFound + return nil, nil, ErrKeyNotFound case NodeTypeLeaf: if bytes.Equal(kHash[:], n.Entry[0][:]) { - return n.Entry[1].BigInt(), nil + return n.Entry[1].BigInt(), siblings, nil } else { - return nil, ErrKeyNotFound + return nil, nil, ErrKeyNotFound } case NodeTypeMiddle: if path[i] { nextKey = n.ChildR + siblings = append(siblings, n.ChildL) } else { nextKey = n.ChildL + siblings = append(siblings, n.ChildR) } default: - return nil, ErrInvalidNodeFound + return nil, nil, ErrInvalidNodeFound } } - return nil, ErrKeyNotFound + return nil, nil, ErrKeyNotFound } // Update updates the value of a specified key in the MerkleTree, and updates // the path from the leaf to the Root with the new values. -func (mt *MerkleTree) Update(k, v *big.Int) error { +func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) { // verify that the MerkleTree is writable if !mt.writable { - return ErrNotWritable + return nil, ErrNotWritable } // verfy that k & are valid and fit inside the Finite Field. if !cryptoUtils.CheckBigIntInField(k) { - return errors.New("Key not inside the Finite Field") + return nil, errors.New("Key not inside the Finite Field") } if !cryptoUtils.CheckBigIntInField(v) { - return errors.New("Key not inside the Finite Field") + return nil, errors.New("Key not inside the Finite Field") } tx, err := mt.db.NewTx() if err != nil { - return err + return nil, err } mt.Lock() defer mt.Unlock() @@ -369,33 +410,45 @@ func (mt *MerkleTree) Update(k, v *big.Int) error { vHash := NewHashFromBigInt(v) path := getPath(mt.maxLevels, kHash[:]) + var cp CircomProcessorProof + cp.OldRoot = mt.rootKey + cp.OldKey = kHash + cp.NewKey = kHash + cp.NewValue = vHash + nextKey := mt.rootKey var siblings []*Hash for i := 0; i < mt.maxLevels; i++ { n, err := mt.GetNode(nextKey) if err != nil { - return err + return nil, err } switch n.Type { case NodeTypeEmpty: - return ErrKeyNotFound + return nil, ErrKeyNotFound case NodeTypeLeaf: if bytes.Equal(kHash[:], n.Entry[0][:]) { + cp.OldValue = n.Entry[1] + cp.Siblings = siblings // update leaf and upload to the root newNodeLeaf := NewNodeLeaf(kHash, vHash) _, err := mt.addNode(tx, newNodeLeaf) if err != nil { - return err + return nil, err } newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings) if err != nil { - return err + return nil, err } mt.rootKey = newRootKey mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) - return tx.Commit() + cp.NewRoot = newRootKey + if err := tx.Commit(); err != nil { + return nil, err + } + return &cp, nil } else { - return ErrKeyNotFound + return nil, ErrKeyNotFound } case NodeTypeMiddle: if path[i] { @@ -406,11 +459,11 @@ func (mt *MerkleTree) Update(k, v *big.Int) error { siblings = append(siblings, n.ChildR) } default: - return ErrInvalidNodeFound + return nil, ErrInvalidNodeFound } } - return ErrKeyNotFound + return nil, ErrKeyNotFound } // Delete removes the specified Key from the MerkleTree and updates the path @@ -715,10 +768,54 @@ func (p *Proof) AllSiblingsCircom(levels int) []*big.Int { return siblingsBigInt } +type CircomProcessorProof struct { + OldRoot *Hash + NewRoot *Hash + Siblings []*Hash + OldKey *Hash + OldValue *Hash + IsOld0 bool + NewKey *Hash + NewValue *Hash + // Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete +} +type CircomVerifierProof struct { + Root *Hash + Siblings []*big.Int + OldKey *Hash + OldValue *Hash + IsOld0 bool + Key *Hash + Value *Hash + Fnc int +} + +func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int, rootKey *Hash) (*CircomVerifierProof, error) { + p, v, err := mt.GenerateProof(k, rootKey) + if err != nil || err != ErrKeyNotFound { + return nil, err + } + var cp CircomVerifierProof + cp.Root = mt.rootKey + cp.Siblings = p.AllSiblingsCircom(mt.maxLevels) + cp.OldKey = &HashZero + cp.OldValue = &HashZero + // cp.IsOld + cp.Key = NewHashFromBigInt(k) + cp.Value = NewHashFromBigInt(v) + if p.Existence { + cp.Fnc = 0 // inclusion + } else { + cp.Fnc = 1 // not inclusion + } + + return &cp, nil +} + // GenerateProof generates the proof of existence (or non-existence) of an // Entry's hash Index for a Merkle Tree given the root. // If the rootKey is nil, the current merkletree root is used -func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) { +func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, *big.Int, error) { p := &Proof{} var siblingKey *Hash @@ -731,19 +828,19 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) { for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ { n, err := mt.GetNode(nextKey) if err != nil { - return nil, err + return nil, nil, err } switch n.Type { case NodeTypeEmpty: - return p, nil + return p, big.NewInt(0), nil case NodeTypeLeaf: if bytes.Equal(kHash[:], n.Entry[0][:]) { p.Existence = true - return p, nil + return p, n.Entry[1].BigInt(), nil } else { // We found a leaf whose entry didn't match hIndex p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]} - return p, nil + return p, n.Entry[1].BigInt(), nil } case NodeTypeMiddle: if path[p.depth] { @@ -754,14 +851,14 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) { siblingKey = n.ChildR } default: - return nil, ErrInvalidNodeFound + return nil, nil, ErrInvalidNodeFound } if !bytes.Equal(siblingKey[:], HashZero[:]) { common.SetBitBigEndian(p.notempties[:], uint(p.depth)) p.Siblings = append(p.Siblings, siblingKey) } } - return nil, ErrKeyNotFound + return nil, nil, ErrKeyNotFound } // VerifyProof verifies the Merkle Proof for the entry and root. diff --git a/merkletree_test.go b/merkletree_test.go index 7e15f8a..d63500b 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -66,8 +66,9 @@ func TestNewTree(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String()) - proof, err := mt.GenerateProof(big.NewInt(33), nil) + proof, v, err := mt.GenerateProof(big.NewInt(33), nil) assert.Nil(t, err) + assert.Equal(t, big.NewInt(44), v) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44))) assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45))) @@ -122,15 +123,15 @@ func TestGet(t *testing.T) { t.Fatal(err) } } - v, err := mt.Get(big.NewInt(10)) + v, _, err := mt.Get(big.NewInt(10)) assert.Nil(t, err) assert.Equal(t, big.NewInt(20), v) - v, err = mt.Get(big.NewInt(15)) + v, _, err = mt.Get(big.NewInt(15)) assert.Nil(t, err) assert.Equal(t, big.NewInt(30), v) - v, err = mt.Get(big.NewInt(16)) + v, _, err = mt.Get(big.NewInt(16)) assert.NotNil(t, err) assert.Equal(t, ErrKeyNotFound, err) assert.Nil(t, v) @@ -147,17 +148,17 @@ func TestUpdate(t *testing.T) { t.Fatal(err) } } - v, err := mt.Get(big.NewInt(10)) + v, _, err := mt.Get(big.NewInt(10)) assert.Nil(t, err) assert.Equal(t, big.NewInt(20), v) - err = mt.Update(big.NewInt(10), big.NewInt(1024)) + _, err = mt.Update(big.NewInt(10), big.NewInt(1024)) assert.Nil(t, err) - v, err = mt.Get(big.NewInt(10)) + v, _, err = mt.Get(big.NewInt(10)) assert.Nil(t, err) assert.Equal(t, big.NewInt(1024), v) - err = mt.Update(big.NewInt(1000), big.NewInt(1024)) + _, err = mt.Update(big.NewInt(1000), big.NewInt(1024)) assert.Equal(t, ErrKeyNotFound, err) } @@ -181,11 +182,11 @@ func TestUpdate2(t *testing.T) { err = mt2.Add(big.NewInt(9876), big.NewInt(10)) assert.Nil(t, err) - err = mt1.Update(big.NewInt(1), big.NewInt(11)) + _, err = mt1.Update(big.NewInt(1), big.NewInt(11)) assert.Nil(t, err) - err = mt1.Update(big.NewInt(2), big.NewInt(22)) + _, err = mt1.Update(big.NewInt(2), big.NewInt(22)) assert.Nil(t, err) - err = mt2.Update(big.NewInt(9876), big.NewInt(6789)) + _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789)) assert.Nil(t, err) assert.Equal(t, mt1.Root(), mt2.Root()) @@ -203,8 +204,9 @@ func TestGenerateAndVerifyProof128(t *testing.T) { t.Fatal(err) } } - proof, err := mt.GenerateProof(big.NewInt(42), nil) + proof, v, err := mt.GenerateProof(big.NewInt(42), nil) assert.Nil(t, err) + assert.Equal(t, "0", v.String()) assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0))) } @@ -237,7 +239,7 @@ func TestSiblingsFromProof(t *testing.T) { } } - proof, err := mt.GenerateProof(big.NewInt(4), nil) + proof, _, err := mt.GenerateProof(big.NewInt(4), nil) if err != nil { t.Fatal(err) } @@ -264,7 +266,7 @@ func TestVerifyProofCases(t *testing.T) { // Existence proof - proof, err := mt.GenerateProof(big.NewInt(4), nil) + proof, _, err := mt.GenerateProof(big.NewInt(4), nil) if err != nil { t.Fatal(err) } @@ -273,14 +275,14 @@ func TestVerifyProofCases(t *testing.T) { assert.Equal(t, "000300000000000000000000000000000000000000000000000000000000000728ea2b267d2a9436657f20b5827285175e030f58c07375535106903b16621630b9104d995843c7cffa685009a1b28dcd371022a3b27b3a4d6987f7c8b39b0f2fffc165330710754ca0fc24451bdd5d5f82a05f42f1427fbdf17879c0b84be60f", hex.EncodeToString(proof.Bytes())) for i := 8; i < 32; i++ { - proof, err = mt.GenerateProof(big.NewInt(int64(i)), nil) + proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil) assert.Nil(t, err) if debug { fmt.Println(i, proof) } } // Non-existence proof, empty aux - proof, err = mt.GenerateProof(big.NewInt(12), nil) + proof, _, err = mt.GenerateProof(big.NewInt(12), nil) if err != nil { t.Fatal(err) } @@ -290,7 +292,7 @@ func TestVerifyProofCases(t *testing.T) { assert.Equal(t, "030300000000000000000000000000000000000000000000000000000000000728ea2b267d2a9436657f20b5827285175e030f58c07375535106903b16621630b9104d995843c7cffa685009a1b28dcd371022a3b27b3a4d6987f7c8b39b0f2fffc165330710754ca0fc24451bdd5d5f82a05f42f1427fbdf17879c0b84be60f04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) // Non-existence proof, diff. node aux - proof, err = mt.GenerateProof(big.NewInt(10), nil) + proof, _, err = mt.GenerateProof(big.NewInt(10), nil) if err != nil { t.Fatal(err) } @@ -312,7 +314,7 @@ func TestVerifyProofFalse(t *testing.T) { // Invalid existence proof (node used for verification doesn't // correspond to node in the proof) - proof, err := mt.GenerateProof(big.NewInt(int64(4)), nil) + proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil) if err != nil { t.Fatal(err) } @@ -320,7 +322,7 @@ func TestVerifyProofFalse(t *testing.T) { assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5)))) // Invalid non-existence proof (Non-existence proof, diff. node aux) - proof, err = mt.GenerateProof(big.NewInt(int64(4)), nil) + proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil) if err != nil { t.Fatal(err) } @@ -554,3 +556,29 @@ func TestDumpLeafsImportLeafs(t *testing.T) { assert.Equal(t, mt.Root(), mt2.Root()) } + +func TestAddAndGetCircomProof(t *testing.T) { + mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10) + assert.Nil(t, err) + assert.Equal(t, "0", mt.Root().String()) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + _, err = mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String()) + + _, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String()) + + _, err = mt.AddAndGetCircomProof(big.NewInt(1234), big.NewInt(9876)) + assert.Nil(t, err) + assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String()) + + proof, v, err := mt.GenerateProof(big.NewInt(33), nil) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(44), v) + + assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44))) + assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45))) +}