mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-07 14:31:28 +01:00
Add proof generation
This commit is contained in:
98
tree.go
98
tree.go
@@ -14,6 +14,7 @@ package arbo
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -114,12 +115,12 @@ func (t *Tree) Add(k, v []byte) error {
|
||||
path := getPath(t.maxLevels, keyPath)
|
||||
// go down to the leaf
|
||||
var siblings [][]byte
|
||||
_, _, siblings, err := t.down(k, t.root, siblings, path, 0)
|
||||
_, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
leafKey, leafValue, err := t.newLeafValue(k, v)
|
||||
leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,7 +149,7 @@ func (t *Tree) Add(k, v []byte) error {
|
||||
}
|
||||
|
||||
// down goes down to the leaf recursively
|
||||
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) (
|
||||
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int, getLeaf bool) (
|
||||
[]byte, []byte, [][]byte, error) {
|
||||
if l > t.maxLevels-1 {
|
||||
return nil, nil, nil, fmt.Errorf("max level")
|
||||
@@ -177,6 +178,9 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in
|
||||
}
|
||||
|
||||
if !bytes.Equal(currValue, emptyValue) {
|
||||
if getLeaf {
|
||||
return currKey, currValue, siblings, nil
|
||||
}
|
||||
oldLeafKey, _ := readLeafValue(currValue)
|
||||
oldLeafKeyFull := make([]byte, t.hashFunction.Len())
|
||||
copy(oldLeafKeyFull[:], oldLeafKey)
|
||||
@@ -200,12 +204,12 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in
|
||||
// right
|
||||
lChild, rChild := readIntermediateChilds(currValue)
|
||||
siblings = append(siblings, lChild)
|
||||
return t.down(newKey, rChild, siblings, path, l+1)
|
||||
return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
|
||||
}
|
||||
// left
|
||||
lChild, rChild := readIntermediateChilds(currValue)
|
||||
siblings = append(siblings, rChild)
|
||||
return t.down(newKey, lChild, siblings, path, l+1)
|
||||
return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
|
||||
default:
|
||||
return nil, nil, nil, fmt.Errorf("invalid value")
|
||||
}
|
||||
@@ -241,12 +245,12 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
|
||||
var k, v []byte
|
||||
var err error
|
||||
if path[l] {
|
||||
k, v, err = t.newIntermediate(siblings[l], key)
|
||||
k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
k, v, err = t.newIntermediate(key, siblings[l])
|
||||
k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -265,8 +269,8 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) (
|
||||
return t.up(tx, k, siblings, path, l-1)
|
||||
}
|
||||
|
||||
func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
|
||||
leafKey, err := t.hashFunction.Hash(k, v, []byte{1})
|
||||
func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
|
||||
leafKey, err := hashFunc.Hash(k, v, []byte{1})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -292,14 +296,14 @@ func readLeafValue(b []byte) ([]byte, []byte) {
|
||||
return k, v
|
||||
}
|
||||
|
||||
func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
|
||||
b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2)
|
||||
func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
|
||||
b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
|
||||
b[0] = 2
|
||||
b[1] = byte(len(l))
|
||||
copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l)
|
||||
copy(b[PrefixValueLen+t.hashFunction.Len():], r)
|
||||
copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
|
||||
copy(b[PrefixValueLen+hashFunc.Len():], r)
|
||||
|
||||
key, err := t.hashFunction.Hash(l, r)
|
||||
key, err := hashFunc.Hash(l, r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -332,9 +336,67 @@ func getPath(numLevels int, k []byte) []bool {
|
||||
// GenProof generates a MerkleTree proof for the given key. If the key exists in
|
||||
// the Tree, the proof will be of existence, if the key does not exist in the
|
||||
// tree, the proof will be of non-existence.
|
||||
func (t *Tree) GenProof(k, v []byte) ([]byte, error) {
|
||||
// unimplemented
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
func (t *Tree) GenProof(k []byte) ([]byte, error) {
|
||||
keyPath := make([]byte, t.hashFunction.Len())
|
||||
copy(keyPath[:], k)
|
||||
|
||||
path := getPath(t.maxLevels, keyPath)
|
||||
// go down to the leaf
|
||||
var siblings [][]byte
|
||||
_, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leafK, leafV := readLeafValue(value)
|
||||
if !bytes.Equal(k, leafK) {
|
||||
fmt.Println("key not in Tree")
|
||||
fmt.Println(leafK)
|
||||
fmt.Println(leafV)
|
||||
// TODO proof of non-existence
|
||||
panic(fmt.Errorf("unimplemented"))
|
||||
}
|
||||
|
||||
s := PackSiblings(t.hashFunction, siblings)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// PackSiblings packs the siblings into a byte array.
|
||||
// [ 1 byte | L bytes | 32 * N bytes ]
|
||||
// [ bitmap length (L) | bitmap | N non-zero siblings ]
|
||||
// Where the bitmap indicates if the sibling is 0 or a value from the siblings array.
|
||||
func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
|
||||
var b []byte
|
||||
var bitmap []bool
|
||||
emptySibling := make([]byte, hashFunc.Len())
|
||||
for i := 0; i < len(siblings); i++ {
|
||||
if bytes.Equal(siblings[i], emptySibling) {
|
||||
bitmap = append(bitmap, false)
|
||||
} else {
|
||||
bitmap = append(bitmap, true)
|
||||
b = append(b, siblings[i]...)
|
||||
}
|
||||
}
|
||||
|
||||
bitmapBytes := bitmapToBytes(bitmap)
|
||||
l := len(bitmapBytes)
|
||||
|
||||
res := make([]byte, l+1+len(b))
|
||||
res[0] = byte(l) // set the bitmapBytes length
|
||||
copy(res[1:1+l], bitmapBytes)
|
||||
copy(res[1+l:], b)
|
||||
return res
|
||||
}
|
||||
|
||||
func bitmapToBytes(bitmap []bool) []byte {
|
||||
bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
|
||||
b := make([]byte, bitmapBytesLen)
|
||||
for i := 0; i < len(bitmap); i++ {
|
||||
if bitmap[i] {
|
||||
b[i/8] |= 1 << (i % 8)
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Get returns the value for a given key
|
||||
@@ -348,5 +410,3 @@ func CheckProof(k, v, root, mproof []byte) (bool, error) {
|
||||
// unimplemented
|
||||
return false, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
// TODO method to export & import the full Tree without values
|
||||
|
||||
18
tree_test.go
18
tree_test.go
@@ -145,6 +145,24 @@ func TestAux(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestGenProof(t *testing.T) {
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
require.Nil(t, err)
|
||||
|
||||
defer tree.db.Close()
|
||||
for i := 0; i < 10; i++ {
|
||||
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||
v := BigIntToBytes(big.NewInt(int64(i * 2)))
|
||||
if err := tree.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
k := BigIntToBytes(big.NewInt(int64(7)))
|
||||
_, err = tree.GenProof(k)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func BenchmarkAdd(b *testing.B) {
|
||||
// prepare inputs
|
||||
var ks, vs [][]byte
|
||||
|
||||
Reference in New Issue
Block a user