You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

928 lines
23 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "encoding/binary"
  15. "encoding/hex"
  16. "fmt"
  17. "io"
  18. "math"
  19. "sync"
  20. "sync/atomic"
  21. "time"
  22. "github.com/iden3/go-merkletree/db"
  23. )
  24. const (
  25. // PrefixValueLen defines the bytes-prefix length used for the Value
  26. // bytes representation stored in the db
  27. PrefixValueLen = 2
  28. // PrefixValueEmpty is used for the first byte of a Value to indicate
  29. // that is an Empty value
  30. PrefixValueEmpty = 0
  31. // PrefixValueLeaf is used for the first byte of a Value to indicate
  32. // that is a Leaf value
  33. PrefixValueLeaf = 1
  34. // PrefixValueIntermediate is used for the first byte of a Value to
  35. // indicate that is a Intermediate value
  36. PrefixValueIntermediate = 2
  37. )
  38. var (
  39. dbKeyRoot = []byte("root")
  40. dbKeyNLeafs = []byte("nleafs")
  41. emptyValue = []byte{0}
  42. )
  43. // Tree defines the struct that implements the MerkleTree functionalities
  44. type Tree struct {
  45. sync.RWMutex
  46. tx db.Tx
  47. db db.Storage
  48. lastAccess int64 // in unix time // TODO delete, is a feature of a upper abstraction level
  49. maxLevels int
  50. root []byte
  51. hashFunction HashFunction
  52. // TODO in the methods that use it, check if emptyHash param is len>0
  53. // (check if it has been initialized)
  54. emptyHash []byte
  55. dbg *dbgStats
  56. }
  57. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  58. // will load it.
  59. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  60. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  61. t.updateAccessTime()
  62. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  63. root, err := t.dbGet(dbKeyRoot)
  64. if err == db.ErrNotFound {
  65. // store new root 0
  66. t.tx, err = t.db.NewTx()
  67. if err != nil {
  68. return nil, err
  69. }
  70. t.root = t.emptyHash
  71. if err = t.dbPut(dbKeyRoot, t.root); err != nil {
  72. return nil, err
  73. }
  74. if err = t.setNLeafs(0); err != nil {
  75. return nil, err
  76. }
  77. if err = t.tx.Commit(); err != nil {
  78. return nil, err
  79. }
  80. return &t, err
  81. } else if err != nil {
  82. return nil, err
  83. }
  84. t.root = root
  85. return &t, nil
  86. }
  87. func (t *Tree) updateAccessTime() {
  88. atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
  89. }
  90. // LastAccess returns the last access timestamp in Unixtime
  91. func (t *Tree) LastAccess() int64 {
  92. return atomic.LoadInt64(&t.lastAccess)
  93. }
  94. // Root returns the root of the Tree
  95. func (t *Tree) Root() []byte {
  96. return t.root
  97. }
  98. // HashFunction returns Tree.hashFunction
  99. func (t *Tree) HashFunction() HashFunction {
  100. return t.hashFunction
  101. }
  102. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  103. // the indexes of the keys failed to add.
  104. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  105. t.updateAccessTime()
  106. t.Lock()
  107. defer t.Unlock()
  108. vt, err := t.loadVT()
  109. if err != nil {
  110. return nil, err
  111. }
  112. // TODO check that keys & values is valid for Tree.hashFunction
  113. invalids, err := vt.addBatch(keys, values)
  114. if err != nil {
  115. return nil, err
  116. }
  117. // once the VirtualTree is build, compute the hashes
  118. pairs, err := vt.computeHashes()
  119. if err != nil {
  120. return nil, err
  121. }
  122. t.root = vt.root.h
  123. // store pairs in db
  124. t.tx, err = t.db.NewTx()
  125. if err != nil {
  126. return nil, err
  127. }
  128. for i := 0; i < len(pairs); i++ {
  129. if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil {
  130. return nil, err
  131. }
  132. }
  133. // store root to db
  134. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  135. return nil, err
  136. }
  137. // update nLeafs
  138. if err := t.incNLeafs(len(keys) - len(invalids)); err != nil {
  139. return nil, err
  140. }
  141. // commit db tx
  142. if err := t.tx.Commit(); err != nil {
  143. return nil, err
  144. }
  145. return invalids, nil
  146. }
  147. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  148. // the same leafs.
  149. func (t *Tree) loadVT() (vt, error) {
  150. vt := newVT(t.maxLevels, t.hashFunction)
  151. vt.params.dbg = t.dbg
  152. err := t.Iterate(func(k, v []byte) {
  153. if v[0] != PrefixValueLeaf {
  154. return
  155. }
  156. leafK, leafV := ReadLeafValue(v)
  157. if err := vt.add(0, leafK, leafV); err != nil {
  158. panic(err)
  159. }
  160. })
  161. return vt, err
  162. }
  163. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  164. // is expected that are represented by a Little-Endian byte array (for circom
  165. // compatibility).
  166. func (t *Tree) Add(k, v []byte) error {
  167. t.updateAccessTime()
  168. t.Lock()
  169. defer t.Unlock()
  170. var err error
  171. t.tx, err = t.db.NewTx()
  172. if err != nil {
  173. return err
  174. }
  175. err = t.add(0, k, v) // add from level 0
  176. if err != nil {
  177. return err
  178. }
  179. // store root to db
  180. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  181. return err
  182. }
  183. // update nLeafs
  184. if err = t.incNLeafs(1); err != nil {
  185. return err
  186. }
  187. return t.tx.Commit()
  188. }
  189. func (t *Tree) add(fromLvl int, k, v []byte) error {
  190. // TODO check validity of key & value (for the Tree.HashFunction type)
  191. keyPath := make([]byte, t.hashFunction.Len())
  192. copy(keyPath[:], k)
  193. path := getPath(t.maxLevels, keyPath)
  194. // go down to the leaf
  195. var siblings [][]byte
  196. _, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false)
  197. if err != nil {
  198. return err
  199. }
  200. leafKey, leafValue, err := t.newLeafValue(k, v)
  201. if err != nil {
  202. return err
  203. }
  204. if err := t.dbPut(leafKey, leafValue); err != nil {
  205. return err
  206. }
  207. // go up to the root
  208. if len(siblings) == 0 {
  209. t.root = leafKey
  210. return nil
  211. }
  212. root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl)
  213. if err != nil {
  214. return err
  215. }
  216. t.root = root
  217. return nil
  218. }
  219. // down goes down to the leaf recursively
  220. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
  221. path []bool, currLvl int, getLeaf bool) (
  222. []byte, []byte, [][]byte, error) {
  223. if currLvl > t.maxLevels-1 {
  224. return nil, nil, nil, fmt.Errorf("max level")
  225. }
  226. var err error
  227. var currValue []byte
  228. if bytes.Equal(currKey, t.emptyHash) {
  229. // empty value
  230. return currKey, emptyValue, siblings, nil
  231. }
  232. currValue, err = t.dbGet(currKey)
  233. if err != nil {
  234. return nil, nil, nil, err
  235. }
  236. switch currValue[0] {
  237. case PrefixValueEmpty: // empty
  238. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  239. // reaching this point
  240. // return currKey, empty, siblings, nil
  241. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  242. case PrefixValueLeaf: // leaf
  243. if bytes.Equal(newKey, currKey) {
  244. // TODO move this error msg to const & add test that
  245. // checks that adding a repeated key this error is
  246. // returned
  247. return nil, nil, nil, fmt.Errorf("key already exists")
  248. }
  249. if !bytes.Equal(currValue, emptyValue) {
  250. if getLeaf {
  251. return currKey, currValue, siblings, nil
  252. }
  253. oldLeafKey, _ := ReadLeafValue(currValue)
  254. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  255. copy(oldLeafKeyFull[:], oldLeafKey)
  256. // if currKey is already used, go down until paths diverge
  257. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  258. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  259. if err != nil {
  260. return nil, nil, nil, err
  261. }
  262. }
  263. return currKey, currValue, siblings, nil
  264. case PrefixValueIntermediate: // intermediate
  265. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  266. return nil, nil, nil,
  267. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  268. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  269. }
  270. // collect siblings while going down
  271. if path[currLvl] {
  272. // right
  273. lChild, rChild := ReadIntermediateChilds(currValue)
  274. siblings = append(siblings, lChild)
  275. return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf)
  276. }
  277. // left
  278. lChild, rChild := ReadIntermediateChilds(currValue)
  279. siblings = append(siblings, rChild)
  280. return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf)
  281. default:
  282. return nil, nil, nil, fmt.Errorf("invalid value")
  283. }
  284. }
  285. // downVirtually is used when in a leaf already exists, and a new leaf which
  286. // shares the path until the existing leaf is being added
  287. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  288. newPath []bool, currLvl int) ([][]byte, error) {
  289. var err error
  290. if currLvl > t.maxLevels-1 {
  291. return nil, fmt.Errorf("max virtual level %d", currLvl)
  292. }
  293. if oldPath[currLvl] == newPath[currLvl] {
  294. siblings = append(siblings, t.emptyHash)
  295. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  296. if err != nil {
  297. return nil, err
  298. }
  299. return siblings, nil
  300. }
  301. // reached the divergence
  302. siblings = append(siblings, oldKey)
  303. return siblings, nil
  304. }
  305. // up goes up recursively updating the intermediate nodes
  306. func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) {
  307. var k, v []byte
  308. var err error
  309. if path[currLvl+toLvl] {
  310. k, v, err = t.newIntermediate(siblings[currLvl], key)
  311. if err != nil {
  312. return nil, err
  313. }
  314. } else {
  315. k, v, err = t.newIntermediate(key, siblings[currLvl])
  316. if err != nil {
  317. return nil, err
  318. }
  319. }
  320. // store k-v to db
  321. if err = t.dbPut(k, v); err != nil {
  322. return nil, err
  323. }
  324. if currLvl == 0 {
  325. // reached the root
  326. return k, nil
  327. }
  328. return t.up(k, siblings, path, currLvl-1, toLvl)
  329. }
  330. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  331. t.dbg.incHash()
  332. return newLeafValue(t.hashFunction, k, v)
  333. }
  334. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  335. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  336. if err != nil {
  337. return nil, nil, err
  338. }
  339. var leafValue []byte
  340. leafValue = append(leafValue, byte(1))
  341. leafValue = append(leafValue, byte(len(k)))
  342. leafValue = append(leafValue, k...)
  343. leafValue = append(leafValue, v...)
  344. return leafKey, leafValue, nil
  345. }
  346. // ReadLeafValue reads from a byte array the leaf key & value
  347. func ReadLeafValue(b []byte) ([]byte, []byte) {
  348. if len(b) < PrefixValueLen {
  349. return []byte{}, []byte{}
  350. }
  351. kLen := b[1]
  352. if len(b) < PrefixValueLen+int(kLen) {
  353. return []byte{}, []byte{}
  354. }
  355. k := b[PrefixValueLen : PrefixValueLen+kLen]
  356. v := b[PrefixValueLen+kLen:]
  357. return k, v
  358. }
  359. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  360. t.dbg.incHash()
  361. return newIntermediate(t.hashFunction, l, r)
  362. }
  363. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  364. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  365. b[0] = 2
  366. b[1] = byte(len(l))
  367. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  368. copy(b[PrefixValueLen+hashFunc.Len():], r)
  369. key, err := hashFunc.Hash(l, r)
  370. if err != nil {
  371. return nil, nil, err
  372. }
  373. return key, b, nil
  374. }
  375. // ReadIntermediateChilds reads from a byte array the two childs keys
  376. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  377. if len(b) < PrefixValueLen {
  378. return []byte{}, []byte{}
  379. }
  380. lLen := b[1]
  381. if len(b) < PrefixValueLen+int(lLen) {
  382. return []byte{}, []byte{}
  383. }
  384. l := b[PrefixValueLen : PrefixValueLen+lLen]
  385. r := b[PrefixValueLen+lLen:]
  386. return l, r
  387. }
  388. func getPath(numLevels int, k []byte) []bool {
  389. path := make([]bool, numLevels)
  390. for n := 0; n < numLevels; n++ {
  391. path[n] = k[n/8]&(1<<(n%8)) != 0
  392. }
  393. return path
  394. }
  395. // Update updates the value for a given existing key. If the given key does not
  396. // exist, returns an error.
  397. func (t *Tree) Update(k, v []byte) error {
  398. t.updateAccessTime()
  399. t.Lock()
  400. defer t.Unlock()
  401. var err error
  402. t.tx, err = t.db.NewTx()
  403. if err != nil {
  404. return err
  405. }
  406. keyPath := make([]byte, t.hashFunction.Len())
  407. copy(keyPath[:], k)
  408. path := getPath(t.maxLevels, keyPath)
  409. var siblings [][]byte
  410. _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  411. if err != nil {
  412. return err
  413. }
  414. oldKey, _ := ReadLeafValue(valueAtBottom)
  415. if !bytes.Equal(oldKey, k) {
  416. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  417. }
  418. leafKey, leafValue, err := t.newLeafValue(k, v)
  419. if err != nil {
  420. return err
  421. }
  422. if err := t.dbPut(leafKey, leafValue); err != nil {
  423. return err
  424. }
  425. // go up to the root
  426. if len(siblings) == 0 {
  427. t.root = leafKey
  428. return t.tx.Commit()
  429. }
  430. root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0)
  431. if err != nil {
  432. return err
  433. }
  434. t.root = root
  435. // store root to db
  436. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  437. return err
  438. }
  439. return t.tx.Commit()
  440. }
  441. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  442. // the Tree, the proof will be of existence, if the key does not exist in the
  443. // tree, the proof will be of non-existence.
  444. func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) {
  445. t.updateAccessTime()
  446. keyPath := make([]byte, t.hashFunction.Len())
  447. copy(keyPath[:], k)
  448. path := getPath(t.maxLevels, keyPath)
  449. // go down to the leaf
  450. var siblings [][]byte
  451. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  452. if err != nil {
  453. return nil, nil, err
  454. }
  455. leafK, leafV := ReadLeafValue(value)
  456. if !bytes.Equal(k, leafK) {
  457. fmt.Println("key not in Tree")
  458. fmt.Println(leafK)
  459. fmt.Println(leafV)
  460. // TODO proof of non-existence
  461. panic(fmt.Errorf("unimplemented"))
  462. }
  463. s := PackSiblings(t.hashFunction, siblings)
  464. return leafV, s, nil
  465. }
  466. // PackSiblings packs the siblings into a byte array.
  467. // [ 1 byte | L bytes | S * N bytes ]
  468. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  469. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  470. // array. And S is the size of the output of the hash function used for the
  471. // Tree.
  472. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  473. var b []byte
  474. var bitmap []bool
  475. emptySibling := make([]byte, hashFunc.Len())
  476. for i := 0; i < len(siblings); i++ {
  477. if bytes.Equal(siblings[i], emptySibling) {
  478. bitmap = append(bitmap, false)
  479. } else {
  480. bitmap = append(bitmap, true)
  481. b = append(b, siblings[i]...)
  482. }
  483. }
  484. bitmapBytes := bitmapToBytes(bitmap)
  485. l := len(bitmapBytes)
  486. res := make([]byte, l+1+len(b))
  487. res[0] = byte(l) // set the bitmapBytes length
  488. copy(res[1:1+l], bitmapBytes)
  489. copy(res[1+l:], b)
  490. return res
  491. }
  492. // UnpackSiblings unpacks the siblings from a byte array.
  493. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  494. l := b[0]
  495. bitmapBytes := b[1 : 1+l]
  496. bitmap := bytesToBitmap(bitmapBytes)
  497. siblingsBytes := b[1+l:]
  498. iSibl := 0
  499. emptySibl := make([]byte, hashFunc.Len())
  500. var siblings [][]byte
  501. for i := 0; i < len(bitmap); i++ {
  502. if iSibl >= len(siblingsBytes) {
  503. break
  504. }
  505. if bitmap[i] {
  506. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  507. iSibl += hashFunc.Len()
  508. } else {
  509. siblings = append(siblings, emptySibl)
  510. }
  511. }
  512. return siblings, nil
  513. }
  514. func bitmapToBytes(bitmap []bool) []byte {
  515. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  516. b := make([]byte, bitmapBytesLen)
  517. for i := 0; i < len(bitmap); i++ {
  518. if bitmap[i] {
  519. b[i/8] |= 1 << (i % 8)
  520. }
  521. }
  522. return b
  523. }
  524. func bytesToBitmap(b []byte) []bool {
  525. var bitmap []bool
  526. for i := 0; i < len(b); i++ {
  527. for j := 0; j < 8; j++ {
  528. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  529. }
  530. }
  531. return bitmap
  532. }
  533. // Get returns the value for a given key
  534. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  535. keyPath := make([]byte, t.hashFunction.Len())
  536. copy(keyPath[:], k)
  537. path := getPath(t.maxLevels, keyPath)
  538. // go down to the leaf
  539. var siblings [][]byte
  540. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  541. if err != nil {
  542. return nil, nil, err
  543. }
  544. leafK, leafV := ReadLeafValue(value)
  545. if !bytes.Equal(k, leafK) {
  546. panic(fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s",
  547. BytesToBigInt(k), BytesToBigInt(leafK)))
  548. }
  549. return leafK, leafV, nil
  550. }
  551. // CheckProof verifies the given proof. The proof verification depends on the
  552. // HashFunction passed as parameter.
  553. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  554. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  555. if err != nil {
  556. return false, err
  557. }
  558. keyPath := make([]byte, hashFunc.Len())
  559. copy(keyPath[:], k)
  560. key, _, err := newLeafValue(hashFunc, k, v)
  561. if err != nil {
  562. return false, err
  563. }
  564. path := getPath(len(siblings), keyPath)
  565. for i := len(siblings) - 1; i >= 0; i-- {
  566. if path[i] {
  567. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  568. if err != nil {
  569. return false, err
  570. }
  571. } else {
  572. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  573. if err != nil {
  574. return false, err
  575. }
  576. }
  577. }
  578. if bytes.Equal(key[:], root) {
  579. return true, nil
  580. }
  581. return false, nil
  582. }
  583. func (t *Tree) dbPut(k, v []byte) error {
  584. if t.tx == nil {
  585. return fmt.Errorf("dbPut error: no db Tx")
  586. }
  587. t.dbg.incDbPut()
  588. return t.tx.Put(k, v)
  589. }
  590. func (t *Tree) dbGet(k []byte) ([]byte, error) {
  591. // if key is empty, return empty as value
  592. if bytes.Equal(k, t.emptyHash) {
  593. return t.emptyHash, nil
  594. }
  595. t.dbg.incDbGet()
  596. v, err := t.db.Get(k)
  597. if err == nil {
  598. return v, nil
  599. }
  600. if t.tx != nil {
  601. return t.tx.Get(k)
  602. }
  603. return nil, db.ErrNotFound
  604. }
  605. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  606. // after the setNLeafs call.
  607. func (t *Tree) incNLeafs(nLeafs int) error {
  608. oldNLeafs, err := t.GetNLeafs()
  609. if err != nil {
  610. return err
  611. }
  612. newNLeafs := oldNLeafs + nLeafs
  613. return t.setNLeafs(newNLeafs)
  614. }
  615. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  616. // after the setNLeafs call.
  617. func (t *Tree) setNLeafs(nLeafs int) error {
  618. b := make([]byte, 8)
  619. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  620. if err := t.dbPut(dbKeyNLeafs, b); err != nil {
  621. return err
  622. }
  623. return nil
  624. }
  625. // GetNLeafs returns the number of Leafs of the Tree.
  626. func (t *Tree) GetNLeafs() (int, error) {
  627. b, err := t.dbGet(dbKeyNLeafs)
  628. if err != nil {
  629. return 0, err
  630. }
  631. nLeafs := binary.LittleEndian.Uint64(b)
  632. return int(nLeafs), nil
  633. }
  634. // Iterate iterates through the full Tree, executing the given function on each
  635. // node of the Tree.
  636. func (t *Tree) Iterate(f func([]byte, []byte)) error {
  637. // TODO allow to define which root to use
  638. t.updateAccessTime()
  639. return t.iter(t.root, f)
  640. }
  641. // IterateWithStop does the same than Iterate, but with int for the current
  642. // level, and a boolean parameter used by the passed function, is to indicate to
  643. // stop iterating on the branch when the method returns 'true'.
  644. func (t *Tree) IterateWithStop(f func(int, []byte, []byte) bool) error {
  645. t.updateAccessTime()
  646. return t.iterWithStop(t.root, 0, f)
  647. }
  648. func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error {
  649. v, err := t.dbGet(k)
  650. if err != nil {
  651. return err
  652. }
  653. currLevel++
  654. switch v[0] {
  655. case PrefixValueEmpty:
  656. f(currLevel, k, v)
  657. case PrefixValueLeaf:
  658. f(currLevel, k, v)
  659. case PrefixValueIntermediate:
  660. stop := f(currLevel, k, v)
  661. if stop {
  662. return nil
  663. }
  664. l, r := ReadIntermediateChilds(v)
  665. if err = t.iterWithStop(l, currLevel, f); err != nil {
  666. return err
  667. }
  668. if err = t.iterWithStop(r, currLevel, f); err != nil {
  669. return err
  670. }
  671. default:
  672. return fmt.Errorf("invalid value")
  673. }
  674. return nil
  675. }
  676. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  677. f2 := func(currLvl int, k, v []byte) bool {
  678. f(k, v)
  679. return false
  680. }
  681. return t.iterWithStop(k, 0, f2)
  682. }
  683. // Dump exports all the Tree leafs in a byte array of length:
  684. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  685. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  686. // [ len(k) | len(v) | key | value ]
  687. // Where S is the size of the output of the hash function used for the Tree.
  688. func (t *Tree) Dump() ([]byte, error) {
  689. t.updateAccessTime()
  690. // TODO allow to define which root to use
  691. // WARNING current encoding only supports key & values of 255 bytes each
  692. // (due using only 1 byte for the length headers).
  693. var b []byte
  694. err := t.Iterate(func(k, v []byte) {
  695. if v[0] != PrefixValueLeaf {
  696. return
  697. }
  698. leafK, leafV := ReadLeafValue(v)
  699. kv := make([]byte, 2+len(leafK)+len(leafV))
  700. kv[0] = byte(len(leafK))
  701. kv[1] = byte(len(leafV))
  702. copy(kv[2:2+len(leafK)], leafK)
  703. copy(kv[2+len(leafK):], leafV)
  704. b = append(b, kv...)
  705. })
  706. return b, err
  707. }
  708. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  709. // method) in the Tree.
  710. func (t *Tree) ImportDump(b []byte) error {
  711. t.updateAccessTime()
  712. r := bytes.NewReader(b)
  713. var err error
  714. var keys, values [][]byte
  715. for {
  716. l := make([]byte, 2)
  717. _, err = io.ReadFull(r, l)
  718. if err == io.EOF {
  719. break
  720. } else if err != nil {
  721. return err
  722. }
  723. k := make([]byte, l[0])
  724. _, err = io.ReadFull(r, k)
  725. if err != nil {
  726. return err
  727. }
  728. v := make([]byte, l[1])
  729. _, err = io.ReadFull(r, v)
  730. if err != nil {
  731. return err
  732. }
  733. keys = append(keys, k)
  734. values = append(values, v)
  735. }
  736. if _, err = t.AddBatch(keys, values); err != nil {
  737. return err
  738. }
  739. return nil
  740. }
  741. // Graphviz iterates across the full tree to generate a string Graphviz
  742. // representation of the tree and writes it to w
  743. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  744. return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
  745. }
  746. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  747. // generate a string Graphviz representation of the first NLevels of the tree
  748. // and writes it to w
  749. func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
  750. fmt.Fprintf(w, `digraph hierarchy {
  751. node [fontname=Monospace,fontsize=10,shape=box]
  752. `)
  753. nChars := 4 // TODO move to global constant
  754. nEmpties := 0
  755. err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool {
  756. if currLvl == untilLvl {
  757. return true // to stop the iter from going down
  758. }
  759. switch v[0] {
  760. case PrefixValueEmpty:
  761. case PrefixValueLeaf:
  762. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  763. // key & value from the leaf
  764. kB, vB := ReadLeafValue(v)
  765. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  766. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  767. hex.EncodeToString(vB[:nChars]))
  768. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  769. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  770. case PrefixValueIntermediate:
  771. l, r := ReadIntermediateChilds(v)
  772. lStr := hex.EncodeToString(l[:nChars])
  773. rStr := hex.EncodeToString(r[:nChars])
  774. eStr := ""
  775. if bytes.Equal(l, t.emptyHash) {
  776. lStr = fmt.Sprintf("empty%v", nEmpties)
  777. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  778. lStr)
  779. nEmpties++
  780. }
  781. if bytes.Equal(r, t.emptyHash) {
  782. rStr = fmt.Sprintf("empty%v", nEmpties)
  783. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  784. rStr)
  785. nEmpties++
  786. }
  787. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  788. lStr, rStr)
  789. fmt.Fprint(w, eStr)
  790. default:
  791. }
  792. return false
  793. })
  794. fmt.Fprintf(w, "}\n")
  795. return err
  796. }
  797. // PrintGraphviz prints the output of Tree.Graphviz
  798. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  799. return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
  800. }
  801. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  802. func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
  803. if rootKey == nil {
  804. rootKey = t.Root()
  805. }
  806. w := bytes.NewBufferString("")
  807. fmt.Fprintf(w,
  808. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  809. err := t.GraphvizFirstNLevels(w, nil, untilLvl)
  810. if err != nil {
  811. fmt.Println(w)
  812. return err
  813. }
  814. fmt.Fprintf(w,
  815. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  816. fmt.Println(w)
  817. return nil
  818. }
  819. // Purge WIP: unimplemented
  820. func (t *Tree) Purge(keys [][]byte) error {
  821. return nil
  822. }