Browse Source

AddAndGetCircomProof generating CircomProcessorProof

circomproofs
arnaucube 4 years ago
parent
commit
7b7b9a12fc
2 changed files with 97 additions and 47 deletions
  1. +55
    -23
      merkletree.go
  2. +42
    -24
      merkletree_test.go

+ 55
- 23
merkletree.go

@ -71,6 +71,9 @@ func (h Hash) Hex() string {
// BigInt returns the *big.Int representation of the *Hash // BigInt returns the *big.Int representation of the *Hash
func (h *Hash) BigInt() *big.Int { func (h *Hash) BigInt() *big.Int {
if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
return big.NewInt(0)
}
return new(big.Int).SetBytes(common.SwapEndianness(h[:])) return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
} }
@ -196,20 +199,22 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return nil return nil
} }
func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) { func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
var cp CircomProcessorProof var cp CircomProcessorProof
cp.OldRoot = mt.rootKey cp.OldRoot = mt.rootKey
gettedV, siblings, err := mt.Get(k)
gettedK, gettedV, siblings, err := mt.Get(k)
if err != nil && err != ErrKeyNotFound { if err != nil && err != ErrKeyNotFound {
return nil, err return nil, err
} }
if err == ErrKeyNotFound {
cp.OldKey = &HashZero
cp.OldValue = &HashZero
} else {
cp.OldKey = NewHashFromBigInt(k)
cp.OldValue = NewHashFromBigInt(gettedV)
cp.OldKey = NewHashFromBigInt(gettedK)
cp.OldValue = NewHashFromBigInt(gettedV)
if bytes.Equal(cp.OldKey[:], HashZero[:]) {
cp.IsOld0 = true
} }
_, _, siblings, err = mt.Get(k)
if err != nil && err != ErrKeyNotFound {
return nil, err
}
cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
err = mt.Add(k, v) err = mt.Add(k, v)
if err != nil { if err != nil {
@ -220,12 +225,6 @@ func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof
cp.NewValue = NewHashFromBigInt(v) cp.NewValue = NewHashFromBigInt(v)
cp.NewRoot = mt.rootKey cp.NewRoot = mt.rootKey
_, siblings, err = mt.Get(k)
if err != nil {
return nil, err
}
cp.Siblings = siblings
return &cp, nil return &cp, nil
} }
@ -343,10 +342,10 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
} }
// Get returns the value of the leaf for the given key // Get returns the value of the leaf for the given key
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) {
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
// verfy that k is valid and fit inside the Finite Field. // verfy that k is valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) { if !cryptoUtils.CheckBigIntInField(k) {
return nil, nil, errors.New("Key not inside the Finite Field")
return nil, nil, nil, errors.New("Key not inside the Finite Field")
} }
kHash := NewHashFromBigInt(k) kHash := NewHashFromBigInt(k)
@ -357,16 +356,16 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) {
for i := 0; i < mt.maxLevels; i++ { for i := 0; i < mt.maxLevels; i++ {
n, err := mt.GetNode(nextKey) n, err := mt.GetNode(nextKey)
if err != nil { if err != nil {
return nil, nil, err
return nil, nil, nil, err
} }
switch n.Type { switch n.Type {
case NodeTypeEmpty: case NodeTypeEmpty:
return nil, nil, ErrKeyNotFound
return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
case NodeTypeLeaf: case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) { if bytes.Equal(kHash[:], n.Entry[0][:]) {
return n.Entry[1].BigInt(), siblings, nil
return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
} else { } else {
return nil, nil, ErrKeyNotFound
return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
} }
case NodeTypeMiddle: case NodeTypeMiddle:
if path[i] { if path[i] {
@ -377,11 +376,11 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) {
siblings = append(siblings, n.ChildR) siblings = append(siblings, n.ChildR)
} }
default: default:
return nil, nil, ErrInvalidNodeFound
return nil, nil, nil, ErrInvalidNodeFound
} }
} }
return nil, nil, ErrKeyNotFound
return nil, nil, nil, ErrKeyNotFound
} }
// Update updates the value of a specified key in the MerkleTree, and updates // Update updates the value of a specified key in the MerkleTree, and updates
@ -753,6 +752,19 @@ func (p *Proof) AllSiblings() []*Hash {
return SiblingsFromProof(p) return SiblingsFromProof(p)
} }
func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
// Add the rest of empty levels to the siblings
for i := len(siblings); i < levels; i++ {
siblings = append(siblings, &HashZero)
}
siblings = append(siblings, &HashZero) // add extra level for circom compatibility
return siblings
// siblingsBigInt := make([]*big.Int, len(siblings))
// for i, sibling := range siblings {
// siblingsBigInt[i] = sibling.BigInt()
// }
}
// AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits. // AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits.
func (p *Proof) AllSiblingsCircom(levels int) []*big.Int { func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
siblings := p.AllSiblings() siblings := p.AllSiblings()
@ -774,11 +786,31 @@ type CircomProcessorProof struct {
Siblings []*Hash Siblings []*Hash
OldKey *Hash OldKey *Hash
OldValue *Hash OldValue *Hash
IsOld0 bool
NewKey *Hash NewKey *Hash
NewValue *Hash NewValue *Hash
IsOld0 bool
// Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete // Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete
} }
func (p CircomProcessorProof) String() string {
buf := bytes.NewBufferString("{")
fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
fmt.Fprintf(buf, " Siblings: [\n ")
for _, s := range p.Siblings {
fmt.Fprintf(buf, "%v, ", s)
}
fmt.Fprintf(buf, "\n ],\n")
fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
fmt.Fprintf(buf, "}\n")
return buf.String()
}
type CircomVerifierProof struct { type CircomVerifierProof struct {
Root *Hash Root *Hash
Siblings []*big.Int Siblings []*big.Int

+ 42
- 24
merkletree_test.go

@ -123,18 +123,21 @@ func TestGet(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
v, _, err := mt.Get(big.NewInt(10))
k, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(10), k)
assert.Equal(t, big.NewInt(20), v) assert.Equal(t, big.NewInt(20), v)
v, _, err = mt.Get(big.NewInt(15))
k, v, _, err = mt.Get(big.NewInt(15))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(15), k)
assert.Equal(t, big.NewInt(30), v) assert.Equal(t, big.NewInt(30), v)
v, _, err = mt.Get(big.NewInt(16))
k, v, _, err = mt.Get(big.NewInt(16))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrKeyNotFound, err) assert.Equal(t, ErrKeyNotFound, err)
assert.Nil(t, v)
assert.Equal(t, "0", k.String())
assert.Equal(t, "0", v.String())
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
@ -148,13 +151,13 @@ func TestUpdate(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
v, _, err := mt.Get(big.NewInt(10))
_, v, _, err := mt.Get(big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v) assert.Equal(t, big.NewInt(20), v)
_, err = mt.Update(big.NewInt(10), big.NewInt(1024)) _, err = mt.Update(big.NewInt(10), big.NewInt(1024))
assert.Nil(t, err) assert.Nil(t, err)
v, _, err = mt.Get(big.NewInt(10))
_, v, _, err = mt.Get(big.NewInt(10))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(1024), v) assert.Equal(t, big.NewInt(1024), v)
@ -563,22 +566,37 @@ func TestAddAndGetCircomProof(t *testing.T) {
assert.Equal(t, "0", mt.Root().String()) assert.Equal(t, "0", mt.Root().String())
// test vectors generated using https://github.com/iden3/circomlib smt.js // test vectors generated using https://github.com/iden3/circomlib smt.js
_, err = mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String())
_, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
_, err = mt.AddAndGetCircomProof(big.NewInt(1234), big.NewInt(9876))
assert.Nil(t, err)
assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String())
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v)
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
assert.Nil(t, err)
assert.Equal(t, "0", cpp.OldRoot.String())
assert.Equal(t, "49322979...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "1", cpp.NewKey.String())
assert.Equal(t, "2", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
assert.Nil(t, err)
assert.Equal(t, "49322979...", cpp.OldRoot.String())
assert.Equal(t, "13563340...", cpp.NewRoot.String())
assert.Equal(t, "1", cpp.OldKey.String())
assert.Equal(t, "2", cpp.OldValue.String())
assert.Equal(t, "33", cpp.NewKey.String())
assert.Equal(t, "44", cpp.NewValue.String())
assert.Equal(t, false, cpp.IsOld0)
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
assert.Nil(t, err)
assert.Equal(t, "13563340...", cpp.OldRoot.String())
assert.Equal(t, "21716426...", cpp.NewRoot.String())
assert.Equal(t, "0", cpp.OldKey.String())
assert.Equal(t, "0", cpp.OldValue.String())
assert.Equal(t, "55", cpp.NewKey.String())
assert.Equal(t, "66", cpp.NewValue.String())
assert.Equal(t, true, cpp.IsOld0)
assert.Equal(t, "[0 34319575... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
// fmt.Println(cpp)
} }

Loading…
Cancel
Save