You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

954 lines
25 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "encoding/binary"
  15. "encoding/hex"
  16. "fmt"
  17. "io"
  18. "math"
  19. "sync"
  20. "github.com/iden3/go-merkletree/db"
  21. )
  22. const (
  23. // PrefixValueLen defines the bytes-prefix length used for the Value
  24. // bytes representation stored in the db
  25. PrefixValueLen = 2
  26. // PrefixValueEmpty is used for the first byte of a Value to indicate
  27. // that is an Empty value
  28. PrefixValueEmpty = 0
  29. // PrefixValueLeaf is used for the first byte of a Value to indicate
  30. // that is a Leaf value
  31. PrefixValueLeaf = 1
  32. // PrefixValueIntermediate is used for the first byte of a Value to
  33. // indicate that is a Intermediate value
  34. PrefixValueIntermediate = 2
  35. // nChars is used to crop the Graphviz nodes labels
  36. nChars = 4
  37. )
  38. var (
  39. dbKeyRoot = []byte("root")
  40. dbKeyNLeafs = []byte("nleafs")
  41. emptyValue = []byte{0}
  42. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  43. // tree that already exists.
  44. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  45. // ErrInvalidValuePrefix is used when going down into the tree, a value
  46. // is read from the db and has an unrecognized prefix.
  47. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  48. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.tx==nil
  49. ErrDBNoTx = fmt.Errorf("dbPut error: no db Tx")
  50. // ErrMaxLevel indicates when going down into the tree, the max level is
  51. // reached
  52. ErrMaxLevel = fmt.Errorf("max level reached")
  53. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  54. // virtual level is reached
  55. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  56. )
  57. // Tree defines the struct that implements the MerkleTree functionalities
  58. type Tree struct {
  59. sync.RWMutex
  60. tx db.Tx
  61. db db.Storage
  62. maxLevels int
  63. root []byte
  64. hashFunction HashFunction
  65. // TODO in the methods that use it, check if emptyHash param is len>0
  66. // (check if it has been initialized)
  67. emptyHash []byte
  68. dbg *dbgStats
  69. }
  70. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  71. // will load it.
  72. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  73. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  74. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  75. root, err := t.dbGet(dbKeyRoot)
  76. if err == db.ErrNotFound {
  77. // store new root 0
  78. t.tx, err = t.db.NewTx()
  79. if err != nil {
  80. return nil, err
  81. }
  82. t.root = t.emptyHash
  83. if err = t.dbPut(dbKeyRoot, t.root); err != nil {
  84. return nil, err
  85. }
  86. if err = t.setNLeafs(0); err != nil {
  87. return nil, err
  88. }
  89. if err = t.tx.Commit(); err != nil {
  90. return nil, err
  91. }
  92. return &t, err
  93. } else if err != nil {
  94. return nil, err
  95. }
  96. t.root = root
  97. return &t, nil
  98. }
  99. // Root returns the root of the Tree
  100. func (t *Tree) Root() []byte {
  101. return t.root
  102. }
  103. // HashFunction returns Tree.hashFunction
  104. func (t *Tree) HashFunction() HashFunction {
  105. return t.hashFunction
  106. }
  107. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  108. // the indexes of the keys failed to add.
  109. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  110. t.Lock()
  111. defer t.Unlock()
  112. vt, err := t.loadVT()
  113. if err != nil {
  114. return nil, err
  115. }
  116. // TODO check validity of keys & values for Tree.hashFunction
  117. invalids, err := vt.addBatch(keys, values)
  118. if err != nil {
  119. return nil, err
  120. }
  121. // once the VirtualTree is build, compute the hashes
  122. pairs, err := vt.computeHashes()
  123. if err != nil {
  124. // TODO currently invalids in computeHashes are not counted
  125. return nil, err
  126. }
  127. t.root = vt.root.h
  128. // store pairs in db
  129. t.tx, err = t.db.NewTx()
  130. if err != nil {
  131. return nil, err
  132. }
  133. for i := 0; i < len(pairs); i++ {
  134. if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil {
  135. return nil, err
  136. }
  137. }
  138. // store root to db
  139. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  140. return nil, err
  141. }
  142. // update nLeafs
  143. if err := t.incNLeafs(len(keys) - len(invalids)); err != nil {
  144. return nil, err
  145. }
  146. // commit db tx
  147. if err := t.tx.Commit(); err != nil {
  148. return nil, err
  149. }
  150. return invalids, nil
  151. }
  152. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  153. // the same leafs.
  154. func (t *Tree) loadVT() (vt, error) {
  155. vt := newVT(t.maxLevels, t.hashFunction)
  156. vt.params.dbg = t.dbg
  157. err := t.Iterate(nil, func(k, v []byte) {
  158. if v[0] != PrefixValueLeaf {
  159. return
  160. }
  161. leafK, leafV := ReadLeafValue(v)
  162. if err := vt.add(0, leafK, leafV); err != nil {
  163. panic(err)
  164. }
  165. })
  166. return vt, err
  167. }
  168. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  169. // is expected that are represented by a Little-Endian byte array (for circom
  170. // compatibility).
  171. func (t *Tree) Add(k, v []byte) error {
  172. t.Lock()
  173. defer t.Unlock()
  174. var err error
  175. t.tx, err = t.db.NewTx()
  176. if err != nil {
  177. return err
  178. }
  179. // TODO check validity of key & value for Tree.hashFunction
  180. err = t.add(0, k, v) // add from level 0
  181. if err != nil {
  182. return err
  183. }
  184. // store root to db
  185. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  186. return err
  187. }
  188. // update nLeafs
  189. if err = t.incNLeafs(1); err != nil {
  190. return err
  191. }
  192. return t.tx.Commit()
  193. }
  194. func (t *Tree) add(fromLvl int, k, v []byte) error {
  195. keyPath := make([]byte, t.hashFunction.Len())
  196. copy(keyPath[:], k)
  197. path := getPath(t.maxLevels, keyPath)
  198. // go down to the leaf
  199. var siblings [][]byte
  200. _, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false)
  201. if err != nil {
  202. return err
  203. }
  204. leafKey, leafValue, err := t.newLeafValue(k, v)
  205. if err != nil {
  206. return err
  207. }
  208. if err := t.dbPut(leafKey, leafValue); err != nil {
  209. return err
  210. }
  211. // go up to the root
  212. if len(siblings) == 0 {
  213. t.root = leafKey
  214. return nil
  215. }
  216. root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl)
  217. if err != nil {
  218. return err
  219. }
  220. t.root = root
  221. return nil
  222. }
  223. // down goes down to the leaf recursively
  224. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
  225. path []bool, currLvl int, getLeaf bool) (
  226. []byte, []byte, [][]byte, error) {
  227. if currLvl > t.maxLevels-1 {
  228. return nil, nil, nil, ErrMaxLevel
  229. }
  230. var err error
  231. var currValue []byte
  232. if bytes.Equal(currKey, t.emptyHash) {
  233. // empty value
  234. return currKey, emptyValue, siblings, nil
  235. }
  236. currValue, err = t.dbGet(currKey)
  237. if err != nil {
  238. return nil, nil, nil, err
  239. }
  240. switch currValue[0] {
  241. case PrefixValueEmpty: // empty
  242. fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n",
  243. hex.EncodeToString(newKey), hex.EncodeToString(currKey),
  244. currLvl, hex.EncodeToString(currValue))
  245. panic("This point should not be reached, as the 'if' above" +
  246. " should avoid reaching this point. This panic is temporary" +
  247. " for reporting purposes, will be deleted in future versions." +
  248. " Please paste this log (including the previous lines) in a" +
  249. " new issue: https://github.com/arnaucube/arbo/issues/new") // TMP
  250. case PrefixValueLeaf: // leaf
  251. if !bytes.Equal(currValue, emptyValue) {
  252. if getLeaf {
  253. return currKey, currValue, siblings, nil
  254. }
  255. oldLeafKey, _ := ReadLeafValue(currValue)
  256. if bytes.Equal(newKey, oldLeafKey) {
  257. return nil, nil, nil, ErrKeyAlreadyExists
  258. }
  259. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  260. copy(oldLeafKeyFull[:], oldLeafKey)
  261. // if currKey is already used, go down until paths diverge
  262. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  263. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  264. if err != nil {
  265. return nil, nil, nil, err
  266. }
  267. }
  268. return currKey, currValue, siblings, nil
  269. case PrefixValueIntermediate: // intermediate
  270. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  271. return nil, nil, nil,
  272. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  273. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  274. }
  275. // collect siblings while going down
  276. if path[currLvl] {
  277. // right
  278. lChild, rChild := ReadIntermediateChilds(currValue)
  279. siblings = append(siblings, lChild)
  280. return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf)
  281. }
  282. // left
  283. lChild, rChild := ReadIntermediateChilds(currValue)
  284. siblings = append(siblings, rChild)
  285. return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf)
  286. default:
  287. return nil, nil, nil, ErrInvalidValuePrefix
  288. }
  289. }
  290. // downVirtually is used when in a leaf already exists, and a new leaf which
  291. // shares the path until the existing leaf is being added
  292. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  293. newPath []bool, currLvl int) ([][]byte, error) {
  294. var err error
  295. if currLvl > t.maxLevels-1 {
  296. return nil, ErrMaxVirtualLevel
  297. }
  298. if oldPath[currLvl] == newPath[currLvl] {
  299. siblings = append(siblings, t.emptyHash)
  300. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  301. if err != nil {
  302. return nil, err
  303. }
  304. return siblings, nil
  305. }
  306. // reached the divergence
  307. siblings = append(siblings, oldKey)
  308. return siblings, nil
  309. }
  310. // up goes up recursively updating the intermediate nodes
  311. func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) {
  312. var k, v []byte
  313. var err error
  314. if path[currLvl+toLvl] {
  315. k, v, err = t.newIntermediate(siblings[currLvl], key)
  316. if err != nil {
  317. return nil, err
  318. }
  319. } else {
  320. k, v, err = t.newIntermediate(key, siblings[currLvl])
  321. if err != nil {
  322. return nil, err
  323. }
  324. }
  325. // store k-v to db
  326. if err = t.dbPut(k, v); err != nil {
  327. return nil, err
  328. }
  329. if currLvl == 0 {
  330. // reached the root
  331. return k, nil
  332. }
  333. return t.up(k, siblings, path, currLvl-1, toLvl)
  334. }
  335. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  336. t.dbg.incHash()
  337. return newLeafValue(t.hashFunction, k, v)
  338. }
  339. // newLeafValue takes a key & value from a leaf, and computes the leaf hash,
  340. // which is used as the leaf key. And the value is the concatenation of the
  341. // inputed key & value. The output of this function is used as key-value to
  342. // store the leaf in the DB.
  343. // [ 1 byte | 1 byte | N bytes | M bytes ]
  344. // [ type of node | length of key | key | value ]
  345. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  346. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  347. if err != nil {
  348. return nil, nil, err
  349. }
  350. var leafValue []byte
  351. leafValue = append(leafValue, byte(1))
  352. leafValue = append(leafValue, byte(len(k)))
  353. leafValue = append(leafValue, k...)
  354. leafValue = append(leafValue, v...)
  355. return leafKey, leafValue, nil
  356. }
  357. // ReadLeafValue reads from a byte array the leaf key & value
  358. func ReadLeafValue(b []byte) ([]byte, []byte) {
  359. if len(b) < PrefixValueLen {
  360. return []byte{}, []byte{}
  361. }
  362. kLen := b[1]
  363. if len(b) < PrefixValueLen+int(kLen) {
  364. return []byte{}, []byte{}
  365. }
  366. k := b[PrefixValueLen : PrefixValueLen+kLen]
  367. v := b[PrefixValueLen+kLen:]
  368. return k, v
  369. }
  370. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  371. t.dbg.incHash()
  372. return newIntermediate(t.hashFunction, l, r)
  373. }
  374. // newIntermediate takes the left & right keys of a intermediate node, and
  375. // computes its hash. Returns the hash of the node, which is the node key, and a
  376. // byte array that contains the value (which contains the left & right child
  377. // keys) to store in the DB.
  378. // [ 1 byte | 1 byte | N bytes | N bytes ]
  379. // [ type of node | length of key | left key | right key ]
  380. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  381. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  382. b[0] = 2
  383. b[1] = byte(len(l))
  384. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  385. copy(b[PrefixValueLen+hashFunc.Len():], r)
  386. key, err := hashFunc.Hash(l, r)
  387. if err != nil {
  388. return nil, nil, err
  389. }
  390. return key, b, nil
  391. }
  392. // ReadIntermediateChilds reads from a byte array the two childs keys
  393. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  394. if len(b) < PrefixValueLen {
  395. return []byte{}, []byte{}
  396. }
  397. lLen := b[1]
  398. if len(b) < PrefixValueLen+int(lLen) {
  399. return []byte{}, []byte{}
  400. }
  401. l := b[PrefixValueLen : PrefixValueLen+lLen]
  402. r := b[PrefixValueLen+lLen:]
  403. return l, r
  404. }
  405. func getPath(numLevels int, k []byte) []bool {
  406. path := make([]bool, numLevels)
  407. for n := 0; n < numLevels; n++ {
  408. path[n] = k[n/8]&(1<<(n%8)) != 0
  409. }
  410. return path
  411. }
  412. // Update updates the value for a given existing key. If the given key does not
  413. // exist, returns an error.
  414. func (t *Tree) Update(k, v []byte) error {
  415. t.Lock()
  416. defer t.Unlock()
  417. var err error
  418. t.tx, err = t.db.NewTx()
  419. if err != nil {
  420. return err
  421. }
  422. keyPath := make([]byte, t.hashFunction.Len())
  423. copy(keyPath[:], k)
  424. path := getPath(t.maxLevels, keyPath)
  425. var siblings [][]byte
  426. _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  427. if err != nil {
  428. return err
  429. }
  430. oldKey, _ := ReadLeafValue(valueAtBottom)
  431. if !bytes.Equal(oldKey, k) {
  432. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  433. }
  434. leafKey, leafValue, err := t.newLeafValue(k, v)
  435. if err != nil {
  436. return err
  437. }
  438. if err := t.dbPut(leafKey, leafValue); err != nil {
  439. return err
  440. }
  441. // go up to the root
  442. if len(siblings) == 0 {
  443. t.root = leafKey
  444. return t.tx.Commit()
  445. }
  446. root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0)
  447. if err != nil {
  448. return err
  449. }
  450. t.root = root
  451. // store root to db
  452. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  453. return err
  454. }
  455. return t.tx.Commit()
  456. }
  457. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  458. // the Tree, the proof will be of existence, if the key does not exist in the
  459. // tree, the proof will be of non-existence.
  460. func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) {
  461. keyPath := make([]byte, t.hashFunction.Len())
  462. copy(keyPath[:], k)
  463. path := getPath(t.maxLevels, keyPath)
  464. // go down to the leaf
  465. var siblings [][]byte
  466. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  467. if err != nil {
  468. return nil, nil, err
  469. }
  470. leafK, leafV := ReadLeafValue(value)
  471. if !bytes.Equal(k, leafK) {
  472. fmt.Println("key not in Tree")
  473. fmt.Println(leafK)
  474. fmt.Println(leafV)
  475. // TODO proof of non-existence
  476. panic("unimplemented")
  477. }
  478. s := PackSiblings(t.hashFunction, siblings)
  479. return leafV, s, nil
  480. }
  481. // PackSiblings packs the siblings into a byte array.
  482. // [ 1 byte | L bytes | S * N bytes ]
  483. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  484. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  485. // array. And S is the size of the output of the hash function used for the
  486. // Tree.
  487. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  488. var b []byte
  489. var bitmap []bool
  490. emptySibling := make([]byte, hashFunc.Len())
  491. for i := 0; i < len(siblings); i++ {
  492. if bytes.Equal(siblings[i], emptySibling) {
  493. bitmap = append(bitmap, false)
  494. } else {
  495. bitmap = append(bitmap, true)
  496. b = append(b, siblings[i]...)
  497. }
  498. }
  499. bitmapBytes := bitmapToBytes(bitmap)
  500. l := len(bitmapBytes)
  501. res := make([]byte, l+1+len(b))
  502. res[0] = byte(l) // set the bitmapBytes length
  503. copy(res[1:1+l], bitmapBytes)
  504. copy(res[1+l:], b)
  505. return res
  506. }
  507. // UnpackSiblings unpacks the siblings from a byte array.
  508. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  509. l := b[0]
  510. bitmapBytes := b[1 : 1+l]
  511. bitmap := bytesToBitmap(bitmapBytes)
  512. siblingsBytes := b[1+l:]
  513. iSibl := 0
  514. emptySibl := make([]byte, hashFunc.Len())
  515. var siblings [][]byte
  516. for i := 0; i < len(bitmap); i++ {
  517. if iSibl >= len(siblingsBytes) {
  518. break
  519. }
  520. if bitmap[i] {
  521. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  522. iSibl += hashFunc.Len()
  523. } else {
  524. siblings = append(siblings, emptySibl)
  525. }
  526. }
  527. return siblings, nil
  528. }
  529. func bitmapToBytes(bitmap []bool) []byte {
  530. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  531. b := make([]byte, bitmapBytesLen)
  532. for i := 0; i < len(bitmap); i++ {
  533. if bitmap[i] {
  534. b[i/8] |= 1 << (i % 8)
  535. }
  536. }
  537. return b
  538. }
  539. func bytesToBitmap(b []byte) []bool {
  540. var bitmap []bool
  541. for i := 0; i < len(b); i++ {
  542. for j := 0; j < 8; j++ {
  543. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  544. }
  545. }
  546. return bitmap
  547. }
  548. // Get returns the value for a given key
  549. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  550. keyPath := make([]byte, t.hashFunction.Len())
  551. copy(keyPath[:], k)
  552. path := getPath(t.maxLevels, keyPath)
  553. // go down to the leaf
  554. var siblings [][]byte
  555. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  556. if err != nil {
  557. return nil, nil, err
  558. }
  559. leafK, leafV := ReadLeafValue(value)
  560. if !bytes.Equal(k, leafK) {
  561. return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s",
  562. BytesToBigInt(k), BytesToBigInt(leafK))
  563. }
  564. return leafK, leafV, nil
  565. }
  566. // CheckProof verifies the given proof. The proof verification depends on the
  567. // HashFunction passed as parameter.
  568. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  569. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  570. if err != nil {
  571. return false, err
  572. }
  573. keyPath := make([]byte, hashFunc.Len())
  574. copy(keyPath[:], k)
  575. key, _, err := newLeafValue(hashFunc, k, v)
  576. if err != nil {
  577. return false, err
  578. }
  579. path := getPath(len(siblings), keyPath)
  580. for i := len(siblings) - 1; i >= 0; i-- {
  581. if path[i] {
  582. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  583. if err != nil {
  584. return false, err
  585. }
  586. } else {
  587. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  588. if err != nil {
  589. return false, err
  590. }
  591. }
  592. }
  593. if bytes.Equal(key[:], root) {
  594. return true, nil
  595. }
  596. return false, nil
  597. }
  598. func (t *Tree) dbPut(k, v []byte) error {
  599. if t.tx == nil {
  600. return ErrDBNoTx
  601. }
  602. t.dbg.incDbPut()
  603. return t.tx.Put(k, v)
  604. }
  605. func (t *Tree) dbGet(k []byte) ([]byte, error) {
  606. // if key is empty, return empty as value
  607. if bytes.Equal(k, t.emptyHash) {
  608. return t.emptyHash, nil
  609. }
  610. t.dbg.incDbGet()
  611. v, err := t.db.Get(k)
  612. if err == nil {
  613. return v, nil
  614. }
  615. if t.tx != nil {
  616. return t.tx.Get(k)
  617. }
  618. return nil, db.ErrNotFound
  619. }
  620. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  621. // after the setNLeafs call.
  622. func (t *Tree) incNLeafs(nLeafs int) error {
  623. oldNLeafs, err := t.GetNLeafs()
  624. if err != nil {
  625. return err
  626. }
  627. newNLeafs := oldNLeafs + nLeafs
  628. return t.setNLeafs(newNLeafs)
  629. }
  630. // Warning: should be called with a Tree.tx created, and with a Tree.tx.Commit
  631. // after the setNLeafs call.
  632. func (t *Tree) setNLeafs(nLeafs int) error {
  633. b := make([]byte, 8)
  634. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  635. if err := t.dbPut(dbKeyNLeafs, b); err != nil {
  636. return err
  637. }
  638. return nil
  639. }
  640. // GetNLeafs returns the number of Leafs of the Tree.
  641. func (t *Tree) GetNLeafs() (int, error) {
  642. b, err := t.dbGet(dbKeyNLeafs)
  643. if err != nil {
  644. return 0, err
  645. }
  646. nLeafs := binary.LittleEndian.Uint64(b)
  647. return int(nLeafs), nil
  648. }
  649. // Iterate iterates through the full Tree, executing the given function on each
  650. // node of the Tree.
  651. func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error {
  652. // allow to define which root to use
  653. if rootKey == nil {
  654. rootKey = t.Root()
  655. }
  656. return t.iter(rootKey, f)
  657. }
  658. // IterateWithStop does the same than Iterate, but with int for the current
  659. // level, and a boolean parameter used by the passed function, is to indicate to
  660. // stop iterating on the branch when the method returns 'true'.
  661. func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error {
  662. // allow to define which root to use
  663. if rootKey == nil {
  664. rootKey = t.Root()
  665. }
  666. return t.iterWithStop(rootKey, 0, f)
  667. }
  668. func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error {
  669. v, err := t.dbGet(k)
  670. if err != nil {
  671. return err
  672. }
  673. currLevel++
  674. switch v[0] {
  675. case PrefixValueEmpty:
  676. f(currLevel, k, v)
  677. case PrefixValueLeaf:
  678. f(currLevel, k, v)
  679. case PrefixValueIntermediate:
  680. stop := f(currLevel, k, v)
  681. if stop {
  682. return nil
  683. }
  684. l, r := ReadIntermediateChilds(v)
  685. if err = t.iterWithStop(l, currLevel, f); err != nil {
  686. return err
  687. }
  688. if err = t.iterWithStop(r, currLevel, f); err != nil {
  689. return err
  690. }
  691. default:
  692. return ErrInvalidValuePrefix
  693. }
  694. return nil
  695. }
  696. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  697. f2 := func(currLvl int, k, v []byte) bool {
  698. f(k, v)
  699. return false
  700. }
  701. return t.iterWithStop(k, 0, f2)
  702. }
  703. // Dump exports all the Tree leafs in a byte array of length:
  704. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  705. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  706. // [ len(k) | len(v) | key | value ]
  707. // Where S is the size of the output of the hash function used for the Tree.
  708. func (t *Tree) Dump(rootKey []byte) ([]byte, error) {
  709. // allow to define which root to use
  710. if rootKey == nil {
  711. rootKey = t.Root()
  712. }
  713. // WARNING current encoding only supports key & values of 255 bytes each
  714. // (due using only 1 byte for the length headers).
  715. var b []byte
  716. err := t.Iterate(rootKey, func(k, v []byte) {
  717. if v[0] != PrefixValueLeaf {
  718. return
  719. }
  720. leafK, leafV := ReadLeafValue(v)
  721. kv := make([]byte, 2+len(leafK)+len(leafV))
  722. kv[0] = byte(len(leafK))
  723. kv[1] = byte(len(leafV))
  724. copy(kv[2:2+len(leafK)], leafK)
  725. copy(kv[2+len(leafK):], leafV)
  726. b = append(b, kv...)
  727. })
  728. return b, err
  729. }
  730. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  731. // method) in the Tree.
  732. func (t *Tree) ImportDump(b []byte) error {
  733. r := bytes.NewReader(b)
  734. var err error
  735. var keys, values [][]byte
  736. for {
  737. l := make([]byte, 2)
  738. _, err = io.ReadFull(r, l)
  739. if err == io.EOF {
  740. break
  741. } else if err != nil {
  742. return err
  743. }
  744. k := make([]byte, l[0])
  745. _, err = io.ReadFull(r, k)
  746. if err != nil {
  747. return err
  748. }
  749. v := make([]byte, l[1])
  750. _, err = io.ReadFull(r, v)
  751. if err != nil {
  752. return err
  753. }
  754. keys = append(keys, k)
  755. values = append(values, v)
  756. }
  757. if _, err = t.AddBatch(keys, values); err != nil {
  758. return err
  759. }
  760. return nil
  761. }
  762. // Graphviz iterates across the full tree to generate a string Graphviz
  763. // representation of the tree and writes it to w
  764. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  765. return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
  766. }
  767. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  768. // generate a string Graphviz representation of the first NLevels of the tree
  769. // and writes it to w
  770. func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
  771. fmt.Fprintf(w, `digraph hierarchy {
  772. node [fontname=Monospace,fontsize=10,shape=box]
  773. `)
  774. if rootKey == nil {
  775. rootKey = t.Root()
  776. }
  777. nEmpties := 0
  778. err := t.iterWithStop(rootKey, 0, func(currLvl int, k, v []byte) bool {
  779. if currLvl == untilLvl {
  780. return true // to stop the iter from going down
  781. }
  782. switch v[0] {
  783. case PrefixValueEmpty:
  784. case PrefixValueLeaf:
  785. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  786. // key & value from the leaf
  787. kB, vB := ReadLeafValue(v)
  788. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  789. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  790. hex.EncodeToString(vB[:nChars]))
  791. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  792. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  793. case PrefixValueIntermediate:
  794. l, r := ReadIntermediateChilds(v)
  795. lStr := hex.EncodeToString(l[:nChars])
  796. rStr := hex.EncodeToString(r[:nChars])
  797. eStr := ""
  798. if bytes.Equal(l, t.emptyHash) {
  799. lStr = fmt.Sprintf("empty%v", nEmpties)
  800. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  801. lStr)
  802. nEmpties++
  803. }
  804. if bytes.Equal(r, t.emptyHash) {
  805. rStr = fmt.Sprintf("empty%v", nEmpties)
  806. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  807. rStr)
  808. nEmpties++
  809. }
  810. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  811. lStr, rStr)
  812. fmt.Fprint(w, eStr)
  813. default:
  814. }
  815. return false
  816. })
  817. fmt.Fprintf(w, "}\n")
  818. return err
  819. }
  820. // PrintGraphviz prints the output of Tree.Graphviz
  821. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  822. if rootKey == nil {
  823. rootKey = t.Root()
  824. }
  825. return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
  826. }
  827. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  828. func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
  829. if rootKey == nil {
  830. rootKey = t.Root()
  831. }
  832. w := bytes.NewBufferString("")
  833. fmt.Fprintf(w,
  834. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  835. err := t.GraphvizFirstNLevels(w, rootKey, untilLvl)
  836. if err != nil {
  837. fmt.Println(w)
  838. return err
  839. }
  840. fmt.Fprintf(w,
  841. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  842. fmt.Println(w)
  843. return nil
  844. }
  845. // Purge WIP: unimplemented TODO
  846. func (t *Tree) Purge(keys [][]byte) error {
  847. return nil
  848. }
  849. // TODO circom proofs