diff --git a/merkletree.go b/merkletree.go index 4d34775..b619e17 100644 --- a/merkletree.go +++ b/merkletree.go @@ -71,6 +71,9 @@ func (h Hash) Hex() string { // BigInt returns the *big.Int representation of the *Hash func (h *Hash) BigInt() *big.Int { + if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil { + return big.NewInt(0) + } return new(big.Int).SetBytes(common.SwapEndianness(h[:])) } @@ -196,20 +199,22 @@ func (mt *MerkleTree) Add(k, v *big.Int) error { return nil } func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) { - var cp CircomProcessorProof cp.OldRoot = mt.rootKey - gettedV, siblings, err := mt.Get(k) + gettedK, gettedV, siblings, err := mt.Get(k) if err != nil && err != ErrKeyNotFound { return nil, err } - if err == ErrKeyNotFound { - cp.OldKey = &HashZero - cp.OldValue = &HashZero - } else { - cp.OldKey = NewHashFromBigInt(k) - cp.OldValue = NewHashFromBigInt(gettedV) + cp.OldKey = NewHashFromBigInt(gettedK) + cp.OldValue = NewHashFromBigInt(gettedV) + if bytes.Equal(cp.OldKey[:], HashZero[:]) { + cp.IsOld0 = true } + _, _, siblings, err = mt.Get(k) + if err != nil && err != ErrKeyNotFound { + return nil, err + } + cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels) err = mt.Add(k, v) if err != nil { @@ -220,12 +225,6 @@ func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof cp.NewValue = NewHashFromBigInt(v) cp.NewRoot = mt.rootKey - _, siblings, err = mt.Get(k) - if err != nil { - return nil, err - } - cp.Siblings = siblings - return &cp, nil } @@ -343,10 +342,10 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) { } // Get returns the value of the leaf for the given key -func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) { +func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) { // verfy that k is valid and fit inside the Finite Field. if !cryptoUtils.CheckBigIntInField(k) { - return nil, nil, errors.New("Key not inside the Finite Field") + return nil, nil, nil, errors.New("Key not inside the Finite Field") } kHash := NewHashFromBigInt(k) @@ -357,16 +356,16 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) { for i := 0; i < mt.maxLevels; i++ { n, err := mt.GetNode(nextKey) if err != nil { - return nil, nil, err + return nil, nil, nil, err } switch n.Type { case NodeTypeEmpty: - return nil, nil, ErrKeyNotFound + return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound case NodeTypeLeaf: if bytes.Equal(kHash[:], n.Entry[0][:]) { - return n.Entry[1].BigInt(), siblings, nil + return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil } else { - return nil, nil, ErrKeyNotFound + return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound } case NodeTypeMiddle: if path[i] { @@ -377,11 +376,11 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) { siblings = append(siblings, n.ChildR) } default: - return nil, nil, ErrInvalidNodeFound + return nil, nil, nil, ErrInvalidNodeFound } } - return nil, nil, ErrKeyNotFound + return nil, nil, nil, ErrKeyNotFound } // Update updates the value of a specified key in the MerkleTree, and updates @@ -753,6 +752,19 @@ func (p *Proof) AllSiblings() []*Hash { return SiblingsFromProof(p) } +func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash { + // Add the rest of empty levels to the siblings + for i := len(siblings); i < levels; i++ { + siblings = append(siblings, &HashZero) + } + siblings = append(siblings, &HashZero) // add extra level for circom compatibility + return siblings + // siblingsBigInt := make([]*big.Int, len(siblings)) + // for i, sibling := range siblings { + // siblingsBigInt[i] = sibling.BigInt() + // } +} + // AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits. func (p *Proof) AllSiblingsCircom(levels int) []*big.Int { siblings := p.AllSiblings() @@ -774,11 +786,31 @@ type CircomProcessorProof struct { Siblings []*Hash OldKey *Hash OldValue *Hash - IsOld0 bool NewKey *Hash NewValue *Hash + IsOld0 bool // Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete } + +func (p CircomProcessorProof) String() string { + buf := bytes.NewBufferString("{") + fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot) + fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot) + fmt.Fprintf(buf, " Siblings: [\n ") + for _, s := range p.Siblings { + fmt.Fprintf(buf, "%v, ", s) + } + fmt.Fprintf(buf, "\n ],\n") + fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey) + fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue) + fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey) + fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue) + fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0) + fmt.Fprintf(buf, "}\n") + + return buf.String() +} + type CircomVerifierProof struct { Root *Hash Siblings []*big.Int diff --git a/merkletree_test.go b/merkletree_test.go index d63500b..1411dbd 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -123,18 +123,21 @@ func TestGet(t *testing.T) { t.Fatal(err) } } - v, _, err := mt.Get(big.NewInt(10)) + k, v, _, err := mt.Get(big.NewInt(10)) assert.Nil(t, err) + assert.Equal(t, big.NewInt(10), k) assert.Equal(t, big.NewInt(20), v) - v, _, err = mt.Get(big.NewInt(15)) + k, v, _, err = mt.Get(big.NewInt(15)) assert.Nil(t, err) + assert.Equal(t, big.NewInt(15), k) assert.Equal(t, big.NewInt(30), v) - v, _, err = mt.Get(big.NewInt(16)) + k, v, _, err = mt.Get(big.NewInt(16)) assert.NotNil(t, err) assert.Equal(t, ErrKeyNotFound, err) - assert.Nil(t, v) + assert.Equal(t, "0", k.String()) + assert.Equal(t, "0", v.String()) } func TestUpdate(t *testing.T) { @@ -148,13 +151,13 @@ func TestUpdate(t *testing.T) { t.Fatal(err) } } - v, _, err := mt.Get(big.NewInt(10)) + _, v, _, err := mt.Get(big.NewInt(10)) assert.Nil(t, err) assert.Equal(t, big.NewInt(20), v) _, err = mt.Update(big.NewInt(10), big.NewInt(1024)) assert.Nil(t, err) - v, _, err = mt.Get(big.NewInt(10)) + _, v, _, err = mt.Get(big.NewInt(10)) assert.Nil(t, err) assert.Equal(t, big.NewInt(1024), v) @@ -563,22 +566,37 @@ func TestAddAndGetCircomProof(t *testing.T) { assert.Equal(t, "0", mt.Root().String()) // test vectors generated using https://github.com/iden3/circomlib smt.js - _, err = mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2)) - assert.Nil(t, err) - assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String()) - - _, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44)) - assert.Nil(t, err) - assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String()) - - _, err = mt.AddAndGetCircomProof(big.NewInt(1234), big.NewInt(9876)) - assert.Nil(t, err) - assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String()) - - proof, v, err := mt.GenerateProof(big.NewInt(33), nil) - assert.Nil(t, err) - assert.Equal(t, big.NewInt(44), v) - - assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44))) - assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45))) + cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2)) + assert.Nil(t, err) + assert.Equal(t, "0", cpp.OldRoot.String()) + assert.Equal(t, "49322979...", cpp.NewRoot.String()) + assert.Equal(t, "0", cpp.OldKey.String()) + assert.Equal(t, "0", cpp.OldValue.String()) + assert.Equal(t, "1", cpp.NewKey.String()) + assert.Equal(t, "2", cpp.NewValue.String()) + assert.Equal(t, true, cpp.IsOld0) + assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) + + cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44)) + assert.Nil(t, err) + assert.Equal(t, "49322979...", cpp.OldRoot.String()) + assert.Equal(t, "13563340...", cpp.NewRoot.String()) + assert.Equal(t, "1", cpp.OldKey.String()) + assert.Equal(t, "2", cpp.OldValue.String()) + assert.Equal(t, "33", cpp.NewKey.String()) + assert.Equal(t, "44", cpp.NewValue.String()) + assert.Equal(t, false, cpp.IsOld0) + assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) + + cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66)) + assert.Nil(t, err) + assert.Equal(t, "13563340...", cpp.OldRoot.String()) + assert.Equal(t, "21716426...", cpp.NewRoot.String()) + assert.Equal(t, "0", cpp.OldKey.String()) + assert.Equal(t, "0", cpp.OldValue.String()) + assert.Equal(t, "55", cpp.NewKey.String()) + assert.Equal(t, "66", cpp.NewValue.String()) + assert.Equal(t, true, cpp.IsOld0) + assert.Equal(t, "[0 34319575... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings)) + // fmt.Println(cpp) }