You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

941 lines
24 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "encoding/binary"
  15. "encoding/hex"
  16. "fmt"
  17. "io"
  18. "math"
  19. "sync"
  20. "github.com/iden3/go-merkletree/db"
  21. )
  22. const (
  23. // PrefixValueLen defines the bytes-prefix length used for the Value
  24. // bytes representation stored in the db
  25. PrefixValueLen = 2
  26. // PrefixValueEmpty is used for the first byte of a Value to indicate
  27. // that is an Empty value
  28. PrefixValueEmpty = 0
  29. // PrefixValueLeaf is used for the first byte of a Value to indicate
  30. // that is a Leaf value
  31. PrefixValueLeaf = 1
  32. // PrefixValueIntermediate is used for the first byte of a Value to
  33. // indicate that is a Intermediate value
  34. PrefixValueIntermediate = 2
  35. // nChars is used to crop the Graphviz nodes labels
  36. nChars = 4
  37. )
  38. var (
  39. dbKeyRoot = []byte("root")
  40. dbKeyNLeafs = []byte("nleafs")
  41. emptyValue = []byte{0}
  42. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  43. // tree that already exists.
  44. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  45. // ErrInvalidValuePrefix is used when going down into the tree, a value
  46. // is read from the db and has an unrecognized prefix.
  47. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  48. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.tx==nil
  49. ErrDBNoTx = fmt.Errorf("dbPut error: no db Tx")
  50. // ErrMaxLevel indicates when going down into the tree, the max level is
  51. // reached
  52. ErrMaxLevel = fmt.Errorf("max level reached")
  53. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  54. // virtual level is reached
  55. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  56. )
  57. // Tree defines the struct that implements the MerkleTree functionalities
  58. type Tree struct {
  59. sync.RWMutex
  60. tx db.Tx
  61. db db.Storage
  62. maxLevels int
  63. root []byte
  64. hashFunction HashFunction
  65. // TODO in the methods that use it, check if emptyHash param is len>0
  66. // (check if it has been initialized)
  67. emptyHash []byte
  68. dbg *dbgStats
  69. }
  70. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  71. // will load it.
  72. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  73. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  74. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  75. root, err := t.dbGet(dbKeyRoot)
  76. if err == db.ErrNotFound {
  77. // store new root 0
  78. t.tx, err = t.db.NewTx()
  79. if err != nil {
  80. return nil, err
  81. }
  82. t.root = t.emptyHash
  83. if err = t.dbPut(dbKeyRoot, t.root); err != nil {
  84. return nil, err
  85. }
  86. if err = t.setNLeafs(0); err != nil {
  87. return nil, err
  88. }
  89. if err = t.tx.Commit(); err != nil {
  90. return nil, err
  91. }
  92. return &t, err
  93. } else if err != nil {
  94. return nil, err
  95. }
  96. t.root = root
  97. return &t, nil
  98. }
  99. // Root returns the root of the Tree
  100. func (t *Tree) Root() []byte {
  101. return t.root
  102. }
  103. // HashFunction returns Tree.hashFunction
  104. func (t *Tree) HashFunction() HashFunction {
  105. return t.hashFunction
  106. }
  107. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  108. // the indexes of the keys failed to add.
  109. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  110. t.Lock()
  111. defer t.Unlock()
  112. vt, err := t.loadVT()
  113. if err != nil {
  114. return nil, err
  115. }
  116. // TODO check validity of keys & values for Tree.hashFunction
  117. invalids, err := vt.addBatch(keys, values)
  118. if err != nil {
  119. return nil, err
  120. }
  121. // once the VirtualTree is build, compute the hashes
  122. pairs, err := vt.computeHashes()
  123. if err != nil {
  124. // TODO currently invalids in computeHashes are not counted
  125. return nil, err
  126. }
  127. t.root = vt.root.h
  128. // store pairs in db
  129. t.tx, err = t.db.NewTx()
  130. if err != nil {
  131. return nil, err
  132. }
  133. for i := 0; i < len(pairs); i++ {
  134. if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil {
  135. return nil, err
  136. }
  137. }
  138. // store root to db
  139. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  140. return nil, err
  141. }
  142. // update nLeafs
  143. if err := t.incNLeafs(len(keys) - len(invalids)); err != nil {
  144. return nil, err
  145. }
  146. // commit db tx
  147. if err := t.tx.Commit(); err != nil {
  148. return nil, err
  149. }
  150. return invalids, nil
  151. }
  152. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  153. // the same leafs.
  154. func (t *Tree) loadVT() (vt, error) {
  155. vt := newVT(t.maxLevels, t.hashFunction)
  156. vt.params.dbg = t.dbg
  157. err := t.Iterate(nil, func(k, v []byte) {
  158. if v[0] != PrefixValueLeaf {
  159. return
  160. }
  161. leafK, leafV := ReadLeafValue(v)
  162. if err := vt.add(0, leafK, leafV); err != nil {
  163. panic(err)
  164. }
  165. })
  166. return vt, err
  167. }
  168. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  169. // is expected that are represented by a Little-Endian byte array (for circom
  170. // compatibility).
  171. func (t *Tree) Add(k, v []byte) error {
  172. t.Lock()
  173. defer t.Unlock()
  174. var err error
  175. t.tx, err = t.db.NewTx()
  176. if err != nil {
  177. return err
  178. }
  179. // TODO check validity of key & value for Tree.hashFunction
  180. err = t.add(0, k, v) // add from level 0
  181. if err != nil {
  182. return err
  183. }
  184. // store root to db
  185. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  186. return err
  187. }
  188. // update nLeafs
  189. if err = t.incNLeafs(1); err != nil {
  190. return err
  191. }
  192. return t.tx.Commit()
  193. }
  194. func (t *Tree) add(fromLvl int, k, v []byte) error {
  195. keyPath := make([]byte, t.hashFunction.Len())
  196. copy(keyPath[:], k)
  197. path := getPath(t.maxLevels, keyPath)
  198. // go down to the leaf
  199. var siblings [][]byte
  200. _, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false)
  201. if err != nil {
  202. return err
  203. }
  204. leafKey, leafValue, err := t.newLeafValue(k, v)
  205. if err != nil {
  206. return err
  207. }
  208. if err := t.dbPut(leafKey, leafValue); err != nil {
  209. return err
  210. }
  211. // go up to the root
  212. if len(siblings) == 0 {
  213. t.root = leafKey
  214. return nil
  215. }
  216. root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl)
  217. if err != nil {
  218. return err
  219. }
  220. t.root = root
  221. return nil
  222. }
  223. // down goes down to the leaf recursively
  224. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
  225. path []bool, currLvl int, getLeaf bool) (
  226. []byte, []byte, [][]byte, error) {
  227. if currLvl > t.maxLevels-1 {
  228. return nil, nil, nil, ErrMaxLevel
  229. }
  230. var err error
  231. var currValue []byte
  232. if bytes.Equal(currKey, t.emptyHash) {
  233. // empty value
  234. return currKey, emptyValue, siblings, nil
  235. }
  236. currValue, err = t.dbGet(currKey)
  237. if err != nil {
  238. return nil, nil, nil, err
  239. }
  240. switch currValue[0] {
  241. case PrefixValueEmpty: // empty
  242. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  243. // reaching this point
  244. // return currKey, empty, siblings, nil
  245. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  246. case PrefixValueLeaf: // leaf
  247. if bytes.Equal(newKey, currKey) {
  248. // TODO move this error msg to const & add test that
  249. // checks that adding a repeated key this error is
  250. // returned
  251. return nil, nil, nil, ErrKeyAlreadyExists
  252. }
  253. if !bytes.Equal(currValue, emptyValue) {
  254. if getLeaf {
  255. return currKey, currValue, siblings, nil
  256. }
  257. oldLeafKey, _ := ReadLeafValue(currValue)
  258. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  259. copy(oldLeafKeyFull[:], oldLeafKey)
  260. // if currKey is already used, go down until paths diverge
  261. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  262. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  263. if err != nil {
  264. return nil, nil, nil, err
  265. }
  266. }
  267. return currKey, currValue, siblings, nil
  268. case PrefixValueIntermediate: // intermediate
  269. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  270. return nil, nil, nil,
  271. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  272. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  273. }
  274. // collect siblings while going down
  275. if path[currLvl] {
  276. // right
  277. lChild, rChild := ReadIntermediateChilds(currValue)
  278. siblings = append(siblings, lChild)
  279. return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf)
  280. }
  281. // left
  282. lChild, rChild := ReadIntermediateChilds(currValue)
  283. siblings = append(siblings, rChild)
  284. return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf)
  285. default:
  286. return nil, nil, nil, ErrInvalidValuePrefix
  287. }
  288. }
  289. // downVirtually is used when in a leaf already exists, and a new leaf which
  290. // shares the path until the existing leaf is being added
  291. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  292. newPath []bool, currLvl int) ([][]byte, error) {
  293. var err error
  294. if currLvl > t.maxLevels-1 {
  295. return nil, ErrMaxVirtualLevel
  296. }
  297. if oldPath[currLvl] == newPath[currLvl] {
  298. siblings = append(siblings, t.emptyHash)
  299. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  300. if err != nil {
  301. return nil, err
  302. }
  303. return siblings, nil
  304. }
  305. // reached the divergence
  306. siblings = append(siblings, oldKey)
  307. return siblings, nil
  308. }
  309. // up goes up recursively updating the intermediate nodes
  310. func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) {
  311. var k, v []byte
  312. var err error
  313. if path[currLvl+toLvl] {
  314. k, v, err = t.newIntermediate(siblings[currLvl], key)
  315. if err != nil {
  316. return nil, err
  317. }
  318. } else {
  319. k, v, err = t.newIntermediate(key, siblings[currLvl])
  320. if err != nil {
  321. return nil, err
  322. }
  323. }
  324. // store k-v to db
  325. if err = t.dbPut(k, v); err != nil {
  326. return nil, err
  327. }
  328. if currLvl == 0 {
  329. // reached the root
  330. return k, nil
  331. }
  332. return t.up(k, siblings, path, currLvl-1, toLvl)
  333. }
  334. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  335. t.dbg.incHash()
  336. return newLeafValue(t.hashFunction, k, v)
  337. }
  338. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  339. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  340. if err != nil {
  341. return nil, nil, err
  342. }
  343. var leafValue []byte
  344. leafValue = append(leafValue, byte(1))
  345. leafValue = append(leafValue, byte(len(k)))
  346. leafValue = append(leafValue, k...)
  347. leafValue = append(leafValue, v...)
  348. return leafKey, leafValue, nil
  349. }
  350. // ReadLeafValue reads from a byte array the leaf key & value
  351. func ReadLeafValue(b []byte) ([]byte, []byte) {
  352. if len(b) < PrefixValueLen {
  353. return []byte{}, []byte{}
  354. }
  355. kLen := b[1]
  356. if len(b) < PrefixValueLen+int(kLen) {
  357. return []byte{}, []byte{}
  358. }
  359. k := b[PrefixValueLen : PrefixValueLen+kLen]
  360. v := b[PrefixValueLen+kLen:]
  361. return k, v
  362. }
  363. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  364. t.dbg.incHash()
  365. return newIntermediate(t.hashFunction, l, r)
  366. }
  367. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  368. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  369. b[0] = 2
  370. b[1] = byte(len(l))
  371. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  372. copy(b[PrefixValueLen+hashFunc.Len():], r)
  373. key, err := hashFunc.Hash(l, r)
  374. if err != nil {
  375. return nil, nil, err
  376. }
  377. return key, b, nil
  378. }
  379. // ReadIntermediateChilds reads from a byte array the two childs keys
  380. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  381. if len(b) < PrefixValueLen {
  382. return []byte{}, []byte{}
  383. }
  384. lLen := b[1]
  385. if len(b) < PrefixValueLen+int(lLen) {
  386. return []byte{}, []byte{}
  387. }
  388. l := b[PrefixValueLen : PrefixValueLen+lLen]
  389. r := b[PrefixValueLen+lLen:]
  390. return l, r
  391. }
  392. func getPath(numLevels int, k []byte) []bool {
  393. path := make([]bool, numLevels)
  394. for n := 0; n < numLevels; n++ {
  395. path[n] = k[n/8]&(1<<(n%8)) != 0
  396. }
  397. return path
  398. }
  399. // Update updates the value for a given existing key. If the given key does not
  400. // exist, returns an error.
  401. func (t *Tree) Update(k, v []byte) error {
  402. t.Lock()
  403. defer t.Unlock()
  404. var err error
  405. t.tx, err = t.db.NewTx()
  406. if err != nil {
  407. return err
  408. }
  409. keyPath := make([]byte, t.hashFunction.Len())
  410. copy(keyPath[:], k)
  411. path := getPath(t.maxLevels, keyPath)
  412. var siblings [][]byte
  413. _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  414. if err != nil {
  415. return err
  416. }
  417. oldKey, _ := ReadLeafValue(valueAtBottom)
  418. if !bytes.Equal(oldKey, k) {
  419. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  420. }
  421. leafKey, leafValue, err := t.newLeafValue(k, v)
  422. if err != nil {
  423. return err
  424. }
  425. if err := t.dbPut(leafKey, leafValue); err != nil {
  426. return err
  427. }
  428. // go up to the root
  429. if len(siblings) == 0 {
  430. t.root = leafKey
  431. return t.tx.Commit()
  432. }
  433. root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0)
  434. if err != nil {
  435. return err
  436. }
  437. t.root = root
  438. // store root to db
  439. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  440. return err
  441. }
  442. return t.tx.Commit()
  443. }
  444. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  445. // the Tree, the proof will be of existence, if the key does not exist in the
  446. // tree, the proof will be of non-existence.
  447. func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) {
  448. keyPath := make([]byte, t.hashFunction.Len())
  449. copy(keyPath[:], k)
  450. path := getPath(t.maxLevels, keyPath)
  451. // go down to the leaf
  452. var siblings [][]byte
  453. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  454. if err != nil {
  455. return nil, nil, err
  456. }
  457. leafK, leafV := ReadLeafValue(value)
  458. if !bytes.Equal(k, leafK) {
  459. fmt.Println("key not in Tree")
  460. fmt.Println(leafK)
  461. fmt.Println(leafV)
  462. // TODO proof of non-existence
  463. panic("unimplemented")
  464. }
  465. s := PackSiblings(t.hashFunction, siblings)
  466. return leafV, s, nil
  467. }
  468. // PackSiblings packs the siblings into a byte array.
  469. // [ 1 byte | L bytes | S * N bytes ]
  470. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  471. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  472. // array. And S is the size of the output of the hash function used for the
  473. // Tree.
  474. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  475. var b []byte
  476. var bitmap []bool
  477. emptySibling := make([]byte, hashFunc.Len())
  478. for i := 0; i < len(siblings); i++ {
  479. if bytes.Equal(siblings[i], emptySibling) {
  480. bitmap = append(bitmap, false)
  481. } else {
  482. bitmap = append(bitmap, true)
  483. b = append(b, siblings[i]...)
  484. }
  485. }
  486. bitmapBytes := bitmapToBytes(bitmap)
  487. l := len(bitmapBytes)
  488. res := make([]byte, l+1+len(b))
  489. res[0] = byte(l) // set the bitmapBytes length
  490. copy(res[1:1+l], bitmapBytes)
  491. copy(res[1+l:], b)
  492. return res
  493. }
  494. // UnpackSiblings unpacks the siblings from a byte array.
  495. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  496. l := b[0]
  497. bitmapBytes := b[1 : 1+l]
  498. bitmap := bytesToBitmap(bitmapBytes)
  499. siblingsBytes := b[1+l:]
  500. iSibl := 0
  501. emptySibl := make([]byte, hashFunc.Len())
  502. var siblings [][]byte
  503. for i := 0; i < len(bitmap); i++ {
  504. if iSibl >= len(siblingsBytes) {
  505. break
  506. }
  507. if bitmap[i] {
  508. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  509. iSibl += hashFunc.Len()
  510. } else {
  511. siblings = append(siblings, emptySibl)
  512. }
  513. }
  514. return siblings, nil
  515. }
  516. func bitmapToBytes(bitmap []bool) []byte {
  517. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  518. b := make([]byte, bitmapBytesLen)
  519. for i := 0; i < len(bitmap); i++ {
  520. if bitmap[i] {
  521. b[i/8] |= 1 << (i % 8)
  522. }
  523. }
  524. return b
  525. }
  526. func bytesToBitmap(b []byte) []bool {
  527. var bitmap []bool
  528. for i := 0; i < len(b); i++ {
  529. for j := 0; j < 8; j++ {
  530. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  531. }
  532. }
  533. return bitmap
  534. }
  535. // Get returns the value for a given key
  536. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  537. keyPath := make([]byte, t.hashFunction.Len())
  538. copy(keyPath[:], k)
  539. path := getPath(t.maxLevels, keyPath)
  540. // go down to the leaf
  541. var siblings [][]byte
  542. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  543. if err != nil {
  544. return nil, nil, err
  545. }
  546. leafK, leafV := ReadLeafValue(value)
  547. if !bytes.Equal(k, leafK) {
  548. return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s",
  549. BytesToBigInt(k), BytesToBigInt(leafK))
  550. }
  551. return leafK, leafV, nil
  552. }
  553. // CheckProof verifies the given proof. The proof verification depends on the
  554. // HashFunction passed as parameter.
  555. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  556. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  557. if err != nil {
  558. return false, err
  559. }
  560. keyPath := make([]byte, hashFunc.Len())
  561. copy(keyPath[:], k)
  562. key, _, err := newLeafValue(hashFunc, k, v)
  563. if err != nil {
  564. return false, err
  565. }
  566. path := getPath(len(siblings), keyPath)
  567. for i := len(siblings) - 1; i >= 0; i-- {
  568. if path[i] {
  569. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  570. if err != nil {
  571. return false, err
  572. }
  573. } else {
  574. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  575. if err != nil {
  576. return false, err
  577. }
  578. }
  579. }
  580. if bytes.Equal(key[:], root) {
  581. return true, nil
  582. }
  583. return false, nil
  584. }
  585. func (t *Tree) dbPut(k, v []byte) error {
  586. if t.tx == nil {
  587. return ErrDBNoTx
  588. }
  589. t.dbg.incDbPut()
  590. return t.tx.Put(k, v)
  591. }
  592. func (t *Tree) dbGet(k []byte) ([]byte, error) {
  593. // if key is empty, return empty as value
  594. if bytes.Equal(k, t.emptyHash) {
  595. return t.emptyHash, nil
  596. }
  597. t.dbg.incDbGet()
  598. v, err := t.db.Get(k)
  599. if err == nil {
  600. return v, nil
  601. }
  602. if t.tx != nil {
  603. return t.tx.Get(k)
  604. }
  605. return nil, db.ErrNotFound
  606. }
  607. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  608. // after the setNLeafs call.
  609. func (t *Tree) incNLeafs(nLeafs int) error {
  610. oldNLeafs, err := t.GetNLeafs()
  611. if err != nil {
  612. return err
  613. }
  614. newNLeafs := oldNLeafs + nLeafs
  615. return t.setNLeafs(newNLeafs)
  616. }
  617. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  618. // after the setNLeafs call.
  619. func (t *Tree) setNLeafs(nLeafs int) error {
  620. b := make([]byte, 8)
  621. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  622. if err := t.dbPut(dbKeyNLeafs, b); err != nil {
  623. return err
  624. }
  625. return nil
  626. }
  627. // GetNLeafs returns the number of Leafs of the Tree.
  628. func (t *Tree) GetNLeafs() (int, error) {
  629. b, err := t.dbGet(dbKeyNLeafs)
  630. if err != nil {
  631. return 0, err
  632. }
  633. nLeafs := binary.LittleEndian.Uint64(b)
  634. return int(nLeafs), nil
  635. }
  636. // Iterate iterates through the full Tree, executing the given function on each
  637. // node of the Tree.
  638. func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error {
  639. // allow to define which root to use
  640. if rootKey == nil {
  641. rootKey = t.Root()
  642. }
  643. return t.iter(rootKey, f)
  644. }
  645. // IterateWithStop does the same than Iterate, but with int for the current
  646. // level, and a boolean parameter used by the passed function, is to indicate to
  647. // stop iterating on the branch when the method returns 'true'.
  648. func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error {
  649. // allow to define which root to use
  650. if rootKey == nil {
  651. rootKey = t.Root()
  652. }
  653. return t.iterWithStop(rootKey, 0, f)
  654. }
  655. func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error {
  656. v, err := t.dbGet(k)
  657. if err != nil {
  658. return err
  659. }
  660. currLevel++
  661. switch v[0] {
  662. case PrefixValueEmpty:
  663. f(currLevel, k, v)
  664. case PrefixValueLeaf:
  665. f(currLevel, k, v)
  666. case PrefixValueIntermediate:
  667. stop := f(currLevel, k, v)
  668. if stop {
  669. return nil
  670. }
  671. l, r := ReadIntermediateChilds(v)
  672. if err = t.iterWithStop(l, currLevel, f); err != nil {
  673. return err
  674. }
  675. if err = t.iterWithStop(r, currLevel, f); err != nil {
  676. return err
  677. }
  678. default:
  679. return ErrInvalidValuePrefix
  680. }
  681. return nil
  682. }
  683. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  684. f2 := func(currLvl int, k, v []byte) bool {
  685. f(k, v)
  686. return false
  687. }
  688. return t.iterWithStop(k, 0, f2)
  689. }
  690. // Dump exports all the Tree leafs in a byte array of length:
  691. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  692. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  693. // [ len(k) | len(v) | key | value ]
  694. // Where S is the size of the output of the hash function used for the Tree.
  695. func (t *Tree) Dump(rootKey []byte) ([]byte, error) {
  696. // allow to define which root to use
  697. if rootKey == nil {
  698. rootKey = t.Root()
  699. }
  700. // WARNING current encoding only supports key & values of 255 bytes each
  701. // (due using only 1 byte for the length headers).
  702. var b []byte
  703. err := t.Iterate(rootKey, func(k, v []byte) {
  704. if v[0] != PrefixValueLeaf {
  705. return
  706. }
  707. leafK, leafV := ReadLeafValue(v)
  708. kv := make([]byte, 2+len(leafK)+len(leafV))
  709. kv[0] = byte(len(leafK))
  710. kv[1] = byte(len(leafV))
  711. copy(kv[2:2+len(leafK)], leafK)
  712. copy(kv[2+len(leafK):], leafV)
  713. b = append(b, kv...)
  714. })
  715. return b, err
  716. }
  717. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  718. // method) in the Tree.
  719. func (t *Tree) ImportDump(b []byte) error {
  720. r := bytes.NewReader(b)
  721. var err error
  722. var keys, values [][]byte
  723. for {
  724. l := make([]byte, 2)
  725. _, err = io.ReadFull(r, l)
  726. if err == io.EOF {
  727. break
  728. } else if err != nil {
  729. return err
  730. }
  731. k := make([]byte, l[0])
  732. _, err = io.ReadFull(r, k)
  733. if err != nil {
  734. return err
  735. }
  736. v := make([]byte, l[1])
  737. _, err = io.ReadFull(r, v)
  738. if err != nil {
  739. return err
  740. }
  741. keys = append(keys, k)
  742. values = append(values, v)
  743. }
  744. if _, err = t.AddBatch(keys, values); err != nil {
  745. return err
  746. }
  747. return nil
  748. }
  749. // Graphviz iterates across the full tree to generate a string Graphviz
  750. // representation of the tree and writes it to w
  751. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  752. return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
  753. }
  754. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  755. // generate a string Graphviz representation of the first NLevels of the tree
  756. // and writes it to w
  757. func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
  758. fmt.Fprintf(w, `digraph hierarchy {
  759. node [fontname=Monospace,fontsize=10,shape=box]
  760. `)
  761. if rootKey == nil {
  762. rootKey = t.Root()
  763. }
  764. nEmpties := 0
  765. err := t.iterWithStop(rootKey, 0, func(currLvl int, k, v []byte) bool {
  766. if currLvl == untilLvl {
  767. return true // to stop the iter from going down
  768. }
  769. switch v[0] {
  770. case PrefixValueEmpty:
  771. case PrefixValueLeaf:
  772. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  773. // key & value from the leaf
  774. kB, vB := ReadLeafValue(v)
  775. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  776. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  777. hex.EncodeToString(vB[:nChars]))
  778. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  779. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  780. case PrefixValueIntermediate:
  781. l, r := ReadIntermediateChilds(v)
  782. lStr := hex.EncodeToString(l[:nChars])
  783. rStr := hex.EncodeToString(r[:nChars])
  784. eStr := ""
  785. if bytes.Equal(l, t.emptyHash) {
  786. lStr = fmt.Sprintf("empty%v", nEmpties)
  787. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  788. lStr)
  789. nEmpties++
  790. }
  791. if bytes.Equal(r, t.emptyHash) {
  792. rStr = fmt.Sprintf("empty%v", nEmpties)
  793. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  794. rStr)
  795. nEmpties++
  796. }
  797. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  798. lStr, rStr)
  799. fmt.Fprint(w, eStr)
  800. default:
  801. }
  802. return false
  803. })
  804. fmt.Fprintf(w, "}\n")
  805. return err
  806. }
  807. // PrintGraphviz prints the output of Tree.Graphviz
  808. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  809. if rootKey == nil {
  810. rootKey = t.Root()
  811. }
  812. return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
  813. }
  814. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  815. func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
  816. if rootKey == nil {
  817. rootKey = t.Root()
  818. }
  819. w := bytes.NewBufferString("")
  820. fmt.Fprintf(w,
  821. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  822. err := t.GraphvizFirstNLevels(w, rootKey, untilLvl)
  823. if err != nil {
  824. fmt.Println(w)
  825. return err
  826. }
  827. fmt.Fprintf(w,
  828. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  829. fmt.Println(w)
  830. return nil
  831. }
  832. // Purge WIP: unimplemented TODO
  833. func (t *Tree) Purge(keys [][]byte) error {
  834. return nil
  835. }
  836. // TODO circom proofs