diff --git a/tree.go b/tree.go index a594b30..a5f3f57 100644 --- a/tree.go +++ b/tree.go @@ -39,6 +39,9 @@ const ( // nChars is used to crop the Graphviz nodes labels nChars = 4 + + maxUint8 = int(^uint8(0)) // 2**8 -1 + maxUint16 = int(^uint16(0)) // 2**16 -1 ) var ( @@ -243,16 +246,21 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err func (t *Tree) loadVT(rTx db.ReadTx) (vt, error) { vt := newVT(t.maxLevels, t.hashFunction) vt.params.dbg = t.dbg - err := t.IterateWithTx(rTx, nil, func(k, v []byte) { + var callbackErr error + err := t.IterateWithStopWithTx(rTx, nil, func(_ int, k, v []byte) bool { if v[0] != PrefixValueLeaf { - return + return false } leafK, leafV := ReadLeafValue(v) if err := vt.add(0, leafK, leafV); err != nil { - // TODO instead of panic, return this error - panic(err) + callbackErr = err + return true } + return false }) + if callbackErr != nil { + return vt, callbackErr + } return vt, err } @@ -304,6 +312,9 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { keyPath := make([]byte, t.hashFunction.Len()) + // if len(k) > t.hashFunction.Len() { // WIP + // return nil, fmt.Errorf("len(k) > hashFunction.Len()") + // } copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) @@ -360,10 +371,10 @@ func (t *Tree) down(rTx db.ReadTx, newKey, currKey []byte, siblings [][]byte, fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n", hex.EncodeToString(newKey), hex.EncodeToString(currKey), currLvl, hex.EncodeToString(currValue)) - panic("This point should not be reached, as the 'if' above" + - " should avoid reaching this point. This panic is temporary" + + panic("This point should not be reached, as the 'if currKey==t.emptyHash'" + + " above should avoid reaching this point. This panic is temporary" + " for reporting purposes, will be deleted in future versions." + - " Please paste this log (including the previous lines) in a" + + " Please paste this log (including the previous log lines) in a" + " new issue: https://github.com/vocdoni/arbo/issues/new") // TMP case PrefixValueLeaf: // leaf if !bytes.Equal(currValue, emptyValue) { @@ -376,6 +387,10 @@ func (t *Tree) down(rTx db.ReadTx, newKey, currKey []byte, siblings [][]byte, } oldLeafKeyFull := make([]byte, t.hashFunction.Len()) + // if len(oldLeafKey) > t.hashFunction.Len() { // WIP + // return nil, nil, nil, + // fmt.Errorf("len(oldLeafKey) > hashFunction.Len()") + // } copy(oldLeafKeyFull[:], oldLeafKey) // if currKey is already used, go down until paths diverge @@ -479,6 +494,9 @@ func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) { } var leafValue []byte leafValue = append(leafValue, byte(PrefixValueLeaf)) + if len(k) > maxUint8 { + return nil, nil, fmt.Errorf("newLeafValue: len(k) > %v", maxUint8) + } leafValue = append(leafValue, byte(len(k))) leafValue = append(leafValue, k...) leafValue = append(leafValue, v...) @@ -514,6 +532,9 @@ func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) { func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) { b := make([]byte, PrefixValueLen+hashFunc.Len()*2) b[0] = PrefixValueIntermediate + if len(l) > maxUint8 { + return nil, nil, fmt.Errorf("newIntermediate: len(l) > %v", maxUint8) + } b[1] = byte(len(l)) copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l) copy(b[PrefixValueLen+hashFunc.Len():], r) @@ -575,6 +596,9 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { var err error keyPath := make([]byte, t.hashFunction.Len()) + // if len(k) > t.hashFunction.Len() { // WIP + // return fmt.Errorf("len(k) > hashFunction.Len()") + // } copy(keyPath[:], k) path := getPath(t.maxLevels, keyPath) @@ -632,6 +656,9 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) { // the db.ReadTx that is used. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) { keyPath := make([]byte, t.hashFunction.Len()) + // if len(k) > t.hashFunction.Len() { // WIP + // return nil, nil, nil, false, fmt.Errorf("len(k) > hashFunction.Len()") + // } copy(keyPath[:], k) root, err := t.RootWithTx(rTx) @@ -647,7 +674,10 @@ func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, return nil, nil, nil, false, err } - s := PackSiblings(t.hashFunction, siblings) + s, err := PackSiblings(t.hashFunction, siblings) + if err != nil { + return nil, nil, nil, false, err + } leafK, leafV := ReadLeafValue(value) if !bytes.Equal(k, leafK) { @@ -665,7 +695,7 @@ func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, // array. And S is the size of the output of the hash function used for the // Tree. The 2 2-byte that define the full length and bitmap length, are // encoded in little-endian. -func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { +func PackSiblings(hashFunc HashFunction, siblings [][]byte) ([]byte, error) { var b []byte var bitmap []bool emptySibling := make([]byte, hashFunc.Len()) @@ -680,14 +710,20 @@ func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { bitmapBytes := bitmapToBytes(bitmap) l := len(bitmapBytes) + if l > maxUint16 { + return nil, fmt.Errorf("PackSiblings: bitmapBytes length > %v", maxUint16) + } fullLen := 4 + l + len(b) //nolint:gomnd + if fullLen > maxUint16 { + return nil, fmt.Errorf("PackSiblings: fullLen > %v", maxUint16) + } res := make([]byte, fullLen) binary.LittleEndian.PutUint16(res[0:2], uint16(fullLen)) // set full length binary.LittleEndian.PutUint16(res[2:4], uint16(l)) // set the bitmapBytes length copy(res[4:4+l], bitmapBytes) copy(res[4+l:], b) - return res + return res, nil } // UnpackSiblings unpacks the siblings from a byte array. @@ -696,7 +732,7 @@ func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) { l := binary.LittleEndian.Uint16(b[2:4]) // bitmap bytes length if len(b) != int(fullLen) { return nil, - fmt.Errorf("error unpacking siblings. Expected len: %d, current len: %d", + fmt.Errorf("expected len: %d, current len: %d", fullLen, len(b)) } @@ -758,6 +794,9 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { // found in the tree in the leaf that was on the path going to the input key. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) { keyPath := make([]byte, t.hashFunction.Len()) + // if len(k) > t.hashFunction.Len() { // WIP + // return nil, nil, fmt.Errorf("len(k) > hashFunction.Len()") + // } copy(keyPath[:], k) root, err := t.RootWithTx(rTx) @@ -1038,18 +1077,31 @@ func (t *Tree) Dump(fromRoot []byte) ([]byte, error) { // WARNING current encoding only supports key & values of 255 bytes each // (due using only 1 byte for the length headers). var b []byte - err := t.Iterate(fromRoot, func(k, v []byte) { + var callbackErr error + err := t.IterateWithStop(fromRoot, func(_ int, k, v []byte) bool { if v[0] != PrefixValueLeaf { - return + return false } leafK, leafV := ReadLeafValue(v) kv := make([]byte, 2+len(leafK)+len(leafV)) + if len(leafK) > maxUint8 { + callbackErr = fmt.Errorf("len(leafK) > %v", maxUint8) + return true + } kv[0] = byte(len(leafK)) + if len(leafV) > maxUint8 { + callbackErr = fmt.Errorf("len(leafV) > %v", maxUint8) + return true + } kv[1] = byte(len(leafV)) copy(kv[2:2+len(leafK)], leafK) copy(kv[2+len(leafK):], leafV) b = append(b, kv...) + return false }) + if callbackErr != nil { + return nil, callbackErr + } return b, err }