You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

838 lines
20 KiB

3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "encoding/binary"
  15. "encoding/hex"
  16. "fmt"
  17. "io"
  18. "math"
  19. "sync"
  20. "sync/atomic"
  21. "time"
  22. "github.com/iden3/go-merkletree/db"
  23. )
  24. const (
  25. // PrefixValueLen defines the bytes-prefix length used for the Value
  26. // bytes representation stored in the db
  27. PrefixValueLen = 2
  28. // PrefixValueEmpty is used for the first byte of a Value to indicate
  29. // that is an Empty value
  30. PrefixValueEmpty = 0
  31. // PrefixValueLeaf is used for the first byte of a Value to indicate
  32. // that is a Leaf value
  33. PrefixValueLeaf = 1
  34. // PrefixValueIntermediate is used for the first byte of a Value to
  35. // indicate that is a Intermediate value
  36. PrefixValueIntermediate = 2
  37. )
  38. var (
  39. dbKeyRoot = []byte("root")
  40. dbKeyNLeafs = []byte("nleafs")
  41. emptyValue = []byte{0}
  42. )
  43. // Tree defines the struct that implements the MerkleTree functionalities
  44. type Tree struct {
  45. sync.RWMutex
  46. tx db.Tx
  47. db db.Storage
  48. lastAccess int64 // in unix time
  49. maxLevels int
  50. root []byte
  51. hashFunction HashFunction
  52. emptyHash []byte
  53. }
  54. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  55. // will load it.
  56. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  57. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  58. t.updateAccessTime()
  59. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  60. root, err := t.dbGet(dbKeyRoot)
  61. if err == db.ErrNotFound {
  62. // store new root 0
  63. t.tx, err = t.db.NewTx()
  64. if err != nil {
  65. return nil, err
  66. }
  67. t.root = t.emptyHash
  68. if err = t.tx.Put(dbKeyRoot, t.root); err != nil {
  69. return nil, err
  70. }
  71. if err = t.setNLeafs(0); err != nil {
  72. return nil, err
  73. }
  74. if err = t.tx.Commit(); err != nil {
  75. return nil, err
  76. }
  77. return &t, err
  78. } else if err != nil {
  79. return nil, err
  80. }
  81. t.root = root
  82. return &t, nil
  83. }
  84. func (t *Tree) updateAccessTime() {
  85. atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
  86. }
  87. // LastAccess returns the last access timestamp in Unixtime
  88. func (t *Tree) LastAccess() int64 {
  89. return atomic.LoadInt64(&t.lastAccess)
  90. }
  91. // Root returns the root of the Tree
  92. func (t *Tree) Root() []byte {
  93. return t.root
  94. }
  95. // AddBatch adds a batch of key-values to the Tree. This method will be
  96. // optimized to do some internal parallelization. Returns an array containing
  97. // the indexes of the keys failed to add.
  98. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  99. t.updateAccessTime()
  100. if len(keys) != len(values) {
  101. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  102. len(keys), len(values))
  103. }
  104. t.Lock()
  105. defer t.Unlock()
  106. var err error
  107. t.tx, err = t.db.NewTx()
  108. if err != nil {
  109. return nil, err
  110. }
  111. var indexes []int
  112. for i := 0; i < len(keys); i++ {
  113. err = t.add(keys[i], values[i])
  114. if err != nil {
  115. indexes = append(indexes, i)
  116. }
  117. }
  118. // store root to db
  119. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  120. return indexes, err
  121. }
  122. // update nLeafs
  123. if err = t.incNLeafs(uint64(len(keys) - len(indexes))); err != nil {
  124. return indexes, err
  125. }
  126. if err := t.tx.Commit(); err != nil {
  127. return nil, err
  128. }
  129. return indexes, nil
  130. }
  131. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  132. // is expected that are represented by a Little-Endian byte array (for circom
  133. // compatibility).
  134. func (t *Tree) Add(k, v []byte) error {
  135. t.updateAccessTime()
  136. t.Lock()
  137. defer t.Unlock()
  138. var err error
  139. t.tx, err = t.db.NewTx()
  140. if err != nil {
  141. return err
  142. }
  143. err = t.add(k, v)
  144. if err != nil {
  145. return err
  146. }
  147. // store root to db
  148. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  149. return err
  150. }
  151. // update nLeafs
  152. if err = t.incNLeafs(1); err != nil {
  153. return err
  154. }
  155. return t.tx.Commit()
  156. }
  157. func (t *Tree) add(k, v []byte) error {
  158. // TODO check validity of key & value (for the Tree.HashFunction type)
  159. keyPath := make([]byte, t.hashFunction.Len())
  160. copy(keyPath[:], k)
  161. path := getPath(t.maxLevels, keyPath)
  162. // go down to the leaf
  163. var siblings [][]byte
  164. _, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
  165. if err != nil {
  166. return err
  167. }
  168. leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
  169. if err != nil {
  170. return err
  171. }
  172. if err := t.tx.Put(leafKey, leafValue); err != nil {
  173. return err
  174. }
  175. // go up to the root
  176. if len(siblings) == 0 {
  177. t.root = leafKey
  178. return nil
  179. }
  180. root, err := t.up(leafKey, siblings, path, len(siblings)-1)
  181. if err != nil {
  182. return err
  183. }
  184. t.root = root
  185. return nil
  186. }
  187. // down goes down to the leaf recursively
  188. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
  189. path []bool, l int, getLeaf bool) (
  190. []byte, []byte, [][]byte, error) {
  191. if l > t.maxLevels-1 {
  192. return nil, nil, nil, fmt.Errorf("max level")
  193. }
  194. var err error
  195. var currValue []byte
  196. if bytes.Equal(currKey, t.emptyHash) {
  197. // empty value
  198. return currKey, emptyValue, siblings, nil
  199. }
  200. currValue, err = t.dbGet(currKey)
  201. if err != nil {
  202. return nil, nil, nil, err
  203. }
  204. switch currValue[0] {
  205. case PrefixValueEmpty: // empty
  206. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  207. // reaching this point
  208. // return currKey, empty, siblings, nil
  209. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  210. case PrefixValueLeaf: // leaf
  211. if bytes.Equal(newKey, currKey) {
  212. return nil, nil, nil, fmt.Errorf("key already exists")
  213. }
  214. if !bytes.Equal(currValue, emptyValue) {
  215. if getLeaf {
  216. return currKey, currValue, siblings, nil
  217. }
  218. oldLeafKey, _ := readLeafValue(currValue)
  219. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  220. copy(oldLeafKeyFull[:], oldLeafKey)
  221. // if currKey is already used, go down until paths diverge
  222. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  223. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
  224. if err != nil {
  225. return nil, nil, nil, err
  226. }
  227. }
  228. return currKey, currValue, siblings, nil
  229. case PrefixValueIntermediate: // intermediate
  230. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  231. return nil, nil, nil,
  232. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  233. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  234. }
  235. // collect siblings while going down
  236. if path[l] {
  237. // right
  238. lChild, rChild := readIntermediateChilds(currValue)
  239. siblings = append(siblings, lChild)
  240. return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
  241. }
  242. // left
  243. lChild, rChild := readIntermediateChilds(currValue)
  244. siblings = append(siblings, rChild)
  245. return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
  246. default:
  247. return nil, nil, nil, fmt.Errorf("invalid value")
  248. }
  249. }
  250. // downVirtually is used when in a leaf already exists, and a new leaf which
  251. // shares the path until the existing leaf is being added
  252. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  253. newPath []bool, l int) ([][]byte, error) {
  254. var err error
  255. if l > t.maxLevels-1 {
  256. return nil, fmt.Errorf("max virtual level %d", l)
  257. }
  258. if oldPath[l] == newPath[l] {
  259. siblings = append(siblings, t.emptyHash)
  260. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
  261. if err != nil {
  262. return nil, err
  263. }
  264. return siblings, nil
  265. }
  266. // reached the divergence
  267. siblings = append(siblings, oldKey)
  268. return siblings, nil
  269. }
  270. // up goes up recursively updating the intermediate nodes
  271. func (t *Tree) up(key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
  272. var k, v []byte
  273. var err error
  274. if path[l] {
  275. k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
  276. if err != nil {
  277. return nil, err
  278. }
  279. } else {
  280. k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
  281. if err != nil {
  282. return nil, err
  283. }
  284. }
  285. // store k-v to db
  286. if err = t.tx.Put(k, v); err != nil {
  287. return nil, err
  288. }
  289. if l == 0 {
  290. // reached the root
  291. return k, nil
  292. }
  293. return t.up(k, siblings, path, l-1)
  294. }
  295. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  296. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  297. if err != nil {
  298. return nil, nil, err
  299. }
  300. var leafValue []byte
  301. leafValue = append(leafValue, byte(1))
  302. leafValue = append(leafValue, byte(len(k)))
  303. leafValue = append(leafValue, k...)
  304. leafValue = append(leafValue, v...)
  305. return leafKey, leafValue, nil
  306. }
  307. func readLeafValue(b []byte) ([]byte, []byte) {
  308. if len(b) < PrefixValueLen {
  309. return []byte{}, []byte{}
  310. }
  311. kLen := b[1]
  312. if len(b) < PrefixValueLen+int(kLen) {
  313. return []byte{}, []byte{}
  314. }
  315. k := b[PrefixValueLen : PrefixValueLen+kLen]
  316. v := b[PrefixValueLen+kLen:]
  317. return k, v
  318. }
  319. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  320. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  321. b[0] = 2
  322. b[1] = byte(len(l))
  323. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  324. copy(b[PrefixValueLen+hashFunc.Len():], r)
  325. key, err := hashFunc.Hash(l, r)
  326. if err != nil {
  327. return nil, nil, err
  328. }
  329. return key, b, nil
  330. }
  331. func readIntermediateChilds(b []byte) ([]byte, []byte) {
  332. if len(b) < PrefixValueLen {
  333. return []byte{}, []byte{}
  334. }
  335. lLen := b[1]
  336. if len(b) < PrefixValueLen+int(lLen) {
  337. return []byte{}, []byte{}
  338. }
  339. l := b[PrefixValueLen : PrefixValueLen+lLen]
  340. r := b[PrefixValueLen+lLen:]
  341. return l, r
  342. }
  343. func getPath(numLevels int, k []byte) []bool {
  344. path := make([]bool, numLevels)
  345. for n := 0; n < numLevels; n++ {
  346. path[n] = k[n/8]&(1<<(n%8)) != 0
  347. }
  348. return path
  349. }
  350. // Update updates the value for a given existing key. If the given key does not
  351. // exist, returns an error.
  352. func (t *Tree) Update(k, v []byte) error {
  353. t.updateAccessTime()
  354. t.Lock()
  355. defer t.Unlock()
  356. var err error
  357. t.tx, err = t.db.NewTx()
  358. if err != nil {
  359. return err
  360. }
  361. keyPath := make([]byte, t.hashFunction.Len())
  362. copy(keyPath[:], k)
  363. path := getPath(t.maxLevels, keyPath)
  364. var siblings [][]byte
  365. _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  366. if err != nil {
  367. return err
  368. }
  369. oldKey, _ := readLeafValue(valueAtBottom)
  370. if !bytes.Equal(oldKey, k) {
  371. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  372. }
  373. leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
  374. if err != nil {
  375. return err
  376. }
  377. if err := t.tx.Put(leafKey, leafValue); err != nil {
  378. return err
  379. }
  380. // go up to the root
  381. if len(siblings) == 0 {
  382. t.root = leafKey
  383. return t.tx.Commit()
  384. }
  385. root, err := t.up(leafKey, siblings, path, len(siblings)-1)
  386. if err != nil {
  387. return err
  388. }
  389. t.root = root
  390. // store root to db
  391. if err := t.tx.Put(dbKeyRoot, t.root); err != nil {
  392. return err
  393. }
  394. return t.tx.Commit()
  395. }
  396. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  397. // the Tree, the proof will be of existence, if the key does not exist in the
  398. // tree, the proof will be of non-existence.
  399. func (t *Tree) GenProof(k []byte) ([]byte, error) {
  400. t.updateAccessTime()
  401. keyPath := make([]byte, t.hashFunction.Len())
  402. copy(keyPath[:], k)
  403. path := getPath(t.maxLevels, keyPath)
  404. // go down to the leaf
  405. var siblings [][]byte
  406. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  407. if err != nil {
  408. return nil, err
  409. }
  410. leafK, leafV := readLeafValue(value)
  411. if !bytes.Equal(k, leafK) {
  412. fmt.Println("key not in Tree")
  413. fmt.Println(leafK)
  414. fmt.Println(leafV)
  415. // TODO proof of non-existence
  416. panic(fmt.Errorf("unimplemented"))
  417. }
  418. s := PackSiblings(t.hashFunction, siblings)
  419. return s, nil
  420. }
  421. // PackSiblings packs the siblings into a byte array.
  422. // [ 1 byte | L bytes | S * N bytes ]
  423. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  424. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  425. // array. And S is the size of the output of the hash function used for the
  426. // Tree.
  427. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  428. var b []byte
  429. var bitmap []bool
  430. emptySibling := make([]byte, hashFunc.Len())
  431. for i := 0; i < len(siblings); i++ {
  432. if bytes.Equal(siblings[i], emptySibling) {
  433. bitmap = append(bitmap, false)
  434. } else {
  435. bitmap = append(bitmap, true)
  436. b = append(b, siblings[i]...)
  437. }
  438. }
  439. bitmapBytes := bitmapToBytes(bitmap)
  440. l := len(bitmapBytes)
  441. res := make([]byte, l+1+len(b))
  442. res[0] = byte(l) // set the bitmapBytes length
  443. copy(res[1:1+l], bitmapBytes)
  444. copy(res[1+l:], b)
  445. return res
  446. }
  447. // UnpackSiblings unpacks the siblings from a byte array.
  448. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  449. l := b[0]
  450. bitmapBytes := b[1 : 1+l]
  451. bitmap := bytesToBitmap(bitmapBytes)
  452. siblingsBytes := b[1+l:]
  453. iSibl := 0
  454. emptySibl := make([]byte, hashFunc.Len())
  455. var siblings [][]byte
  456. for i := 0; i < len(bitmap); i++ {
  457. if iSibl >= len(siblingsBytes) {
  458. break
  459. }
  460. if bitmap[i] {
  461. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  462. iSibl += hashFunc.Len()
  463. } else {
  464. siblings = append(siblings, emptySibl)
  465. }
  466. }
  467. return siblings, nil
  468. }
  469. func bitmapToBytes(bitmap []bool) []byte {
  470. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  471. b := make([]byte, bitmapBytesLen)
  472. for i := 0; i < len(bitmap); i++ {
  473. if bitmap[i] {
  474. b[i/8] |= 1 << (i % 8)
  475. }
  476. }
  477. return b
  478. }
  479. func bytesToBitmap(b []byte) []bool {
  480. var bitmap []bool
  481. for i := 0; i < len(b); i++ {
  482. for j := 0; j < 8; j++ {
  483. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  484. }
  485. }
  486. return bitmap
  487. }
  488. // Get returns the value for a given key
  489. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  490. keyPath := make([]byte, t.hashFunction.Len())
  491. copy(keyPath[:], k)
  492. path := getPath(t.maxLevels, keyPath)
  493. // go down to the leaf
  494. var siblings [][]byte
  495. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  496. if err != nil {
  497. return nil, nil, err
  498. }
  499. leafK, leafV := readLeafValue(value)
  500. if !bytes.Equal(k, leafK) {
  501. panic(fmt.Errorf("%s != %s", BytesToBigInt(k), BytesToBigInt(leafK)))
  502. }
  503. return leafK, leafV, nil
  504. }
  505. // CheckProof verifies the given proof. The proof verification depends on the
  506. // HashFunction passed as parameter.
  507. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  508. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  509. if err != nil {
  510. return false, err
  511. }
  512. keyPath := make([]byte, hashFunc.Len())
  513. copy(keyPath[:], k)
  514. key, _, err := newLeafValue(hashFunc, k, v)
  515. if err != nil {
  516. return false, err
  517. }
  518. path := getPath(len(siblings), keyPath)
  519. for i := len(siblings) - 1; i >= 0; i-- {
  520. if path[i] {
  521. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  522. if err != nil {
  523. return false, err
  524. }
  525. } else {
  526. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  527. if err != nil {
  528. return false, err
  529. }
  530. }
  531. }
  532. if bytes.Equal(key[:], root) {
  533. return true, nil
  534. }
  535. return false, nil
  536. }
  537. func (t *Tree) dbGet(k []byte) ([]byte, error) {
  538. // if key is empty, return empty as value
  539. if bytes.Equal(k, t.emptyHash) {
  540. return t.emptyHash, nil
  541. }
  542. v, err := t.db.Get(k)
  543. if err == nil {
  544. return v, nil
  545. }
  546. if t.tx != nil {
  547. return t.tx.Get(k)
  548. }
  549. return nil, db.ErrNotFound
  550. }
  551. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  552. // after the setNLeafs call.
  553. func (t *Tree) incNLeafs(nLeafs uint64) error {
  554. oldNLeafs, err := t.GetNLeafs()
  555. if err != nil {
  556. return err
  557. }
  558. newNLeafs := oldNLeafs + nLeafs
  559. return t.setNLeafs(newNLeafs)
  560. }
  561. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  562. // after the setNLeafs call.
  563. func (t *Tree) setNLeafs(nLeafs uint64) error {
  564. b := make([]byte, 8)
  565. binary.LittleEndian.PutUint64(b, nLeafs)
  566. if err := t.tx.Put(dbKeyNLeafs, b); err != nil {
  567. return err
  568. }
  569. return nil
  570. }
  571. // GetNLeafs returns the number of Leafs of the Tree.
  572. func (t *Tree) GetNLeafs() (uint64, error) {
  573. b, err := t.dbGet(dbKeyNLeafs)
  574. if err != nil {
  575. return 0, err
  576. }
  577. nLeafs := binary.LittleEndian.Uint64(b)
  578. return nLeafs, nil
  579. }
  580. // Iterate iterates through the full Tree, executing the given function on each
  581. // node of the Tree.
  582. func (t *Tree) Iterate(f func([]byte, []byte)) error {
  583. t.updateAccessTime()
  584. return t.iter(t.root, f)
  585. }
  586. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  587. v, err := t.dbGet(k)
  588. if err != nil {
  589. return err
  590. }
  591. switch v[0] {
  592. case PrefixValueEmpty:
  593. f(k, v)
  594. case PrefixValueLeaf:
  595. f(k, v)
  596. case PrefixValueIntermediate:
  597. f(k, v)
  598. l, r := readIntermediateChilds(v)
  599. if err = t.iter(l, f); err != nil {
  600. return err
  601. }
  602. if err = t.iter(r, f); err != nil {
  603. return err
  604. }
  605. default:
  606. return fmt.Errorf("invalid value")
  607. }
  608. return nil
  609. }
  610. // Dump exports all the Tree leafs in a byte array of length:
  611. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  612. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  613. // [ len(k) | len(v) | key | value ]
  614. // Where S is the size of the output of the hash function used for the Tree.
  615. func (t *Tree) Dump() ([]byte, error) {
  616. t.updateAccessTime()
  617. // WARNING current encoding only supports key & values of 255 bytes each
  618. // (due using only 1 byte for the length headers).
  619. var b []byte
  620. err := t.Iterate(func(k, v []byte) {
  621. if v[0] != PrefixValueLeaf {
  622. return
  623. }
  624. leafK, leafV := readLeafValue(v)
  625. kv := make([]byte, 2+len(leafK)+len(leafV))
  626. kv[0] = byte(len(leafK))
  627. kv[1] = byte(len(leafV))
  628. copy(kv[2:2+len(leafK)], leafK)
  629. copy(kv[2+len(leafK):], leafV)
  630. b = append(b, kv...)
  631. })
  632. return b, err
  633. }
  634. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  635. // method) in the Tree.
  636. func (t *Tree) ImportDump(b []byte) error {
  637. t.updateAccessTime()
  638. r := bytes.NewReader(b)
  639. count := 0
  640. // TODO instead of adding one by one, use AddBatch (once AddBatch is
  641. // optimized)
  642. var err error
  643. for {
  644. l := make([]byte, 2)
  645. _, err = io.ReadFull(r, l)
  646. if err == io.EOF {
  647. break
  648. } else if err != nil {
  649. return err
  650. }
  651. k := make([]byte, l[0])
  652. _, err = io.ReadFull(r, k)
  653. if err != nil {
  654. return err
  655. }
  656. v := make([]byte, l[1])
  657. _, err = io.ReadFull(r, v)
  658. if err != nil {
  659. return err
  660. }
  661. err = t.Add(k, v)
  662. if err != nil {
  663. return err
  664. }
  665. count++
  666. }
  667. // update nLeafs (once ImportDump uses AddBatch method, this will not be
  668. // needed)
  669. t.tx, err = t.db.NewTx()
  670. if err != nil {
  671. return err
  672. }
  673. if err := t.incNLeafs(uint64(count)); err != nil {
  674. return err
  675. }
  676. if err = t.tx.Commit(); err != nil {
  677. return err
  678. }
  679. return nil
  680. }
  681. // Graphviz iterates across the full tree to generate a string Graphviz
  682. // representation of the tree and writes it to w
  683. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  684. fmt.Fprintf(w, `digraph hierarchy {
  685. node [fontname=Monospace,fontsize=10,shape=box]
  686. `)
  687. nChars := 4
  688. nEmpties := 0
  689. err := t.Iterate(func(k, v []byte) {
  690. switch v[0] {
  691. case PrefixValueEmpty:
  692. case PrefixValueLeaf:
  693. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  694. // key & value from the leaf
  695. kB, vB := readLeafValue(v)
  696. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  697. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  698. hex.EncodeToString(vB[:nChars]))
  699. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  700. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  701. case PrefixValueIntermediate:
  702. l, r := readIntermediateChilds(v)
  703. lStr := hex.EncodeToString(l[:nChars])
  704. rStr := hex.EncodeToString(r[:nChars])
  705. eStr := ""
  706. if bytes.Equal(l, t.emptyHash) {
  707. lStr = fmt.Sprintf("empty%v", nEmpties)
  708. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  709. lStr)
  710. nEmpties++
  711. }
  712. if bytes.Equal(r, t.emptyHash) {
  713. rStr = fmt.Sprintf("empty%v", nEmpties)
  714. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  715. rStr)
  716. nEmpties++
  717. }
  718. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  719. lStr, rStr)
  720. fmt.Fprint(w, eStr)
  721. default:
  722. }
  723. })
  724. fmt.Fprintf(w, "}\n")
  725. return err
  726. }
  727. // PrintGraphviz prints the output of Tree.Graphviz
  728. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  729. if rootKey == nil {
  730. rootKey = t.Root()
  731. }
  732. w := bytes.NewBufferString("")
  733. fmt.Fprintf(w,
  734. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  735. err := t.Graphviz(w, nil)
  736. if err != nil {
  737. fmt.Println(w)
  738. return err
  739. }
  740. fmt.Fprintf(w,
  741. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  742. fmt.Println(w)
  743. return nil
  744. }
  745. // Purge WIP: unimplemented
  746. func (t *Tree) Purge(keys [][]byte) error {
  747. return nil
  748. }