You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1045 lines
28 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree, following the specification from
  4. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  5. https://eprint.iacr.org/2018/955.
  6. Allows to define which hash function to use. So for example, when working with
  7. zkSnarks the Poseidon hash function can be used, but when not, it can be used
  8. the Blake2b hash function, which has much faster computation time.
  9. */
  10. package arbo
  11. import (
  12. "bytes"
  13. "encoding/binary"
  14. "encoding/hex"
  15. "fmt"
  16. "io"
  17. "math"
  18. "sync"
  19. "go.vocdoni.io/dvote/db"
  20. )
  21. const (
  22. // PrefixValueLen defines the bytes-prefix length used for the Value
  23. // bytes representation stored in the db
  24. PrefixValueLen = 2
  25. // PrefixValueEmpty is used for the first byte of a Value to indicate
  26. // that is an Empty value
  27. PrefixValueEmpty = 0
  28. // PrefixValueLeaf is used for the first byte of a Value to indicate
  29. // that is a Leaf value
  30. PrefixValueLeaf = 1
  31. // PrefixValueIntermediate is used for the first byte of a Value to
  32. // indicate that is a Intermediate value
  33. PrefixValueIntermediate = 2
  34. // nChars is used to crop the Graphviz nodes labels
  35. nChars = 4
  36. )
  37. var (
  38. dbKeyRoot = []byte("root")
  39. dbKeyNLeafs = []byte("nleafs")
  40. emptyValue = []byte{0}
  41. // ErrKeyNotFound is used when a key is not found in the db neither in
  42. // the current db Batch.
  43. ErrKeyNotFound = fmt.Errorf("key not found")
  44. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  45. // tree that already exists.
  46. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  47. // ErrInvalidValuePrefix is used when going down into the tree, a value
  48. // is read from the db and has an unrecognized prefix.
  49. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  50. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.dbBatch==nil
  51. ErrDBNoTx = fmt.Errorf("dbPut error: no db Batch")
  52. // ErrMaxLevel indicates when going down into the tree, the max level is
  53. // reached
  54. ErrMaxLevel = fmt.Errorf("max level reached")
  55. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  56. // virtual level is reached
  57. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  58. )
  59. // Tree defines the struct that implements the MerkleTree functionalities
  60. type Tree struct {
  61. sync.RWMutex
  62. db db.Database
  63. maxLevels int
  64. root []byte
  65. hashFunction HashFunction
  66. // TODO in the methods that use it, check if emptyHash param is len>0
  67. // (check if it has been initialized)
  68. emptyHash []byte
  69. dbg *dbgStats
  70. }
  71. // NewTree returns a new Tree, if there is a Tree still in the given database, it
  72. // will load it.
  73. func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, error) {
  74. t := Tree{db: database, maxLevels: maxLevels, hashFunction: hash}
  75. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  76. wTx := t.db.WriteTx()
  77. defer wTx.Discard()
  78. root, err := wTx.Get(dbKeyRoot)
  79. if err == db.ErrKeyNotFound {
  80. // store new root 0
  81. t.root = t.emptyHash
  82. if err = wTx.Set(dbKeyRoot, t.root); err != nil {
  83. return nil, err
  84. }
  85. if err = t.setNLeafs(wTx, 0); err != nil {
  86. return nil, err
  87. }
  88. if err = wTx.Commit(); err != nil {
  89. return nil, err
  90. }
  91. return &t, nil
  92. } else if err != nil {
  93. return nil, err
  94. }
  95. if err = wTx.Commit(); err != nil {
  96. return nil, err
  97. }
  98. t.root = root
  99. return &t, nil
  100. }
  101. // Root returns the root of the Tree
  102. func (t *Tree) Root() []byte {
  103. // TODO get Root from db
  104. return t.root
  105. }
  106. // HashFunction returns Tree.hashFunction
  107. func (t *Tree) HashFunction() HashFunction {
  108. return t.hashFunction
  109. }
  110. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  111. // the indexes of the keys failed to add. Supports empty values as input
  112. // parameters, which is equivalent to 0 valued byte array.
  113. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  114. wTx := t.db.WriteTx()
  115. defer wTx.Discard()
  116. invalids, err := t.AddBatchWithTx(wTx, keys, values)
  117. if err != nil {
  118. return invalids, err
  119. }
  120. return invalids, wTx.Commit()
  121. }
  122. // AddBatchWithTx does the same than the AddBatch method, but allowing to pass
  123. // the db.WriteTx that is used. The db.WriteTx will not be committed inside
  124. // this method.
  125. func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, error) {
  126. t.Lock()
  127. defer t.Unlock()
  128. vt, err := t.loadVT(wTx)
  129. if err != nil {
  130. return nil, err
  131. }
  132. e := []byte{}
  133. // equal the number of keys & values
  134. if len(keys) > len(values) {
  135. // add missing values
  136. for i := len(values); i < len(keys); i++ {
  137. values = append(values, e)
  138. }
  139. } else if len(keys) < len(values) {
  140. // crop extra values
  141. values = values[:len(keys)]
  142. }
  143. invalids, err := vt.addBatch(keys, values)
  144. if err != nil {
  145. return nil, err
  146. }
  147. // once the VirtualTree is build, compute the hashes
  148. pairs, err := vt.computeHashes()
  149. if err != nil {
  150. // currently invalids in computeHashes are not counted,
  151. // but should not be needed, as if there is an error there is
  152. // nothing stored in the db and the error is returned
  153. return nil, err
  154. }
  155. t.root = vt.root.h
  156. // store pairs in db
  157. for i := 0; i < len(pairs); i++ {
  158. if err := wTx.Set(pairs[i][0], pairs[i][1]); err != nil {
  159. return nil, err
  160. }
  161. }
  162. // store root to db
  163. if err := wTx.Set(dbKeyRoot, t.root); err != nil {
  164. return nil, err
  165. }
  166. // update nLeafs
  167. if err := t.incNLeafs(wTx, len(keys)-len(invalids)); err != nil {
  168. return nil, err
  169. }
  170. return invalids, nil
  171. }
  172. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  173. // the same leafs.
  174. func (t *Tree) loadVT(rTx db.ReadTx) (vt, error) {
  175. vt := newVT(t.maxLevels, t.hashFunction)
  176. vt.params.dbg = t.dbg
  177. err := t.IterateWithTx(rTx, nil, func(k, v []byte) {
  178. if v[0] != PrefixValueLeaf {
  179. return
  180. }
  181. leafK, leafV := ReadLeafValue(v)
  182. if err := vt.add(0, leafK, leafV); err != nil {
  183. // TODO instead of panic, return this error
  184. panic(err)
  185. }
  186. })
  187. return vt, err
  188. }
  189. // Add inserts the key-value into the Tree. If the inputs come from a
  190. // *big.Int, is expected that are represented by a Little-Endian byte array
  191. // (for circom compatibility).
  192. func (t *Tree) Add(k, v []byte) error {
  193. wTx := t.db.WriteTx()
  194. defer wTx.Discard()
  195. if err := t.AddWithTx(wTx, k, v); err != nil {
  196. return err
  197. }
  198. return wTx.Commit()
  199. }
  200. // AddWithTx does the same than the Add method, but allowing to pass the
  201. // db.WriteTx that is used. The db.WriteTx will not be committed inside this
  202. // method.
  203. func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error {
  204. t.Lock()
  205. defer t.Unlock()
  206. err := t.add(wTx, 0, k, v) // add from level 0
  207. if err != nil {
  208. return err
  209. }
  210. // store root to db
  211. if err := wTx.Set(dbKeyRoot, t.root); err != nil {
  212. return err
  213. }
  214. // update nLeafs
  215. if err = t.incNLeafs(wTx, 1); err != nil {
  216. return err
  217. }
  218. return nil
  219. }
  220. func (t *Tree) add(wTx db.WriteTx, fromLvl int, k, v []byte) error {
  221. keyPath := make([]byte, t.hashFunction.Len())
  222. copy(keyPath[:], k)
  223. path := getPath(t.maxLevels, keyPath)
  224. // go down to the leaf
  225. var siblings [][]byte
  226. _, _, siblings, err := t.down(wTx, k, t.root, siblings, path, fromLvl, false)
  227. if err != nil {
  228. return err
  229. }
  230. leafKey, leafValue, err := t.newLeafValue(k, v)
  231. if err != nil {
  232. return err
  233. }
  234. if err := wTx.Set(leafKey, leafValue); err != nil {
  235. return err
  236. }
  237. // go up to the root
  238. if len(siblings) == 0 {
  239. t.root = leafKey
  240. return nil
  241. }
  242. root, err := t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl)
  243. if err != nil {
  244. return err
  245. }
  246. t.root = root
  247. return nil
  248. }
  249. // down goes down to the leaf recursively
  250. func (t *Tree) down(rTx db.ReadTx, newKey, currKey []byte, siblings [][]byte,
  251. path []bool, currLvl int, getLeaf bool) (
  252. []byte, []byte, [][]byte, error) {
  253. if currLvl > t.maxLevels-1 {
  254. return nil, nil, nil, ErrMaxLevel
  255. }
  256. var err error
  257. var currValue []byte
  258. if bytes.Equal(currKey, t.emptyHash) {
  259. // empty value
  260. return currKey, emptyValue, siblings, nil
  261. }
  262. currValue, err = rTx.Get(currKey)
  263. if err != nil {
  264. return nil, nil, nil, err
  265. }
  266. switch currValue[0] {
  267. case PrefixValueEmpty: // empty
  268. fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n",
  269. hex.EncodeToString(newKey), hex.EncodeToString(currKey),
  270. currLvl, hex.EncodeToString(currValue))
  271. panic("This point should not be reached, as the 'if' above" +
  272. " should avoid reaching this point. This panic is temporary" +
  273. " for reporting purposes, will be deleted in future versions." +
  274. " Please paste this log (including the previous lines) in a" +
  275. " new issue: https://github.com/vocdoni/arbo/issues/new") // TMP
  276. case PrefixValueLeaf: // leaf
  277. if !bytes.Equal(currValue, emptyValue) {
  278. if getLeaf {
  279. return currKey, currValue, siblings, nil
  280. }
  281. oldLeafKey, _ := ReadLeafValue(currValue)
  282. if bytes.Equal(newKey, oldLeafKey) {
  283. return nil, nil, nil, ErrKeyAlreadyExists
  284. }
  285. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  286. copy(oldLeafKeyFull[:], oldLeafKey)
  287. // if currKey is already used, go down until paths diverge
  288. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  289. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  290. if err != nil {
  291. return nil, nil, nil, err
  292. }
  293. }
  294. return currKey, currValue, siblings, nil
  295. case PrefixValueIntermediate: // intermediate
  296. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  297. return nil, nil, nil,
  298. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  299. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  300. }
  301. // collect siblings while going down
  302. if path[currLvl] {
  303. // right
  304. lChild, rChild := ReadIntermediateChilds(currValue)
  305. siblings = append(siblings, lChild)
  306. return t.down(rTx, newKey, rChild, siblings, path, currLvl+1, getLeaf)
  307. }
  308. // left
  309. lChild, rChild := ReadIntermediateChilds(currValue)
  310. siblings = append(siblings, rChild)
  311. return t.down(rTx, newKey, lChild, siblings, path, currLvl+1, getLeaf)
  312. default:
  313. return nil, nil, nil, ErrInvalidValuePrefix
  314. }
  315. }
  316. // downVirtually is used when in a leaf already exists, and a new leaf which
  317. // shares the path until the existing leaf is being added
  318. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  319. newPath []bool, currLvl int) ([][]byte, error) {
  320. var err error
  321. if currLvl > t.maxLevels-1 {
  322. return nil, ErrMaxVirtualLevel
  323. }
  324. if oldPath[currLvl] == newPath[currLvl] {
  325. siblings = append(siblings, t.emptyHash)
  326. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  327. if err != nil {
  328. return nil, err
  329. }
  330. return siblings, nil
  331. }
  332. // reached the divergence
  333. siblings = append(siblings, oldKey)
  334. return siblings, nil
  335. }
  336. // up goes up recursively updating the intermediate nodes
  337. func (t *Tree) up(wTx db.WriteTx, key []byte, siblings [][]byte, path []bool,
  338. currLvl, toLvl int) ([]byte, error) {
  339. var k, v []byte
  340. var err error
  341. if path[currLvl+toLvl] {
  342. k, v, err = t.newIntermediate(siblings[currLvl], key)
  343. if err != nil {
  344. return nil, err
  345. }
  346. } else {
  347. k, v, err = t.newIntermediate(key, siblings[currLvl])
  348. if err != nil {
  349. return nil, err
  350. }
  351. }
  352. // store k-v to db
  353. if err = wTx.Set(k, v); err != nil {
  354. return nil, err
  355. }
  356. if currLvl == 0 {
  357. // reached the root
  358. return k, nil
  359. }
  360. return t.up(wTx, k, siblings, path, currLvl-1, toLvl)
  361. }
  362. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  363. t.dbg.incHash()
  364. return newLeafValue(t.hashFunction, k, v)
  365. }
  366. // newLeafValue takes a key & value from a leaf, and computes the leaf hash,
  367. // which is used as the leaf key. And the value is the concatenation of the
  368. // inputed key & value. The output of this function is used as key-value to
  369. // store the leaf in the DB.
  370. // [ 1 byte | 1 byte | N bytes | M bytes ]
  371. // [ type of node | length of key | key | value ]
  372. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  373. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  374. if err != nil {
  375. return nil, nil, err
  376. }
  377. var leafValue []byte
  378. leafValue = append(leafValue, byte(1))
  379. leafValue = append(leafValue, byte(len(k)))
  380. leafValue = append(leafValue, k...)
  381. leafValue = append(leafValue, v...)
  382. return leafKey, leafValue, nil
  383. }
  384. // ReadLeafValue reads from a byte array the leaf key & value
  385. func ReadLeafValue(b []byte) ([]byte, []byte) {
  386. if len(b) < PrefixValueLen {
  387. return []byte{}, []byte{}
  388. }
  389. kLen := b[1]
  390. if len(b) < PrefixValueLen+int(kLen) {
  391. return []byte{}, []byte{}
  392. }
  393. k := b[PrefixValueLen : PrefixValueLen+kLen]
  394. v := b[PrefixValueLen+kLen:]
  395. return k, v
  396. }
  397. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  398. t.dbg.incHash()
  399. return newIntermediate(t.hashFunction, l, r)
  400. }
  401. // newIntermediate takes the left & right keys of a intermediate node, and
  402. // computes its hash. Returns the hash of the node, which is the node key, and a
  403. // byte array that contains the value (which contains the left & right child
  404. // keys) to store in the DB.
  405. // [ 1 byte | 1 byte | N bytes | N bytes ]
  406. // [ type of node | length of key | left key | right key ]
  407. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  408. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  409. b[0] = 2
  410. b[1] = byte(len(l))
  411. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  412. copy(b[PrefixValueLen+hashFunc.Len():], r)
  413. key, err := hashFunc.Hash(l, r)
  414. if err != nil {
  415. return nil, nil, err
  416. }
  417. return key, b, nil
  418. }
  419. // ReadIntermediateChilds reads from a byte array the two childs keys
  420. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  421. if len(b) < PrefixValueLen {
  422. return []byte{}, []byte{}
  423. }
  424. lLen := b[1]
  425. if len(b) < PrefixValueLen+int(lLen) {
  426. return []byte{}, []byte{}
  427. }
  428. l := b[PrefixValueLen : PrefixValueLen+lLen]
  429. r := b[PrefixValueLen+lLen:]
  430. return l, r
  431. }
  432. func getPath(numLevels int, k []byte) []bool {
  433. path := make([]bool, numLevels)
  434. for n := 0; n < numLevels; n++ {
  435. path[n] = k[n/8]&(1<<(n%8)) != 0
  436. }
  437. return path
  438. }
  439. // Update updates the value for a given existing key. If the given key does not
  440. // exist, returns an error.
  441. func (t *Tree) Update(k, v []byte) error {
  442. wTx := t.db.WriteTx()
  443. defer wTx.Discard()
  444. if err := t.UpdateWithTx(wTx, k, v); err != nil {
  445. return err
  446. }
  447. return wTx.Commit()
  448. }
  449. // UpdateWithTx does the same than the Update method, but allowing to pass the
  450. // db.WriteTx that is used. The db.WriteTx will not be committed inside this
  451. // method.
  452. func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
  453. t.Lock()
  454. defer t.Unlock()
  455. var err error
  456. keyPath := make([]byte, t.hashFunction.Len())
  457. copy(keyPath[:], k)
  458. path := getPath(t.maxLevels, keyPath)
  459. var siblings [][]byte
  460. _, valueAtBottom, siblings, err := t.down(wTx, k, t.root, siblings, path, 0, true)
  461. if err != nil {
  462. return err
  463. }
  464. oldKey, _ := ReadLeafValue(valueAtBottom)
  465. if !bytes.Equal(oldKey, k) {
  466. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  467. }
  468. leafKey, leafValue, err := t.newLeafValue(k, v)
  469. if err != nil {
  470. return err
  471. }
  472. if err := wTx.Set(leafKey, leafValue); err != nil {
  473. return err
  474. }
  475. // go up to the root
  476. if len(siblings) == 0 {
  477. t.root = leafKey
  478. return nil
  479. }
  480. root, err := t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0)
  481. if err != nil {
  482. return err
  483. }
  484. t.root = root
  485. // store root to db
  486. if err := wTx.Set(dbKeyRoot, t.root); err != nil {
  487. return err
  488. }
  489. return nil
  490. }
  491. // GenProof generates a MerkleTree proof for the given key. The leaf value is
  492. // returned, together with the packed siblings of the proof, and a boolean
  493. // parameter that indicates if the proof is of existence (true) or not (false).
  494. func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
  495. rTx := t.db.ReadTx()
  496. defer rTx.Discard()
  497. return t.GenProofWithTx(rTx, k)
  498. }
  499. // GenProofWithTx does the same than the GenProof method, but allowing to pass
  500. // the db.ReadTx that is used.
  501. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) {
  502. keyPath := make([]byte, t.hashFunction.Len())
  503. copy(keyPath[:], k)
  504. path := getPath(t.maxLevels, keyPath)
  505. // go down to the leaf
  506. var siblings [][]byte
  507. _, value, siblings, err := t.down(rTx, k, t.root, siblings, path, 0, true)
  508. if err != nil {
  509. return nil, nil, nil, false, err
  510. }
  511. s := PackSiblings(t.hashFunction, siblings)
  512. leafK, leafV := ReadLeafValue(value)
  513. if !bytes.Equal(k, leafK) {
  514. // key not in tree, proof of non-existence
  515. return leafK, leafV, s, false, err
  516. }
  517. return leafK, leafV, s, true, nil
  518. }
  519. // PackSiblings packs the siblings into a byte array.
  520. // [ 1 byte | L bytes | S * N bytes ]
  521. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  522. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  523. // array. And S is the size of the output of the hash function used for the
  524. // Tree.
  525. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  526. var b []byte
  527. var bitmap []bool
  528. emptySibling := make([]byte, hashFunc.Len())
  529. for i := 0; i < len(siblings); i++ {
  530. if bytes.Equal(siblings[i], emptySibling) {
  531. bitmap = append(bitmap, false)
  532. } else {
  533. bitmap = append(bitmap, true)
  534. b = append(b, siblings[i]...)
  535. }
  536. }
  537. bitmapBytes := bitmapToBytes(bitmap)
  538. l := len(bitmapBytes)
  539. res := make([]byte, l+1+len(b))
  540. res[0] = byte(l) // set the bitmapBytes length
  541. copy(res[1:1+l], bitmapBytes)
  542. copy(res[1+l:], b)
  543. return res
  544. }
  545. // UnpackSiblings unpacks the siblings from a byte array.
  546. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  547. l := b[0]
  548. bitmapBytes := b[1 : 1+l]
  549. bitmap := bytesToBitmap(bitmapBytes)
  550. siblingsBytes := b[1+l:]
  551. iSibl := 0
  552. emptySibl := make([]byte, hashFunc.Len())
  553. var siblings [][]byte
  554. for i := 0; i < len(bitmap); i++ {
  555. if iSibl >= len(siblingsBytes) {
  556. break
  557. }
  558. if bitmap[i] {
  559. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  560. iSibl += hashFunc.Len()
  561. } else {
  562. siblings = append(siblings, emptySibl)
  563. }
  564. }
  565. return siblings, nil
  566. }
  567. func bitmapToBytes(bitmap []bool) []byte {
  568. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  569. b := make([]byte, bitmapBytesLen)
  570. for i := 0; i < len(bitmap); i++ {
  571. if bitmap[i] {
  572. b[i/8] |= 1 << (i % 8)
  573. }
  574. }
  575. return b
  576. }
  577. func bytesToBitmap(b []byte) []bool {
  578. var bitmap []bool
  579. for i := 0; i < len(b); i++ {
  580. for j := 0; j < 8; j++ {
  581. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  582. }
  583. }
  584. return bitmap
  585. }
  586. // Get returns the value for a given key
  587. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  588. rTx := t.db.ReadTx()
  589. defer rTx.Discard()
  590. return t.GetWithTx(rTx, k)
  591. }
  592. // GetWithTx does the same than the Get method, but allowing to pass the
  593. // db.ReadTx that is used.
  594. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) {
  595. keyPath := make([]byte, t.hashFunction.Len())
  596. copy(keyPath[:], k)
  597. path := getPath(t.maxLevels, keyPath)
  598. // go down to the leaf
  599. var siblings [][]byte
  600. _, value, _, err := t.down(rTx, k, t.root, siblings, path, 0, true)
  601. if err != nil {
  602. return nil, nil, err
  603. }
  604. leafK, leafV := ReadLeafValue(value)
  605. if !bytes.Equal(k, leafK) {
  606. return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s",
  607. BytesToBigInt(k), BytesToBigInt(leafK))
  608. }
  609. return leafK, leafV, nil
  610. }
  611. // CheckProof verifies the given proof. The proof verification depends on the
  612. // HashFunction passed as parameter.
  613. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  614. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  615. if err != nil {
  616. return false, err
  617. }
  618. keyPath := make([]byte, hashFunc.Len())
  619. copy(keyPath[:], k)
  620. key, _, err := newLeafValue(hashFunc, k, v)
  621. if err != nil {
  622. return false, err
  623. }
  624. path := getPath(len(siblings), keyPath)
  625. for i := len(siblings) - 1; i >= 0; i-- {
  626. if path[i] {
  627. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  628. if err != nil {
  629. return false, err
  630. }
  631. } else {
  632. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  633. if err != nil {
  634. return false, err
  635. }
  636. }
  637. }
  638. if bytes.Equal(key[:], root) {
  639. return true, nil
  640. }
  641. return false, nil
  642. }
  643. func (t *Tree) incNLeafs(wTx db.WriteTx, nLeafs int) error {
  644. oldNLeafs, err := t.GetNLeafs()
  645. if err != nil {
  646. return err
  647. }
  648. newNLeafs := oldNLeafs + nLeafs
  649. return t.setNLeafs(wTx, newNLeafs)
  650. }
  651. func (t *Tree) setNLeafs(wTx db.WriteTx, nLeafs int) error {
  652. b := make([]byte, 8)
  653. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  654. if err := wTx.Set(dbKeyNLeafs, b); err != nil {
  655. return err
  656. }
  657. return nil
  658. }
  659. // GetNLeafs returns the number of Leafs of the Tree.
  660. func (t *Tree) GetNLeafs() (int, error) {
  661. rTx := t.db.ReadTx()
  662. defer rTx.Discard()
  663. return t.GetNLeafsWithTx(rTx)
  664. }
  665. // GetNLeafsWithTx does the same than the GetNLeafs method, but allowing to
  666. // pass the db.ReadTx that is used.
  667. func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) {
  668. b, err := rTx.Get(dbKeyNLeafs)
  669. if err != nil {
  670. return 0, err
  671. }
  672. nLeafs := binary.LittleEndian.Uint64(b)
  673. return int(nLeafs), nil
  674. }
  675. // Snapshot returns a copy of the Tree from the given root
  676. func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) {
  677. // TODO currently this method only changes the 'root pointer', but the
  678. // db continues being the same. In a future iteration, once the
  679. // db.Database interface allows to do database checkpoints, this method
  680. // could be updated to do a full checkpoint of the database for the
  681. // snapshot, to return a completly new independent tree containing the
  682. // snapshot.
  683. t.RLock()
  684. defer t.RUnlock()
  685. // allow to define which root to use
  686. if rootKey == nil {
  687. rootKey = t.Root()
  688. }
  689. return &Tree{
  690. db: t.db,
  691. maxLevels: t.maxLevels,
  692. root: rootKey,
  693. hashFunction: t.hashFunction,
  694. dbg: t.dbg,
  695. }, nil
  696. }
  697. // Iterate iterates through the full Tree, executing the given function on each
  698. // node of the Tree.
  699. func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error {
  700. rTx := t.db.ReadTx()
  701. defer rTx.Discard()
  702. return t.IterateWithTx(rTx, rootKey, f)
  703. }
  704. // IterateWithTx does the same than the Iterate method, but allowing to pass
  705. // the db.ReadTx that is used.
  706. func (t *Tree) IterateWithTx(rTx db.ReadTx, rootKey []byte, f func([]byte, []byte)) error {
  707. // allow to define which root to use
  708. if rootKey == nil {
  709. rootKey = t.Root()
  710. }
  711. return t.iter(rTx, rootKey, f)
  712. }
  713. // IterateWithStop does the same than Iterate, but with int for the current
  714. // level, and a boolean parameter used by the passed function, is to indicate to
  715. // stop iterating on the branch when the method returns 'true'.
  716. func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error {
  717. // allow to define which root to use
  718. if rootKey == nil {
  719. rootKey = t.Root()
  720. }
  721. rTx := t.db.ReadTx()
  722. defer rTx.Discard()
  723. return t.iterWithStop(rTx, rootKey, 0, f)
  724. }
  725. // IterateWithStopWithTx does the same than the IterateWithStop method, but
  726. // allowing to pass the db.ReadTx that is used.
  727. func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, rootKey []byte,
  728. f func(int, []byte, []byte) bool) error {
  729. // allow to define which root to use
  730. if rootKey == nil {
  731. rootKey = t.Root()
  732. }
  733. return t.iterWithStop(rTx, rootKey, 0, f)
  734. }
  735. func (t *Tree) iterWithStop(rTx db.ReadTx, k []byte, currLevel int,
  736. f func(int, []byte, []byte) bool) error {
  737. var v []byte
  738. var err error
  739. if bytes.Equal(k, t.emptyHash) {
  740. v = t.emptyHash
  741. } else {
  742. v, err = rTx.Get(k)
  743. if err != nil {
  744. return err
  745. }
  746. }
  747. currLevel++
  748. switch v[0] {
  749. case PrefixValueEmpty:
  750. f(currLevel, k, v)
  751. case PrefixValueLeaf:
  752. f(currLevel, k, v)
  753. case PrefixValueIntermediate:
  754. stop := f(currLevel, k, v)
  755. if stop {
  756. return nil
  757. }
  758. l, r := ReadIntermediateChilds(v)
  759. if err = t.iterWithStop(rTx, l, currLevel, f); err != nil {
  760. return err
  761. }
  762. if err = t.iterWithStop(rTx, r, currLevel, f); err != nil {
  763. return err
  764. }
  765. default:
  766. return ErrInvalidValuePrefix
  767. }
  768. return nil
  769. }
  770. func (t *Tree) iter(rTx db.ReadTx, k []byte, f func([]byte, []byte)) error {
  771. f2 := func(currLvl int, k, v []byte) bool {
  772. f(k, v)
  773. return false
  774. }
  775. return t.iterWithStop(rTx, k, 0, f2)
  776. }
  777. // Dump exports all the Tree leafs in a byte array of length:
  778. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  779. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  780. // [ len(k) | len(v) | key | value ]
  781. // Where S is the size of the output of the hash function used for the Tree.
  782. func (t *Tree) Dump(rootKey []byte) ([]byte, error) {
  783. // allow to define which root to use
  784. if rootKey == nil {
  785. rootKey = t.Root()
  786. }
  787. // WARNING current encoding only supports key & values of 255 bytes each
  788. // (due using only 1 byte for the length headers).
  789. var b []byte
  790. err := t.Iterate(rootKey, func(k, v []byte) {
  791. if v[0] != PrefixValueLeaf {
  792. return
  793. }
  794. leafK, leafV := ReadLeafValue(v)
  795. kv := make([]byte, 2+len(leafK)+len(leafV))
  796. kv[0] = byte(len(leafK))
  797. kv[1] = byte(len(leafV))
  798. copy(kv[2:2+len(leafK)], leafK)
  799. copy(kv[2+len(leafK):], leafV)
  800. b = append(b, kv...)
  801. })
  802. return b, err
  803. }
  804. // ImportDump imports the leafs (that have been exported with the Dump method)
  805. // in the Tree.
  806. func (t *Tree) ImportDump(b []byte) error {
  807. r := bytes.NewReader(b)
  808. var err error
  809. var keys, values [][]byte
  810. for {
  811. l := make([]byte, 2)
  812. _, err = io.ReadFull(r, l)
  813. if err == io.EOF {
  814. break
  815. } else if err != nil {
  816. return err
  817. }
  818. k := make([]byte, l[0])
  819. _, err = io.ReadFull(r, k)
  820. if err != nil {
  821. return err
  822. }
  823. v := make([]byte, l[1])
  824. _, err = io.ReadFull(r, v)
  825. if err != nil {
  826. return err
  827. }
  828. keys = append(keys, k)
  829. values = append(values, v)
  830. }
  831. if _, err = t.AddBatch(keys, values); err != nil {
  832. return err
  833. }
  834. return nil
  835. }
  836. // Graphviz iterates across the full tree to generate a string Graphviz
  837. // representation of the tree and writes it to w
  838. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  839. return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
  840. }
  841. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  842. // generate a string Graphviz representation of the first NLevels of the tree
  843. // and writes it to w
  844. func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
  845. fmt.Fprintf(w, `digraph hierarchy {
  846. node [fontname=Monospace,fontsize=10,shape=box]
  847. `)
  848. if rootKey == nil {
  849. rootKey = t.Root()
  850. }
  851. rTx := t.db.ReadTx()
  852. defer rTx.Discard()
  853. nEmpties := 0
  854. err := t.iterWithStop(rTx, rootKey, 0, func(currLvl int, k, v []byte) bool {
  855. if currLvl == untilLvl {
  856. return true // to stop the iter from going down
  857. }
  858. switch v[0] {
  859. case PrefixValueEmpty:
  860. case PrefixValueLeaf:
  861. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  862. // key & value from the leaf
  863. kB, vB := ReadLeafValue(v)
  864. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  865. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  866. hex.EncodeToString(vB[:nChars]))
  867. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  868. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  869. case PrefixValueIntermediate:
  870. l, r := ReadIntermediateChilds(v)
  871. lStr := hex.EncodeToString(l[:nChars])
  872. rStr := hex.EncodeToString(r[:nChars])
  873. eStr := ""
  874. if bytes.Equal(l, t.emptyHash) {
  875. lStr = fmt.Sprintf("empty%v", nEmpties)
  876. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  877. lStr)
  878. nEmpties++
  879. }
  880. if bytes.Equal(r, t.emptyHash) {
  881. rStr = fmt.Sprintf("empty%v", nEmpties)
  882. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  883. rStr)
  884. nEmpties++
  885. }
  886. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  887. lStr, rStr)
  888. fmt.Fprint(w, eStr)
  889. default:
  890. }
  891. return false
  892. })
  893. fmt.Fprintf(w, "}\n")
  894. return err
  895. }
  896. // PrintGraphviz prints the output of Tree.Graphviz
  897. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  898. if rootKey == nil {
  899. rootKey = t.Root()
  900. }
  901. return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
  902. }
  903. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  904. func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
  905. if rootKey == nil {
  906. rootKey = t.Root()
  907. }
  908. w := bytes.NewBufferString("")
  909. fmt.Fprintf(w,
  910. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  911. err := t.GraphvizFirstNLevels(w, rootKey, untilLvl)
  912. if err != nil {
  913. fmt.Println(w)
  914. return err
  915. }
  916. fmt.Fprintf(w,
  917. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  918. fmt.Println(w)
  919. return nil
  920. }
  921. // TODO circom proofs
  922. // TODO data structure for proofs (including root, key, value, siblings,
  923. // hashFunction) + method to verify that data structure