Browse Source

Add proof generation

master
arnaucube 3 years ago
parent
commit
8c63b5d192
2 changed files with 97 additions and 19 deletions
  1. +79
    -19
      tree.go
  2. +18
    -0
      tree_test.go

+ 79
- 19
tree.go

@ -14,6 +14,7 @@ package arbo
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math"
"sync/atomic" "sync/atomic"
"time" "time"
@ -114,12 +115,12 @@ func (t *Tree) Add(k, v []byte) error {
path := getPath(t.maxLevels, keyPath) path := getPath(t.maxLevels, keyPath)
// go down to the leaf // go down to the leaf
var siblings [][]byte var siblings [][]byte
_, _, siblings, err := t.down(k, t.root, siblings, path, 0)
_, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
if err != nil { if err != nil {
return err return err
} }
leafKey, leafValue, err := t.newLeafValue(k, v)
leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
if err != nil { if err != nil {
return err return err
} }
@ -148,7 +149,7 @@ func (t *Tree) Add(k, v []byte) error {
} }
// down goes down to the leaf recursively // down goes down to the leaf recursively
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) (
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int, getLeaf bool) (
[]byte, []byte, [][]byte, error) { []byte, []byte, [][]byte, error) {
if l > t.maxLevels-1 { if l > t.maxLevels-1 {
return nil, nil, nil, fmt.Errorf("max level") return nil, nil, nil, fmt.Errorf("max level")
@ -177,6 +178,9 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in
} }
if !bytes.Equal(currValue, emptyValue) { if !bytes.Equal(currValue, emptyValue) {
if getLeaf {
return currKey, currValue, siblings, nil
}
oldLeafKey, _ := readLeafValue(currValue) oldLeafKey, _ := readLeafValue(currValue)
oldLeafKeyFull := make([]byte, t.hashFunction.Len()) oldLeafKeyFull := make([]byte, t.hashFunction.Len())
copy(oldLeafKeyFull[:], oldLeafKey) copy(oldLeafKeyFull[:], oldLeafKey)
@ -200,12 +204,12 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in
// right // right
lChild, rChild := readIntermediateChilds(currValue) lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, lChild) siblings = append(siblings, lChild)
return t.down(newKey, rChild, siblings, path, l+1)
return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
} }
// left // left
lChild, rChild := readIntermediateChilds(currValue) lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, rChild) siblings = append(siblings, rChild)
return t.down(newKey, lChild, siblings, path, l+1)
return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
default: default:
return nil, nil, nil, fmt.Errorf("invalid value") return nil, nil, nil, fmt.Errorf("invalid value")
} }
@ -241,12 +245,12 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
var k, v []byte var k, v []byte
var err error var err error
if path[l] { if path[l] {
k, v, err = t.newIntermediate(siblings[l], key)
k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
k, v, err = t.newIntermediate(key, siblings[l])
k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -265,8 +269,8 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
return t.up(tx, k, siblings, path, l-1) return t.up(tx, k, siblings, path, l-1)
} }
func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
leafKey, err := t.hashFunction.Hash(k, v, []byte{1})
func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
leafKey, err := hashFunc.Hash(k, v, []byte{1})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -292,14 +296,14 @@ func readLeafValue(b []byte) ([]byte, []byte) {
return k, v return k, v
} }
func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2)
func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
b[0] = 2 b[0] = 2
b[1] = byte(len(l)) b[1] = byte(len(l))
copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l)
copy(b[PrefixValueLen+t.hashFunction.Len():], r)
copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
copy(b[PrefixValueLen+hashFunc.Len():], r)
key, err := t.hashFunction.Hash(l, r)
key, err := hashFunc.Hash(l, r)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -332,9 +336,67 @@ func getPath(numLevels int, k []byte) []bool {
// GenProof generates a MerkleTree proof for the given key. If the key exists in // GenProof generates a MerkleTree proof for the given key. If the key exists in
// the Tree, the proof will be of existence, if the key does not exist in the // the Tree, the proof will be of existence, if the key does not exist in the
// tree, the proof will be of non-existence. // tree, the proof will be of non-existence.
func (t *Tree) GenProof(k, v []byte) ([]byte, error) {
// unimplemented
return nil, fmt.Errorf("unimplemented")
func (t *Tree) GenProof(k []byte) ([]byte, error) {
keyPath := make([]byte, t.hashFunction.Len())
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)
// go down to the leaf
var siblings [][]byte
_, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
if err != nil {
return nil, err
}
leafK, leafV := readLeafValue(value)
if !bytes.Equal(k, leafK) {
fmt.Println("key not in Tree")
fmt.Println(leafK)
fmt.Println(leafV)
// TODO proof of non-existence
panic(fmt.Errorf("unimplemented"))
}
s := PackSiblings(t.hashFunction, siblings)
return s, nil
}
// PackSiblings packs the siblings into a byte array.
// [ 1 byte | L bytes | 32 * N bytes ]
// [ bitmap length (L) | bitmap | N non-zero siblings ]
// Where the bitmap indicates if the sibling is 0 or a value from the siblings array.
func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
var b []byte
var bitmap []bool
emptySibling := make([]byte, hashFunc.Len())
for i := 0; i < len(siblings); i++ {
if bytes.Equal(siblings[i], emptySibling) {
bitmap = append(bitmap, false)
} else {
bitmap = append(bitmap, true)
b = append(b, siblings[i]...)
}
}
bitmapBytes := bitmapToBytes(bitmap)
l := len(bitmapBytes)
res := make([]byte, l+1+len(b))
res[0] = byte(l) // set the bitmapBytes length
copy(res[1:1+l], bitmapBytes)
copy(res[1+l:], b)
return res
}
func bitmapToBytes(bitmap []bool) []byte {
bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
b := make([]byte, bitmapBytesLen)
for i := 0; i < len(bitmap); i++ {
if bitmap[i] {
b[i/8] |= 1 << (i % 8)
}
}
return b
} }
// Get returns the value for a given key // Get returns the value for a given key
@ -348,5 +410,3 @@ func CheckProof(k, v, root, mproof []byte) (bool, error) {
// unimplemented // unimplemented
return false, fmt.Errorf("unimplemented") return false, fmt.Errorf("unimplemented")
} }
// TODO method to export & import the full Tree without values

+ 18
- 0
tree_test.go

@ -145,6 +145,24 @@ func TestAux(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestGenProof(t *testing.T) {
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree.db.Close()
for i := 0; i < 10; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i * 2)))
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
}
k := BigIntToBytes(big.NewInt(int64(7)))
_, err = tree.GenProof(k)
assert.Nil(t, err)
}
func BenchmarkAdd(b *testing.B) { func BenchmarkAdd(b *testing.B) {
// prepare inputs // prepare inputs
var ks, vs [][]byte var ks, vs [][]byte

Loading…
Cancel
Save