From 1f1bd54b93ae90eab266ceea14a368de6ba79bb0 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sun, 9 Aug 2020 09:32:20 +0200 Subject: [PATCH] mt.Update CircomProof generator w/ test --- merkletree.go | 3 ++- merkletree_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/merkletree.go b/merkletree.go index b619e17..dfd927b 100644 --- a/merkletree.go +++ b/merkletree.go @@ -198,6 +198,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error { return nil } + func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) { var cp CircomProcessorProof cp.OldRoot = mt.rootKey @@ -428,7 +429,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) { case NodeTypeLeaf: if bytes.Equal(kHash[:], n.Entry[0][:]) { cp.OldValue = n.Entry[1] - cp.Siblings = siblings + cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels) // update leaf and upload to the root newNodeLeaf := NewNodeLeaf(kHash, vHash) _, err := mt.addNode(tx, newNodeLeaf) diff --git a/merkletree_test.go b/merkletree_test.go index 1411dbd..ecabfce 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -600,3 +600,31 @@ func TestAddAndGetCircomProof(t *testing.T) { assert.Equal(t, "[0 34319575... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) // fmt.Println(cpp) } + +func TestUpdateCircomProcessorProof(t *testing.T) { + mt := newTestingMerkle(t, 10) + defer mt.db.Close() + + for i := 0; i < 16; i++ { + k := big.NewInt(int64(i)) + v := big.NewInt(int64(i * 2)) + if err := mt.Add(k, v); err != nil { + t.Fatal(err) + } + } + _, v, _, err := mt.Get(big.NewInt(10)) + assert.Nil(t, err) + assert.Equal(t, big.NewInt(20), v) + + // test vectors generated using https://github.com/iden3/circomlib smt.js + cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024)) + assert.Nil(t, err) + assert.Equal(t, "57072083...", cpp.OldRoot.String()) + assert.Equal(t, "11191558...", cpp.NewRoot.String()) + assert.Equal(t, "10", cpp.OldKey.String()) + assert.Equal(t, "20", cpp.OldValue.String()) + assert.Equal(t, "10", cpp.NewKey.String()) + assert.Equal(t, "1024", cpp.NewValue.String()) + assert.Equal(t, false, cpp.IsOld0) + assert.Equal(t, "[12331503... 70994311... 88639181... 20174344... 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) +}