diff --git a/merkletree.go b/merkletree.go index d90a537..29aaf67 100644 --- a/merkletree.go +++ b/merkletree.go @@ -459,7 +459,7 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) { path := getPath(mt.maxLevels, kHash[:]) nextKey := mt.rootKey - var siblings []*Hash + siblings := []*Hash{} for i := 0; i < mt.maxLevels; i++ { n, err := mt.GetNode(nextKey) if err != nil { @@ -524,7 +524,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) { cp.NewValue = vHash nextKey := mt.rootKey - var siblings []*Hash + siblings := []*Hash{} for i := 0; i < mt.maxLevels; i++ { n, err := mt.GetNode(nextKey) if err != nil { @@ -607,7 +607,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error { path := getPath(mt.maxLevels, kHash[:]) nextKey := mt.rootKey - var siblings []*Hash + siblings := []*Hash{} for i := 0; i < mt.maxLevels; i++ { n, err := mt.GetNode(nextKey) if err != nil { @@ -842,7 +842,7 @@ func (p *Proof) Bytes() []byte { // SiblingsFromProof returns all the siblings of the proof. func SiblingsFromProof(proof *Proof) []*Hash { sibIdx := 0 - var siblings []*Hash + siblings := []*Hash{} for lvl := 0; lvl < int(proof.depth); lvl++ { if TestBitBigEndian(proof.notempties[:], uint(lvl)) { siblings = append(siblings, proof.Siblings[sibIdx]) diff --git a/merkletree_test.go b/merkletree_test.go index 2fa504f..fc481d0 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -712,6 +712,15 @@ func TestSmtVerifier(t *testing.T) { err = mt.Add(big.NewInt(1), big.NewInt(11)) assert.Nil(t, err) + + cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil) + assert.Nil(t, err) + jCvp, err := json.Marshal(cvp) + assert.Nil(t, err) + // expect siblings to be '[]', instead of 'null' + expected := `{"root":"14137057030252181222327992235694793580963111268072013054745223667806564674729","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll + + assert.Equal(t, expected, string(jCvp)) err = mt.Add(big.NewInt(2), big.NewInt(22)) assert.Nil(t, err) err = mt.Add(big.NewInt(3), big.NewInt(33)) @@ -719,14 +728,14 @@ func TestSmtVerifier(t *testing.T) { err = mt.Add(big.NewInt(4), big.NewInt(44)) assert.Nil(t, err) - cvp, err := mt.GenerateCircomVerifierProof(big.NewInt(2), nil) + cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil) assert.Nil(t, err) - jCvp, err := json.Marshal(cvp) + jCvp, err = json.Marshal(cvp) assert.Nil(t, err) // Test vectors generated using https://github.com/iden3/circomlib smt.js // Expect siblings with the extra 0 that the circom circuits need - expected := `{"root":"10171140035965439966839815283432442651152991056297946102647688349369299124493","siblings":["12422661758472400223401299094238820777063458096110016599986781158438915645129","4330149052063565277182642012557086942088176847773467265587998154672740895682","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll + expected = `{"root":"10171140035965439966839815283432442651152991056297946102647688349369299124493","siblings":["12422661758472400223401299094238820777063458096110016599986781158438915645129","4330149052063565277182642012557086942088176847773467265587998154672740895682","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll assert.Equal(t, expected, string(jCvp)) cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)