diff --git a/addbatch_test.go b/addbatch_test.go index f5f55fb..1e89a57 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -75,11 +75,17 @@ func TestAddBatchTreeEmpty(t *testing.T) { defer tree.db.Close() //nolint:errcheck bLen := tree.HashFunction().Len() - start := time.Now() + var keys, values [][]byte for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) - if err := tree.Add(k, v); err != nil { + keys = append(keys, k) + values = append(values, v) + } + + start := time.Now() + for i := 0; i < nLeafs; i++ { + if err := tree.Add(keys[i], values[i]); err != nil { t.Fatal(err) } } @@ -92,13 +98,6 @@ func TestAddBatchTreeEmpty(t *testing.T) { defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() - var keys, values [][]byte - for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(bLen, big.NewInt(int64(i))) - v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) - keys = append(keys, k) - values = append(values, v) - } start = time.Now() indexes, err := tree2.AddBatch(keys, values) c.Assert(err, qt.IsNil) @@ -916,6 +915,77 @@ func TestLoadVT(t *testing.T) { c.Check(tree.Root(), qt.DeepEquals, vt.root.h) } +// TestAddKeysWithEmptyValues calls AddBatch giving an array of empty values +func TestAddKeysWithEmptyValues(t *testing.T) { + c := qt.New(t) + + nLeafs := 1024 + + database, err := db.NewBadgerDB(c.TempDir()) + c.Assert(err, qt.IsNil) + tree, err := NewTree(database, 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree.db.Close() //nolint:errcheck + + bLen := tree.HashFunction().Len() + var keys, values [][]byte + for i := 0; i < nLeafs; i++ { + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := []byte{} + keys = append(keys, k) + values = append(values, v) + } + + for i := 0; i < nLeafs; i++ { + if err := tree.Add(keys[i], values[i]); err != nil { + t.Fatal(err) + } + } + + database2, err := db.NewBadgerDB(c.TempDir()) + c.Assert(err, qt.IsNil) + tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + defer tree2.db.Close() //nolint:errcheck + tree2.dbgInit() + + indexes, err := tree2.AddBatch(keys, values) + c.Assert(err, qt.IsNil) + c.Check(len(indexes), qt.Equals, 0) + + // check that both trees roots are equal + checkRoots(c, tree, tree2) + + kAux, proofV, siblings, existence, err := tree2.GenProof(keys[9]) + c.Assert(err, qt.IsNil) + c.Assert(proofV, qt.DeepEquals, values[9]) + c.Assert(keys[9], qt.DeepEquals, kAux) + c.Assert(existence, qt.IsTrue) + + // check with empty array + verif, err := CheckProof(tree.hashFunction, keys[9], []byte{}, tree.Root(), siblings) + c.Assert(err, qt.IsNil) + c.Check(verif, qt.IsTrue) + + // check with array with only 1 zero + verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0}, tree.Root(), siblings) + c.Assert(err, qt.IsNil) + c.Check(verif, qt.IsTrue) + + // check with array with 32 zeroes + e32 := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + c.Assert(len(e32), qt.Equals, 32) + verif, err = CheckProof(tree.hashFunction, keys[9], e32, tree.Root(), siblings) + c.Assert(err, qt.IsNil) + c.Check(verif, qt.IsTrue) + + // check with array with value!=0 returns false at verification + verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, tree.Root(), siblings) + c.Assert(err, qt.IsNil) + c.Check(verif, qt.IsFalse) +} + // TODO test adding batch with multiple invalid keys // TODO for tests of AddBatch, if the root does not match the Add root, bulk // all the leafs of both trees into a log file to later be able to debug and diff --git a/tree.go b/tree.go index 9ab2b05..81c0dd9 100644 --- a/tree.go +++ b/tree.go @@ -83,10 +83,11 @@ type Tree struct { dbg *dbgStats } -const bmSize = sha256.Size +// bmKeySize stands for batchMemoryKeySize +const bmKeySize = sha256.Size // TMP -type kvMap map[[bmSize]byte]kv +type kvMap map[[bmKeySize]byte]kv // Get retreives the value respective to a key from the KvMap func (m kvMap) Get(k []byte) ([]byte, bool) { @@ -109,7 +110,7 @@ func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, err if err == ErrKeyNotFound { // store new root 0 t.dbBatch = t.db.NewBatch() - t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP + t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP t.root = t.emptyHash if err = t.dbPut(dbKeyRoot, t.root); err != nil { return nil, err @@ -139,7 +140,8 @@ func (t *Tree) HashFunction() HashFunction { } // AddBatch adds a batch of key-values to the Tree. Returns an array containing -// the indexes of the keys failed to add. +// the indexes of the keys failed to add. Supports empty values as input +// parameters, which is equivalent to 0 valued byte array. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { t.Lock() defer t.Unlock() @@ -151,6 +153,17 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { // TODO check validity of keys & values for Tree.hashFunction + // equal the number of keys & values + if len(keys) > len(values) { + // add missing values + for i := len(values); i < len(keys); i++ { + values = append(values, emptyValue) + } + } else if len(keys) < len(values) { + // crop extra values + values = values[:len(keys)] + } + invalids, err := vt.addBatch(keys, values) if err != nil { return nil, err @@ -166,7 +179,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { // store pairs in db t.dbBatch = t.db.NewBatch() - t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP + t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP for i := 0; i < len(pairs); i++ { if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil { return nil, err @@ -217,7 +230,7 @@ func (t *Tree) Add(k, v []byte) error { var err error t.dbBatch = t.db.NewBatch() - t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP + t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP // TODO check validity of key & value for Tree.hashFunction @@ -491,7 +504,7 @@ func (t *Tree) Update(k, v []byte) error { var err error t.dbBatch = t.db.NewBatch() - t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP + t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) diff --git a/vt.go b/vt.go index 083ca71..a2152fd 100644 --- a/vt.go +++ b/vt.go @@ -319,7 +319,7 @@ func (t *vt) computeHashes() ([][2][]byte, error) { t.params.maxLevels, bucketVT.params, bucketPairs[cpu]) if err != nil { // TODO WIP - panic("TODO" + err.Error()) + panic("TODO: " + err.Error()) } subRoots[cpu] = bucketVT.root