diff --git a/tree.go b/tree.go index 7df1f91..bf59e49 100644 --- a/tree.go +++ b/tree.go @@ -14,6 +14,7 @@ package arbo import ( "bytes" "fmt" + "math" "sync/atomic" "time" @@ -114,12 +115,12 @@ func (t *Tree) Add(k, v []byte) error { path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(k, t.root, siblings, path, 0) + _, _, siblings, err := t.down(k, t.root, siblings, path, 0, false) if err != nil { return err } - leafKey, leafValue, err := t.newLeafValue(k, v) + leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v) if err != nil { return err } @@ -148,7 +149,7 @@ func (t *Tree) Add(k, v []byte) error { } // down goes down to the leaf recursively -func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) ( +func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int, getLeaf bool) ( []byte, []byte, [][]byte, error) { if l > t.maxLevels-1 { return nil, nil, nil, fmt.Errorf("max level") @@ -177,6 +178,9 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in } if !bytes.Equal(currValue, emptyValue) { + if getLeaf { + return currKey, currValue, siblings, nil + } oldLeafKey, _ := readLeafValue(currValue) oldLeafKeyFull := make([]byte, t.hashFunction.Len()) copy(oldLeafKeyFull[:], oldLeafKey) @@ -200,12 +204,12 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l in // right lChild, rChild := readIntermediateChilds(currValue) siblings = append(siblings, lChild) - return t.down(newKey, rChild, siblings, path, l+1) + return t.down(newKey, rChild, siblings, path, l+1, getLeaf) } // left lChild, rChild := readIntermediateChilds(currValue) siblings = append(siblings, rChild) - return t.down(newKey, lChild, siblings, path, l+1) + return t.down(newKey, lChild, siblings, path, l+1, getLeaf) default: return nil, nil, nil, fmt.Errorf("invalid value") } @@ -241,12 +245,12 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ( var k, v []byte var err error if path[l] { - k, v, err = t.newIntermediate(siblings[l], key) + k, v, err = newIntermediate(t.hashFunction, siblings[l], key) if err != nil { return nil, err } } else { - k, v, err = t.newIntermediate(key, siblings[l]) + k, v, err = newIntermediate(t.hashFunction, key, siblings[l]) if err != nil { return nil, err } @@ -265,8 +269,8 @@ func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ( return t.up(tx, k, siblings, path, l-1) } -func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) { - leafKey, err := t.hashFunction.Hash(k, v, []byte{1}) +func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { + leafKey, err := hashFunc.Hash(k, v, []byte{1}) if err != nil { return nil, nil, err } @@ -292,14 +296,14 @@ func readLeafValue(b []byte) ([]byte, []byte) { return k, v } -func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) { - b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2) +func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) { + b := make([]byte, PrefixValueLen+hashFunc.Len()*2) b[0] = 2 b[1] = byte(len(l)) - copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l) - copy(b[PrefixValueLen+t.hashFunction.Len():], r) + copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l) + copy(b[PrefixValueLen+hashFunc.Len():], r) - key, err := t.hashFunction.Hash(l, r) + key, err := hashFunc.Hash(l, r) if err != nil { return nil, nil, err } @@ -332,9 +336,67 @@ func getPath(numLevels int, k []byte) []bool { // GenProof generates a MerkleTree proof for the given key. If the key exists in // the Tree, the proof will be of existence, if the key does not exist in the // tree, the proof will be of non-existence. -func (t *Tree) GenProof(k, v []byte) ([]byte, error) { - // unimplemented - return nil, fmt.Errorf("unimplemented") +func (t *Tree) GenProof(k []byte) ([]byte, error) { + keyPath := make([]byte, t.hashFunction.Len()) + copy(keyPath[:], k) + + path := getPath(t.maxLevels, keyPath) + // go down to the leaf + var siblings [][]byte + _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true) + if err != nil { + return nil, err + } + + leafK, leafV := readLeafValue(value) + if !bytes.Equal(k, leafK) { + fmt.Println("key not in Tree") + fmt.Println(leafK) + fmt.Println(leafV) + // TODO proof of non-existence + panic(fmt.Errorf("unimplemented")) + } + + s := PackSiblings(t.hashFunction, siblings) + return s, nil +} + +// PackSiblings packs the siblings into a byte array. +// [ 1 byte | L bytes | 32 * N bytes ] +// [ bitmap length (L) | bitmap | N non-zero siblings ] +// Where the bitmap indicates if the sibling is 0 or a value from the siblings array. +func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { + var b []byte + var bitmap []bool + emptySibling := make([]byte, hashFunc.Len()) + for i := 0; i < len(siblings); i++ { + if bytes.Equal(siblings[i], emptySibling) { + bitmap = append(bitmap, false) + } else { + bitmap = append(bitmap, true) + b = append(b, siblings[i]...) + } + } + + bitmapBytes := bitmapToBytes(bitmap) + l := len(bitmapBytes) + + res := make([]byte, l+1+len(b)) + res[0] = byte(l) // set the bitmapBytes length + copy(res[1:1+l], bitmapBytes) + copy(res[1+l:], b) + return res +} + +func bitmapToBytes(bitmap []bool) []byte { + bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd + b := make([]byte, bitmapBytesLen) + for i := 0; i < len(bitmap); i++ { + if bitmap[i] { + b[i/8] |= 1 << (i % 8) + } + } + return b } // Get returns the value for a given key @@ -348,5 +410,3 @@ func CheckProof(k, v, root, mproof []byte) (bool, error) { // unimplemented return false, fmt.Errorf("unimplemented") } - -// TODO method to export & import the full Tree without values diff --git a/tree_test.go b/tree_test.go index 672a674..db0ac56 100644 --- a/tree_test.go +++ b/tree_test.go @@ -145,6 +145,24 @@ func TestAux(t *testing.T) { assert.Nil(t, err) } +func TestGenProof(t *testing.T) { + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + + defer tree.db.Close() + for i := 0; i < 10; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + + k := BigIntToBytes(big.NewInt(int64(7))) + _, err = tree.GenProof(k) + assert.Nil(t, err) +} + func BenchmarkAdd(b *testing.B) { // prepare inputs var ks, vs [][]byte