mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-07 03:26:46 +01:00
AddAndGetCircomProof generating CircomProcessorProof
This commit is contained in:
@@ -71,6 +71,9 @@ func (h Hash) Hex() string {
|
|||||||
|
|
||||||
// BigInt returns the *big.Int representation of the *Hash
|
// BigInt returns the *big.Int representation of the *Hash
|
||||||
func (h *Hash) BigInt() *big.Int {
|
func (h *Hash) BigInt() *big.Int {
|
||||||
|
if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
|
||||||
|
return big.NewInt(0)
|
||||||
|
}
|
||||||
return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
|
return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,20 +199,22 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
|
func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
|
||||||
|
|
||||||
var cp CircomProcessorProof
|
var cp CircomProcessorProof
|
||||||
cp.OldRoot = mt.rootKey
|
cp.OldRoot = mt.rootKey
|
||||||
gettedV, siblings, err := mt.Get(k)
|
gettedK, gettedV, siblings, err := mt.Get(k)
|
||||||
if err != nil && err != ErrKeyNotFound {
|
if err != nil && err != ErrKeyNotFound {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err == ErrKeyNotFound {
|
cp.OldKey = NewHashFromBigInt(gettedK)
|
||||||
cp.OldKey = &HashZero
|
|
||||||
cp.OldValue = &HashZero
|
|
||||||
} else {
|
|
||||||
cp.OldKey = NewHashFromBigInt(k)
|
|
||||||
cp.OldValue = NewHashFromBigInt(gettedV)
|
cp.OldValue = NewHashFromBigInt(gettedV)
|
||||||
|
if bytes.Equal(cp.OldKey[:], HashZero[:]) {
|
||||||
|
cp.IsOld0 = true
|
||||||
}
|
}
|
||||||
|
_, _, siblings, err = mt.Get(k)
|
||||||
|
if err != nil && err != ErrKeyNotFound {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
|
||||||
|
|
||||||
err = mt.Add(k, v)
|
err = mt.Add(k, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -220,12 +225,6 @@ func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof
|
|||||||
cp.NewValue = NewHashFromBigInt(v)
|
cp.NewValue = NewHashFromBigInt(v)
|
||||||
cp.NewRoot = mt.rootKey
|
cp.NewRoot = mt.rootKey
|
||||||
|
|
||||||
_, siblings, err = mt.Get(k)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cp.Siblings = siblings
|
|
||||||
|
|
||||||
return &cp, nil
|
return &cp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,10 +342,10 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the value of the leaf for the given key
|
// Get returns the value of the leaf for the given key
|
||||||
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) {
|
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
|
||||||
// verfy that k is valid and fit inside the Finite Field.
|
// verfy that k is valid and fit inside the Finite Field.
|
||||||
if !cryptoUtils.CheckBigIntInField(k) {
|
if !cryptoUtils.CheckBigIntInField(k) {
|
||||||
return nil, nil, errors.New("Key not inside the Finite Field")
|
return nil, nil, nil, errors.New("Key not inside the Finite Field")
|
||||||
}
|
}
|
||||||
|
|
||||||
kHash := NewHashFromBigInt(k)
|
kHash := NewHashFromBigInt(k)
|
||||||
@@ -357,16 +356,16 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) {
|
|||||||
for i := 0; i < mt.maxLevels; i++ {
|
for i := 0; i < mt.maxLevels; i++ {
|
||||||
n, err := mt.GetNode(nextKey)
|
n, err := mt.GetNode(nextKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case NodeTypeEmpty:
|
case NodeTypeEmpty:
|
||||||
return nil, nil, ErrKeyNotFound
|
return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
|
||||||
case NodeTypeLeaf:
|
case NodeTypeLeaf:
|
||||||
if bytes.Equal(kHash[:], n.Entry[0][:]) {
|
if bytes.Equal(kHash[:], n.Entry[0][:]) {
|
||||||
return n.Entry[1].BigInt(), siblings, nil
|
return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
|
||||||
} else {
|
} else {
|
||||||
return nil, nil, ErrKeyNotFound
|
return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
|
||||||
}
|
}
|
||||||
case NodeTypeMiddle:
|
case NodeTypeMiddle:
|
||||||
if path[i] {
|
if path[i] {
|
||||||
@@ -377,11 +376,11 @@ func (mt *MerkleTree) Get(k *big.Int) (*big.Int, []*Hash, error) {
|
|||||||
siblings = append(siblings, n.ChildR)
|
siblings = append(siblings, n.ChildR)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, nil, ErrInvalidNodeFound
|
return nil, nil, nil, ErrInvalidNodeFound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, ErrKeyNotFound
|
return nil, nil, nil, ErrKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update updates the value of a specified key in the MerkleTree, and updates
|
// Update updates the value of a specified key in the MerkleTree, and updates
|
||||||
@@ -753,6 +752,19 @@ func (p *Proof) AllSiblings() []*Hash {
|
|||||||
return SiblingsFromProof(p)
|
return SiblingsFromProof(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
|
||||||
|
// Add the rest of empty levels to the siblings
|
||||||
|
for i := len(siblings); i < levels; i++ {
|
||||||
|
siblings = append(siblings, &HashZero)
|
||||||
|
}
|
||||||
|
siblings = append(siblings, &HashZero) // add extra level for circom compatibility
|
||||||
|
return siblings
|
||||||
|
// siblingsBigInt := make([]*big.Int, len(siblings))
|
||||||
|
// for i, sibling := range siblings {
|
||||||
|
// siblingsBigInt[i] = sibling.BigInt()
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
// AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits.
|
// AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits.
|
||||||
func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
|
func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
|
||||||
siblings := p.AllSiblings()
|
siblings := p.AllSiblings()
|
||||||
@@ -774,11 +786,31 @@ type CircomProcessorProof struct {
|
|||||||
Siblings []*Hash
|
Siblings []*Hash
|
||||||
OldKey *Hash
|
OldKey *Hash
|
||||||
OldValue *Hash
|
OldValue *Hash
|
||||||
IsOld0 bool
|
|
||||||
NewKey *Hash
|
NewKey *Hash
|
||||||
NewValue *Hash
|
NewValue *Hash
|
||||||
|
IsOld0 bool
|
||||||
// Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete
|
// Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p CircomProcessorProof) String() string {
|
||||||
|
buf := bytes.NewBufferString("{")
|
||||||
|
fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
|
||||||
|
fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
|
||||||
|
fmt.Fprintf(buf, " Siblings: [\n ")
|
||||||
|
for _, s := range p.Siblings {
|
||||||
|
fmt.Fprintf(buf, "%v, ", s)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(buf, "\n ],\n")
|
||||||
|
fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
|
||||||
|
fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
|
||||||
|
fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
|
||||||
|
fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
|
||||||
|
fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
|
||||||
|
fmt.Fprintf(buf, "}\n")
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
type CircomVerifierProof struct {
|
type CircomVerifierProof struct {
|
||||||
Root *Hash
|
Root *Hash
|
||||||
Siblings []*big.Int
|
Siblings []*big.Int
|
||||||
|
|||||||
@@ -123,18 +123,21 @@ func TestGet(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
v, _, err := mt.Get(big.NewInt(10))
|
k, v, _, err := mt.Get(big.NewInt(10))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, big.NewInt(10), k)
|
||||||
assert.Equal(t, big.NewInt(20), v)
|
assert.Equal(t, big.NewInt(20), v)
|
||||||
|
|
||||||
v, _, err = mt.Get(big.NewInt(15))
|
k, v, _, err = mt.Get(big.NewInt(15))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, big.NewInt(15), k)
|
||||||
assert.Equal(t, big.NewInt(30), v)
|
assert.Equal(t, big.NewInt(30), v)
|
||||||
|
|
||||||
v, _, err = mt.Get(big.NewInt(16))
|
k, v, _, err = mt.Get(big.NewInt(16))
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
assert.Equal(t, ErrKeyNotFound, err)
|
assert.Equal(t, ErrKeyNotFound, err)
|
||||||
assert.Nil(t, v)
|
assert.Equal(t, "0", k.String())
|
||||||
|
assert.Equal(t, "0", v.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
func TestUpdate(t *testing.T) {
|
||||||
@@ -148,13 +151,13 @@ func TestUpdate(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
v, _, err := mt.Get(big.NewInt(10))
|
_, v, _, err := mt.Get(big.NewInt(10))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, big.NewInt(20), v)
|
assert.Equal(t, big.NewInt(20), v)
|
||||||
|
|
||||||
_, err = mt.Update(big.NewInt(10), big.NewInt(1024))
|
_, err = mt.Update(big.NewInt(10), big.NewInt(1024))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
v, _, err = mt.Get(big.NewInt(10))
|
_, v, _, err = mt.Get(big.NewInt(10))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, big.NewInt(1024), v)
|
assert.Equal(t, big.NewInt(1024), v)
|
||||||
|
|
||||||
@@ -563,22 +566,37 @@ func TestAddAndGetCircomProof(t *testing.T) {
|
|||||||
assert.Equal(t, "0", mt.Root().String())
|
assert.Equal(t, "0", mt.Root().String())
|
||||||
|
|
||||||
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
// test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||||
_, err = mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
|
cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", mt.Root().BigInt().String())
|
assert.Equal(t, "0", cpp.OldRoot.String())
|
||||||
|
assert.Equal(t, "49322979...", cpp.NewRoot.String())
|
||||||
|
assert.Equal(t, "0", cpp.OldKey.String())
|
||||||
|
assert.Equal(t, "0", cpp.OldValue.String())
|
||||||
|
assert.Equal(t, "1", cpp.NewKey.String())
|
||||||
|
assert.Equal(t, "2", cpp.NewValue.String())
|
||||||
|
assert.Equal(t, true, cpp.IsOld0)
|
||||||
|
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
|
||||||
|
|
||||||
_, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
|
cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "13563340744765267202993741297198970774200042973817962221376874695587906013050", mt.Root().BigInt().String())
|
assert.Equal(t, "49322979...", cpp.OldRoot.String())
|
||||||
|
assert.Equal(t, "13563340...", cpp.NewRoot.String())
|
||||||
|
assert.Equal(t, "1", cpp.OldKey.String())
|
||||||
|
assert.Equal(t, "2", cpp.OldValue.String())
|
||||||
|
assert.Equal(t, "33", cpp.NewKey.String())
|
||||||
|
assert.Equal(t, "44", cpp.NewValue.String())
|
||||||
|
assert.Equal(t, false, cpp.IsOld0)
|
||||||
|
assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
|
||||||
|
|
||||||
_, err = mt.AddAndGetCircomProof(big.NewInt(1234), big.NewInt(9876))
|
cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "16970503620176669663662021947486532860010370357132361783766545149750777353066", mt.Root().BigInt().String())
|
assert.Equal(t, "13563340...", cpp.OldRoot.String())
|
||||||
|
assert.Equal(t, "21716426...", cpp.NewRoot.String())
|
||||||
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
|
assert.Equal(t, "0", cpp.OldKey.String())
|
||||||
assert.Nil(t, err)
|
assert.Equal(t, "0", cpp.OldValue.String())
|
||||||
assert.Equal(t, big.NewInt(44), v)
|
assert.Equal(t, "55", cpp.NewKey.String())
|
||||||
|
assert.Equal(t, "66", cpp.NewValue.String())
|
||||||
assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
|
assert.Equal(t, true, cpp.IsOld0)
|
||||||
assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
|
assert.Equal(t, "[0 34319575... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
|
||||||
|
// fmt.Println(cpp)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user