diff --git a/tree.go b/tree.go index a5f3f57..6f856d0 100644 --- a/tree.go +++ b/tree.go @@ -311,13 +311,10 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { } func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { - keyPath := make([]byte, t.hashFunction.Len()) - // if len(k) > t.hashFunction.Len() { // WIP - // return nil, fmt.Errorf("len(k) > hashFunction.Len()") - // } + keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd copy(keyPath[:], k) - path := getPath(t.maxLevels, keyPath) + // go down to the leaf var siblings [][]byte _, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false) @@ -593,12 +590,7 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { return ErrSnapshotNotEditable } - var err error - - keyPath := make([]byte, t.hashFunction.Len()) - // if len(k) > t.hashFunction.Len() { // WIP - // return fmt.Errorf("len(k) > hashFunction.Len()") - // } + keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) @@ -655,18 +647,15 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) { // GenProofWithTx does the same than the GenProof method, but allowing to pass // the db.ReadTx that is used. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) { - keyPath := make([]byte, t.hashFunction.Len()) - // if len(k) > t.hashFunction.Len() { // WIP - // return nil, nil, nil, false, fmt.Errorf("len(k) > hashFunction.Len()") - // } + keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd copy(keyPath[:], k) + path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(rTx) if err != nil { return nil, nil, nil, false, err } - path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte _, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true) @@ -793,18 +782,15 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { // ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data // found in the tree in the leaf that was on the path going to the input key. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) { - keyPath := make([]byte, t.hashFunction.Len()) - // if len(k) > t.hashFunction.Len() { // WIP - // return nil, nil, fmt.Errorf("len(k) > hashFunction.Len()") - // } + keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd copy(keyPath[:], k) + path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(rTx) if err != nil { return nil, nil, err } - path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte _, value, _, err := t.down(rTx, k, root, siblings, path, 0, true) @@ -827,7 +813,7 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, return false, err } - keyPath := make([]byte, hashFunc.Len()) + keyPath := make([]byte, len(siblings)) copy(keyPath[:], k) key, _, err := newLeafValue(hashFunc, k, v) diff --git a/tree_test.go b/tree_test.go index c794377..f6592c4 100644 --- a/tree_test.go +++ b/tree_test.go @@ -640,6 +640,40 @@ func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) { c.Assert(err, qt.Equals, ErrKeyNotFound) // and not equal to db.ErrKeyNotFound } +func TestKeyLen(t *testing.T) { + c := qt.New(t) + database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) + c.Assert(err, qt.IsNil) + // maxLevels is 100, keyPath length = ceil(maxLevels/8) = 13 + maxLevels := 100 + tree, err := NewTree(database, maxLevels, HashFunctionPoseidon) + c.Assert(err, qt.IsNil) + + // expect no errors when adding a key of only 4 bytes (when the + // required length of keyPath for 100 levels would be 13 bytes) + bLen := 4 + k := BigIntToBytes(bLen, big.NewInt(1)) + v := BigIntToBytes(bLen, big.NewInt(1)) + + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) + + err = tree.Update(k, v) + c.Assert(err, qt.IsNil) + + _, _, _, _, err = tree.GenProof(k) + c.Assert(err, qt.IsNil) + + _, _, err = tree.Get(k) + c.Assert(err, qt.IsNil) + + k = BigIntToBytes(bLen, big.NewInt(2)) + v = BigIntToBytes(bLen, big.NewInt(2)) + invalids, err := tree.AddBatch([][]byte{k}, [][]byte{v}) + c.Assert(err, qt.IsNil) + c.Assert(len(invalids), qt.Equals, 0) +} + func BenchmarkAdd(b *testing.B) { bLen := 32 // for both Poseidon & Sha256 // prepare inputs diff --git a/vt.go b/vt.go index 9232ec6..939c6b6 100644 --- a/vt.go +++ b/vt.go @@ -44,7 +44,7 @@ func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { } kvs := make([]kv, len(ks)) for i := 0; i < len(ks); i++ { - keyPath := make([]byte, p.hashFunction.Len()) + keyPath := make([]byte, int(math.Ceil(float64(p.maxLevels)/float64(8)))) //nolint:gomnd copy(keyPath[:], ks[i]) kvs[i].pos = i kvs[i].keyPath = keyPath