mt.Update CircomProof generator w/ test

This commit is contained in:
arnaucube
2020-08-09 09:32:20 +02:00
parent 7b7b9a12fc
commit 1f1bd54b93
2 changed files with 30 additions and 1 deletions

View File

@@ -198,6 +198,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return nil return nil
} }
func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) { func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
var cp CircomProcessorProof var cp CircomProcessorProof
cp.OldRoot = mt.rootKey cp.OldRoot = mt.rootKey
@@ -428,7 +429,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
case NodeTypeLeaf: case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) { if bytes.Equal(kHash[:], n.Entry[0][:]) {
cp.OldValue = n.Entry[1] cp.OldValue = n.Entry[1]
cp.Siblings = siblings cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
// update leaf and upload to the root // update leaf and upload to the root
newNodeLeaf := NewNodeLeaf(kHash, vHash) newNodeLeaf := NewNodeLeaf(kHash, vHash)
_, err := mt.addNode(tx, newNodeLeaf) _, err := mt.addNode(tx, newNodeLeaf)

View File

@@ -600,3 +600,31 @@ func TestAddAndGetCircomProof(t *testing.T) {
assert.Equal(t, "[0 34319575... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) assert.Equal(t, "[0 34319575... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
// fmt.Println(cpp) // fmt.Println(cpp)
} }
func TestUpdateCircomProcessorProof(t *testing.T) {
mt := newTestingMerkle(t, 10)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
// test vectors generated using https://github.com/iden3/circomlib smt.js
cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err)
assert.Equal(t, "57072083...", cpp.OldRoot.String())
assert.Equal(t, "11191558...", cpp.NewRoot.String())
assert.Equal(t, "10", cpp.OldKey.String())
assert.Equal(t, "20", cpp.OldValue.String())
assert.Equal(t, "10", cpp.NewKey.String())
assert.Equal(t, "1024", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t, "[12331503... 70994311... 88639181... 20174344... 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
}