From f1665b1a15ac1df9a5824a44c766eec69af12279 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Fri, 2 Apr 2021 15:31:38 +0200 Subject: [PATCH] Add Iterate, Dump, ImportDump methods to Tree --- tree.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++-- tree_test.go | 39 +++++++++++++++++++--- 2 files changed, 126 insertions(+), 6 deletions(-) diff --git a/tree.go b/tree.go index 65b511a..c2ad2e1 100644 --- a/tree.go +++ b/tree.go @@ -14,6 +14,7 @@ package arbo import ( "bytes" "fmt" + "io" "math" "sync/atomic" "time" @@ -397,9 +398,11 @@ func (t *Tree) GenProof(k []byte) ([]byte, error) { } // PackSiblings packs the siblings into a byte array. -// [ 1 byte | L bytes | 32 * N bytes ] +// [ 1 byte | L bytes | S * N bytes ] // [ bitmap length (L) | bitmap | N non-zero siblings ] -// Where the bitmap indicates if the sibling is 0 or a value from the siblings array. +// Where the bitmap indicates if the sibling is 0 or a value from the siblings +// array. And S is the size of the output of the hash function used for the +// Tree. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { var b []byte var bitmap []bool @@ -533,3 +536,89 @@ func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) { } return nil, db.ErrNotFound } + +// Iterate iterates through the full Tree, executing the given function on each +// node of the Tree. +func (t *Tree) Iterate(f func([]byte, []byte)) error { + return t.iter(t.root, f) +} + +func (t *Tree) iter(k []byte, f func([]byte, []byte)) error { + v, err := t.dbGet(nil, k) + if err != nil { + return err + } + + switch v[0] { + case PrefixValueEmpty: + f(k, v) + case PrefixValueLeaf: + f(k, v) + case PrefixValueIntermediate: + f(k, v) + l, r := readIntermediateChilds(v) + if err = t.iter(l, f); err != nil { + return err + } + if err = t.iter(r, f); err != nil { + return err + } + default: + return fmt.Errorf("invalid value") + } + return nil +} + +// Dump exports all the Tree leafs in a byte array of length: +// [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v: +// [ 1 byte | 1 byte | S bytes | len(v) bytes ] +// [ len(k) | len(v) | key | value ] +// Where S is the size of the output of the hash function used for the Tree. +func (t *Tree) Dump() ([]byte, error) { + // WARNING current encoding only supports key & values of 255 bytes each + // (due using only 1 byte for the length headers). + var b []byte + err := t.Iterate(func(k, v []byte) { + if v[0] != PrefixValueLeaf { + return + } + leafK, leafV := readLeafValue(v) + kv := make([]byte, 2+len(leafK)+len(leafV)) + kv[0] = byte(len(leafK)) + kv[1] = byte(len(leafV)) + copy(kv[2:2+len(leafK)], leafK) + copy(kv[2+len(leafK):], leafV) + b = append(b, kv...) + }) + return b, err +} + +// ImportDump imports the leafs (that have been exported with the ExportLeafs +// method) in the Tree. +func (t *Tree) ImportDump(b []byte) error { + r := bytes.NewReader(b) + for { + l := make([]byte, 2) + _, err := io.ReadFull(r, l) + if err == io.EOF { + break + } else if err != nil { + return err + } + k := make([]byte, l[0]) + _, err = io.ReadFull(r, k) + if err != nil { + return err + } + v := make([]byte, l[1]) + _, err = io.ReadFull(r, v) + if err != nil { + return err + } + err = t.Add(k, v) + if err != nil { + return err + } + } + return nil +} diff --git a/tree_test.go b/tree_test.go index 869ffc3..8ba7b69 100644 --- a/tree_test.go +++ b/tree_test.go @@ -34,6 +34,7 @@ func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) { tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc) assert.Nil(t, err) defer tree.db.Close() + assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root())) err = tree.Add( @@ -100,8 +101,8 @@ func TestAddBatch(t *testing.T) { func TestAddDifferentOrder(t *testing.T) { tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) - defer tree1.db.Close() + for i := 0; i < 16; i++ { k := SwapEndianness(big.NewInt(int64(i)).Bytes()) v := SwapEndianness(big.NewInt(0).Bytes()) @@ -113,6 +114,7 @@ func TestAddDifferentOrder(t *testing.T) { tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) defer tree2.db.Close() + for i := 16 - 1; i >= 0; i-- { k := big.NewInt(int64(i)).Bytes() v := big.NewInt(0).Bytes() @@ -131,6 +133,7 @@ func TestAddRepeatedIndex(t *testing.T) { tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) defer tree.db.Close() + k := big.NewInt(int64(3)).Bytes() v := big.NewInt(int64(12)).Bytes() if err := tree.Add(k, v); err != nil { @@ -145,6 +148,7 @@ func TestAux(t *testing.T) { tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) defer tree.db.Close() + k := BigIntToBytes(big.NewInt(int64(1))) v := BigIntToBytes(big.NewInt(int64(0))) err = tree.Add(k, v) @@ -168,8 +172,8 @@ func TestAux(t *testing.T) { func TestGet(t *testing.T) { tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) - defer tree.db.Close() + for i := 0; i < 10; i++ { k := BigIntToBytes(big.NewInt(int64(i))) v := BigIntToBytes(big.NewInt(int64(i * 2))) @@ -188,8 +192,8 @@ func TestGet(t *testing.T) { func TestGenProofAndVerify(t *testing.T) { tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) - defer tree.db.Close() + for i := 0; i < 10; i++ { k := BigIntToBytes(big.NewInt(int64(i))) v := BigIntToBytes(big.NewInt(int64(i * 2))) @@ -209,6 +213,33 @@ func TestGenProofAndVerify(t *testing.T) { assert.True(t, verif) } +func TestDumpAndImportDump(t *testing.T) { + tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + defer tree1.db.Close() + + for i := 0; i < 16; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i * 2))) + if err := tree1.Add(k, v); err != nil { + t.Fatal(err) + } + } + + e, err := tree1.Dump() + require.Nil(t, err) + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + defer tree2.db.Close() + err = tree2.ImportDump(e) + require.Nil(t, err) + assert.Equal(t, tree1.Root(), tree2.Root()) + assert.Equal(t, + "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08", + hex.EncodeToString(tree2.Root())) +} + func BenchmarkAdd(b *testing.B) { // prepare inputs var ks, vs [][]byte @@ -230,8 +261,8 @@ func BenchmarkAdd(b *testing.B) { func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) { tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc) require.Nil(b, err) - defer tree.db.Close() + for i := 0; i < len(ks); i++ { if err := tree.Add(ks[i], vs[i]); err != nil { b.Fatal(err)