You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1272 lines
34 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree, following the specification from
  4. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  5. https://eprint.iacr.org/2018/955.
  6. Allows to define which hash function to use. So for example, when working with
  7. zkSnarks the Poseidon hash function can be used, but when not, it can be used
  8. the Blake2b hash function, which has much faster computation time.
  9. */
  10. package arbo
  11. import (
  12. "bytes"
  13. "encoding/binary"
  14. "encoding/hex"
  15. "fmt"
  16. "io"
  17. "math"
  18. "sync"
  19. "go.vocdoni.io/dvote/db"
  20. )
  21. const (
  22. // PrefixValueLen defines the bytes-prefix length used for the Value
  23. // bytes representation stored in the db
  24. PrefixValueLen = 2
  25. // PrefixValueEmpty is used for the first byte of a Value to indicate
  26. // that is an Empty value
  27. PrefixValueEmpty = 0
  28. // PrefixValueLeaf is used for the first byte of a Value to indicate
  29. // that is a Leaf value
  30. PrefixValueLeaf = 1
  31. // PrefixValueIntermediate is used for the first byte of a Value to
  32. // indicate that is a Intermediate value
  33. PrefixValueIntermediate = 2
  34. // nChars is used to crop the Graphviz nodes labels
  35. nChars = 4
  36. maxUint8 = int(^uint8(0)) // 2**8 -1
  37. maxUint16 = int(^uint16(0)) // 2**16 -1
  38. )
  39. var (
  40. dbKeyRoot = []byte("root")
  41. dbKeyNLeafs = []byte("nleafs")
  42. emptyValue = []byte{0}
  43. // ErrKeyNotFound is used when a key is not found in the db neither in
  44. // the current db Batch.
  45. ErrKeyNotFound = fmt.Errorf("key not found")
  46. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  47. // tree that already exists.
  48. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  49. // ErrInvalidValuePrefix is used when going down into the tree, a value
  50. // is read from the db and has an unrecognized prefix.
  51. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  52. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.dbBatch==nil
  53. ErrDBNoTx = fmt.Errorf("dbPut error: no db Batch")
  54. // ErrMaxLevel indicates when going down into the tree, the max level is
  55. // reached
  56. ErrMaxLevel = fmt.Errorf("max level reached")
  57. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  58. // virtual level is reached
  59. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  60. // ErrSnapshotNotEditable indicates when the tree is a non writable
  61. // snapshot, thus can not be modified
  62. ErrSnapshotNotEditable = fmt.Errorf("snapshot tree can not be edited")
  63. // ErrTreeNotEmpty indicates when the tree was expected to be empty and
  64. // it is not
  65. ErrTreeNotEmpty = fmt.Errorf("tree is not empty")
  66. )
  67. // Tree defines the struct that implements the MerkleTree functionalities
  68. type Tree struct {
  69. sync.Mutex
  70. db db.Database
  71. maxLevels int
  72. snapshotRoot []byte
  73. hashFunction HashFunction
  74. // TODO in the methods that use it, check if emptyHash param is len>0
  75. // (check if it has been initialized)
  76. emptyHash []byte
  77. dbg *dbgStats
  78. }
  79. // NewTree returns a new Tree, if there is a Tree still in the given database, it
  80. // will load it.
  81. func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, error) {
  82. wTx := database.WriteTx()
  83. defer wTx.Discard()
  84. t, err := NewTreeWithTx(wTx, database, maxLevels, hash)
  85. if err != nil {
  86. return nil, err
  87. }
  88. if err = wTx.Commit(); err != nil {
  89. return nil, err
  90. }
  91. return t, nil
  92. }
  93. // NewTreeWithTx returns a new Tree using the given db.WriteTx, which will not
  94. // be ccommited inside this method, if there is a Tree still in the given
  95. // database, it will load it.
  96. func NewTreeWithTx(wTx db.WriteTx, database db.Database,
  97. maxLevels int, hash HashFunction) (*Tree, error) {
  98. t := Tree{db: database, maxLevels: maxLevels, hashFunction: hash}
  99. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  100. _, err := wTx.Get(dbKeyRoot)
  101. if err == db.ErrKeyNotFound {
  102. // store new root 0 (empty)
  103. if err = wTx.Set(dbKeyRoot, t.emptyHash); err != nil {
  104. return nil, err
  105. }
  106. if err = t.setNLeafs(wTx, 0); err != nil {
  107. return nil, err
  108. }
  109. return &t, nil
  110. } else if err != nil {
  111. return nil, err
  112. }
  113. return &t, nil
  114. }
  115. // Root returns the root of the Tree
  116. func (t *Tree) Root() ([]byte, error) {
  117. rTx := t.db.ReadTx()
  118. defer rTx.Discard()
  119. return t.RootWithTx(rTx)
  120. }
  121. // RootWithTx returns the root of the Tree using the given db.ReadTx
  122. func (t *Tree) RootWithTx(rTx db.ReadTx) ([]byte, error) {
  123. // if snapshotRoot is defined, means that the tree is a snapshot, and
  124. // the root is not obtained from the db, but from the snapshotRoot
  125. // parameter
  126. if t.snapshotRoot != nil {
  127. return t.snapshotRoot, nil
  128. }
  129. // get db root
  130. return rTx.Get(dbKeyRoot)
  131. }
  132. func (t *Tree) setRoot(wTx db.WriteTx, root []byte) error {
  133. return wTx.Set(dbKeyRoot, root)
  134. }
  135. // HashFunction returns Tree.hashFunction
  136. func (t *Tree) HashFunction() HashFunction {
  137. return t.hashFunction
  138. }
  139. // editable returns true if the tree is editable, and false when is not
  140. // editable (because is a snapshot tree)
  141. func (t *Tree) editable() bool {
  142. return t.snapshotRoot == nil
  143. }
  144. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  145. // the indexes of the keys failed to add. Supports empty values as input
  146. // parameters, which is equivalent to 0 valued byte array.
  147. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  148. wTx := t.db.WriteTx()
  149. defer wTx.Discard()
  150. invalids, err := t.AddBatchWithTx(wTx, keys, values)
  151. if err != nil {
  152. return invalids, err
  153. }
  154. return invalids, wTx.Commit()
  155. }
  156. // AddBatchWithTx does the same than the AddBatch method, but allowing to pass
  157. // the db.WriteTx that is used. The db.WriteTx will not be committed inside
  158. // this method.
  159. func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, error) {
  160. t.Lock()
  161. defer t.Unlock()
  162. if !t.editable() {
  163. return nil, ErrSnapshotNotEditable
  164. }
  165. vt, err := t.loadVT(wTx)
  166. if err != nil {
  167. return nil, err
  168. }
  169. e := []byte{}
  170. // equal the number of keys & values
  171. if len(keys) > len(values) {
  172. // add missing values
  173. for i := len(values); i < len(keys); i++ {
  174. values = append(values, e)
  175. }
  176. } else if len(keys) < len(values) {
  177. // crop extra values
  178. values = values[:len(keys)]
  179. }
  180. invalids, err := vt.addBatch(keys, values)
  181. if err != nil {
  182. return nil, err
  183. }
  184. // once the VirtualTree is build, compute the hashes
  185. pairs, err := vt.computeHashes()
  186. if err != nil {
  187. // currently invalids in computeHashes are not counted,
  188. // but should not be needed, as if there is an error there is
  189. // nothing stored in the db and the error is returned
  190. return nil, err
  191. }
  192. // store pairs in db
  193. for i := 0; i < len(pairs); i++ {
  194. if err := wTx.Set(pairs[i][0], pairs[i][1]); err != nil {
  195. return nil, err
  196. }
  197. }
  198. // store root (from the vt) to db
  199. if vt.root != nil {
  200. if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil {
  201. return nil, err
  202. }
  203. }
  204. // update nLeafs
  205. if err := t.incNLeafs(wTx, len(keys)-len(invalids)); err != nil {
  206. return nil, err
  207. }
  208. return invalids, nil
  209. }
  210. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  211. // the same leafs.
  212. func (t *Tree) loadVT(rTx db.ReadTx) (vt, error) {
  213. vt := newVT(t.maxLevels, t.hashFunction)
  214. vt.params.dbg = t.dbg
  215. var callbackErr error
  216. err := t.IterateWithStopWithTx(rTx, nil, func(_ int, k, v []byte) bool {
  217. if v[0] != PrefixValueLeaf {
  218. return false
  219. }
  220. leafK, leafV := ReadLeafValue(v)
  221. if err := vt.add(0, leafK, leafV); err != nil {
  222. callbackErr = err
  223. return true
  224. }
  225. return false
  226. })
  227. if callbackErr != nil {
  228. return vt, callbackErr
  229. }
  230. return vt, err
  231. }
  232. // Add inserts the key-value into the Tree. If the inputs come from a
  233. // *big.Int, is expected that are represented by a Little-Endian byte array
  234. // (for circom compatibility).
  235. func (t *Tree) Add(k, v []byte) error {
  236. wTx := t.db.WriteTx()
  237. defer wTx.Discard()
  238. if err := t.AddWithTx(wTx, k, v); err != nil {
  239. return err
  240. }
  241. return wTx.Commit()
  242. }
  243. // AddWithTx does the same than the Add method, but allowing to pass the
  244. // db.WriteTx that is used. The db.WriteTx will not be committed inside this
  245. // method.
  246. func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error {
  247. t.Lock()
  248. defer t.Unlock()
  249. if !t.editable() {
  250. return ErrSnapshotNotEditable
  251. }
  252. root, err := t.RootWithTx(wTx)
  253. if err != nil {
  254. return err
  255. }
  256. root, err = t.add(wTx, root, 0, k, v) // add from level 0
  257. if err != nil {
  258. return err
  259. }
  260. // store root to db
  261. if err := t.setRoot(wTx, root); err != nil {
  262. return err
  263. }
  264. // update nLeafs
  265. if err = t.incNLeafs(wTx, 1); err != nil {
  266. return err
  267. }
  268. return nil
  269. }
  270. // keyPathFromKey returns the keyPath and checks that the key is not bigger
  271. // than maximum key length for the tree maxLevels size.
  272. // This is because if the key bits length is bigger than the maxLevels of the
  273. // tree, two different keys that their difference is at the end, will collision
  274. // in the same leaf of the tree (at the max depth).
  275. func keyPathFromKey(maxLevels int, k []byte) ([]byte, error) {
  276. maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
  277. if len(k) > maxKeyLen {
  278. return nil, fmt.Errorf("len(k) can not be bigger than ceil(maxLevels/8), where"+
  279. " len(k): %d, maxLevels: %d, max key len=ceil(maxLevels/8): %d. Might need"+
  280. " a bigger tree depth (maxLevels>=%d) in order to input keys of length %d",
  281. len(k), maxLevels, maxKeyLen, len(k)*8, len(k)) //nolint:gomnd
  282. }
  283. keyPath := make([]byte, maxKeyLen) //nolint:gomnd
  284. copy(keyPath[:], k)
  285. return keyPath, nil
  286. }
  287. func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) {
  288. keyPath, err := keyPathFromKey(t.maxLevels, k)
  289. if err != nil {
  290. return nil, err
  291. }
  292. path := getPath(t.maxLevels, keyPath)
  293. // go down to the leaf
  294. var siblings [][]byte
  295. _, _, siblings, err = t.down(wTx, k, root, siblings, path, fromLvl, false)
  296. if err != nil {
  297. return nil, err
  298. }
  299. leafKey, leafValue, err := t.newLeafValue(k, v)
  300. if err != nil {
  301. return nil, err
  302. }
  303. if err := wTx.Set(leafKey, leafValue); err != nil {
  304. return nil, err
  305. }
  306. // go up to the root
  307. if len(siblings) == 0 {
  308. // return the leafKey as root
  309. return leafKey, nil
  310. }
  311. root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl)
  312. if err != nil {
  313. return nil, err
  314. }
  315. return root, nil
  316. }
  317. // down goes down to the leaf recursively
  318. func (t *Tree) down(rTx db.ReadTx, newKey, currKey []byte, siblings [][]byte,
  319. path []bool, currLvl int, getLeaf bool) (
  320. []byte, []byte, [][]byte, error) {
  321. if currLvl > t.maxLevels {
  322. return nil, nil, nil, ErrMaxLevel
  323. }
  324. var err error
  325. var currValue []byte
  326. if bytes.Equal(currKey, t.emptyHash) {
  327. // empty value
  328. return currKey, emptyValue, siblings, nil
  329. }
  330. currValue, err = rTx.Get(currKey)
  331. if err != nil {
  332. return nil, nil, nil, err
  333. }
  334. switch currValue[0] {
  335. case PrefixValueEmpty: // empty
  336. fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n",
  337. hex.EncodeToString(newKey), hex.EncodeToString(currKey),
  338. currLvl, hex.EncodeToString(currValue))
  339. panic("This point should not be reached, as the 'if currKey==t.emptyHash'" +
  340. " above should avoid reaching this point. This panic is temporary" +
  341. " for reporting purposes, will be deleted in future versions." +
  342. " Please paste this log (including the previous log lines) in a" +
  343. " new issue: https://github.com/vocdoni/arbo/issues/new") // TMP
  344. case PrefixValueLeaf: // leaf
  345. if !bytes.Equal(currValue, emptyValue) {
  346. if getLeaf {
  347. return currKey, currValue, siblings, nil
  348. }
  349. oldLeafKey, _ := ReadLeafValue(currValue)
  350. if bytes.Equal(newKey, oldLeafKey) {
  351. return nil, nil, nil, ErrKeyAlreadyExists
  352. }
  353. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  354. // if len(oldLeafKey) > t.hashFunction.Len() { // WIP
  355. // return nil, nil, nil,
  356. // fmt.Errorf("len(oldLeafKey) > hashFunction.Len()")
  357. // }
  358. copy(oldLeafKeyFull[:], oldLeafKey)
  359. // if currKey is already used, go down until paths diverge
  360. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  361. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  362. if err != nil {
  363. return nil, nil, nil, err
  364. }
  365. }
  366. return currKey, currValue, siblings, nil
  367. case PrefixValueIntermediate: // intermediate
  368. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  369. return nil, nil, nil,
  370. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  371. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  372. }
  373. // collect siblings while going down
  374. if path[currLvl] {
  375. // right
  376. lChild, rChild := ReadIntermediateChilds(currValue)
  377. siblings = append(siblings, lChild)
  378. return t.down(rTx, newKey, rChild, siblings, path, currLvl+1, getLeaf)
  379. }
  380. // left
  381. lChild, rChild := ReadIntermediateChilds(currValue)
  382. siblings = append(siblings, rChild)
  383. return t.down(rTx, newKey, lChild, siblings, path, currLvl+1, getLeaf)
  384. default:
  385. return nil, nil, nil, ErrInvalidValuePrefix
  386. }
  387. }
  388. // downVirtually is used when in a leaf already exists, and a new leaf which
  389. // shares the path until the existing leaf is being added
  390. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  391. newPath []bool, currLvl int) ([][]byte, error) {
  392. var err error
  393. if currLvl > t.maxLevels-1 {
  394. return nil, ErrMaxVirtualLevel
  395. }
  396. if oldPath[currLvl] == newPath[currLvl] {
  397. siblings = append(siblings, t.emptyHash)
  398. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  399. if err != nil {
  400. return nil, err
  401. }
  402. return siblings, nil
  403. }
  404. // reached the divergence
  405. siblings = append(siblings, oldKey)
  406. return siblings, nil
  407. }
  408. // up goes up recursively updating the intermediate nodes
  409. func (t *Tree) up(wTx db.WriteTx, key []byte, siblings [][]byte, path []bool,
  410. currLvl, toLvl int) ([]byte, error) {
  411. var k, v []byte
  412. var err error
  413. if path[currLvl+toLvl] {
  414. k, v, err = t.newIntermediate(siblings[currLvl], key)
  415. if err != nil {
  416. return nil, err
  417. }
  418. } else {
  419. k, v, err = t.newIntermediate(key, siblings[currLvl])
  420. if err != nil {
  421. return nil, err
  422. }
  423. }
  424. // store k-v to db
  425. if err = wTx.Set(k, v); err != nil {
  426. return nil, err
  427. }
  428. if currLvl == 0 {
  429. // reached the root
  430. return k, nil
  431. }
  432. return t.up(wTx, k, siblings, path, currLvl-1, toLvl)
  433. }
  434. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  435. t.dbg.incHash()
  436. return newLeafValue(t.hashFunction, k, v)
  437. }
  438. // newLeafValue takes a key & value from a leaf, and computes the leaf hash,
  439. // which is used as the leaf key. And the value is the concatenation of the
  440. // inputed key & value. The output of this function is used as key-value to
  441. // store the leaf in the DB.
  442. // [ 1 byte | 1 byte | N bytes | M bytes ]
  443. // [ type of node | length of key | key | value ]
  444. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  445. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  446. if err != nil {
  447. return nil, nil, err
  448. }
  449. var leafValue []byte
  450. leafValue = append(leafValue, byte(PrefixValueLeaf))
  451. if len(k) > maxUint8 {
  452. return nil, nil, fmt.Errorf("newLeafValue: len(k) > %v", maxUint8)
  453. }
  454. leafValue = append(leafValue, byte(len(k)))
  455. leafValue = append(leafValue, k...)
  456. leafValue = append(leafValue, v...)
  457. return leafKey, leafValue, nil
  458. }
  459. // ReadLeafValue reads from a byte array the leaf key & value
  460. func ReadLeafValue(b []byte) ([]byte, []byte) {
  461. if len(b) < PrefixValueLen {
  462. return []byte{}, []byte{}
  463. }
  464. kLen := b[1]
  465. if len(b) < PrefixValueLen+int(kLen) {
  466. return []byte{}, []byte{}
  467. }
  468. k := b[PrefixValueLen : PrefixValueLen+kLen]
  469. v := b[PrefixValueLen+kLen:]
  470. return k, v
  471. }
  472. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  473. t.dbg.incHash()
  474. return newIntermediate(t.hashFunction, l, r)
  475. }
  476. // newIntermediate takes the left & right keys of a intermediate node, and
  477. // computes its hash. Returns the hash of the node, which is the node key, and a
  478. // byte array that contains the value (which contains the left & right child
  479. // keys) to store in the DB.
  480. // [ 1 byte | 1 byte | N bytes | N bytes ]
  481. // [ type of node | length of left key | left key | right key ]
  482. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  483. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  484. b[0] = PrefixValueIntermediate
  485. if len(l) > maxUint8 {
  486. return nil, nil, fmt.Errorf("newIntermediate: len(l) > %v", maxUint8)
  487. }
  488. b[1] = byte(len(l))
  489. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  490. copy(b[PrefixValueLen+hashFunc.Len():], r)
  491. key, err := hashFunc.Hash(l, r)
  492. if err != nil {
  493. return nil, nil, err
  494. }
  495. return key, b, nil
  496. }
  497. // ReadIntermediateChilds reads from a byte array the two childs keys
  498. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  499. if len(b) < PrefixValueLen {
  500. return []byte{}, []byte{}
  501. }
  502. lLen := b[1]
  503. if len(b) < PrefixValueLen+int(lLen) {
  504. return []byte{}, []byte{}
  505. }
  506. l := b[PrefixValueLen : PrefixValueLen+lLen]
  507. r := b[PrefixValueLen+lLen:]
  508. return l, r
  509. }
  510. func getPath(numLevels int, k []byte) []bool {
  511. path := make([]bool, numLevels)
  512. for n := 0; n < numLevels; n++ {
  513. path[n] = k[n/8]&(1<<(n%8)) != 0
  514. }
  515. return path
  516. }
  517. // Update updates the value for a given existing key. If the given key does not
  518. // exist, returns an error.
  519. func (t *Tree) Update(k, v []byte) error {
  520. wTx := t.db.WriteTx()
  521. defer wTx.Discard()
  522. if err := t.UpdateWithTx(wTx, k, v); err != nil {
  523. return err
  524. }
  525. return wTx.Commit()
  526. }
  527. // UpdateWithTx does the same than the Update method, but allowing to pass the
  528. // db.WriteTx that is used. The db.WriteTx will not be committed inside this
  529. // method.
  530. func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
  531. t.Lock()
  532. defer t.Unlock()
  533. if !t.editable() {
  534. return ErrSnapshotNotEditable
  535. }
  536. keyPath, err := keyPathFromKey(t.maxLevels, k)
  537. if err != nil {
  538. return err
  539. }
  540. path := getPath(t.maxLevels, keyPath)
  541. root, err := t.RootWithTx(wTx)
  542. if err != nil {
  543. return err
  544. }
  545. var siblings [][]byte
  546. _, valueAtBottom, siblings, err := t.down(wTx, k, root, siblings, path, 0, true)
  547. if err != nil {
  548. return err
  549. }
  550. oldKey, _ := ReadLeafValue(valueAtBottom)
  551. if !bytes.Equal(oldKey, k) {
  552. return ErrKeyNotFound
  553. }
  554. leafKey, leafValue, err := t.newLeafValue(k, v)
  555. if err != nil {
  556. return err
  557. }
  558. if err := wTx.Set(leafKey, leafValue); err != nil {
  559. return err
  560. }
  561. // go up to the root
  562. if len(siblings) == 0 {
  563. return t.setRoot(wTx, leafKey)
  564. }
  565. root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0)
  566. if err != nil {
  567. return err
  568. }
  569. // store root to db
  570. if err := t.setRoot(wTx, root); err != nil {
  571. return err
  572. }
  573. return nil
  574. }
  575. // GenProof generates a MerkleTree proof for the given key. The leaf value is
  576. // returned, together with the packed siblings of the proof, and a boolean
  577. // parameter that indicates if the proof is of existence (true) or not (false).
  578. func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
  579. rTx := t.db.ReadTx()
  580. defer rTx.Discard()
  581. return t.GenProofWithTx(rTx, k)
  582. }
  583. // GenProofWithTx does the same than the GenProof method, but allowing to pass
  584. // the db.ReadTx that is used.
  585. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) {
  586. keyPath, err := keyPathFromKey(t.maxLevels, k)
  587. if err != nil {
  588. return nil, nil, nil, false, err
  589. }
  590. path := getPath(t.maxLevels, keyPath)
  591. root, err := t.RootWithTx(rTx)
  592. if err != nil {
  593. return nil, nil, nil, false, err
  594. }
  595. // go down to the leaf
  596. var siblings [][]byte
  597. _, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true)
  598. if err != nil {
  599. return nil, nil, nil, false, err
  600. }
  601. s, err := PackSiblings(t.hashFunction, siblings)
  602. if err != nil {
  603. return nil, nil, nil, false, err
  604. }
  605. leafK, leafV := ReadLeafValue(value)
  606. if !bytes.Equal(k, leafK) {
  607. // key not in tree, proof of non-existence
  608. return leafK, leafV, s, false, nil
  609. }
  610. return leafK, leafV, s, true, nil
  611. }
  612. // PackSiblings packs the siblings into a byte array.
  613. // [ 2 byte | 2 byte | L bytes | S * N bytes ]
  614. // [ full length | bitmap length (L) | bitmap | N non-zero siblings ]
  615. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  616. // array. And S is the size of the output of the hash function used for the
  617. // Tree. The 2 2-byte that define the full length and bitmap length, are
  618. // encoded in little-endian.
  619. func PackSiblings(hashFunc HashFunction, siblings [][]byte) ([]byte, error) {
  620. var b []byte
  621. var bitmap []bool
  622. emptySibling := make([]byte, hashFunc.Len())
  623. for i := 0; i < len(siblings); i++ {
  624. if bytes.Equal(siblings[i], emptySibling) {
  625. bitmap = append(bitmap, false)
  626. } else {
  627. bitmap = append(bitmap, true)
  628. b = append(b, siblings[i]...)
  629. }
  630. }
  631. bitmapBytes := bitmapToBytes(bitmap)
  632. l := len(bitmapBytes)
  633. if l > maxUint16 {
  634. return nil, fmt.Errorf("PackSiblings: bitmapBytes length > %v", maxUint16)
  635. }
  636. fullLen := 4 + l + len(b) //nolint:gomnd
  637. if fullLen > maxUint16 {
  638. return nil, fmt.Errorf("PackSiblings: fullLen > %v", maxUint16)
  639. }
  640. res := make([]byte, fullLen)
  641. binary.LittleEndian.PutUint16(res[0:2], uint16(fullLen)) // set full length
  642. binary.LittleEndian.PutUint16(res[2:4], uint16(l)) // set the bitmapBytes length
  643. copy(res[4:4+l], bitmapBytes)
  644. copy(res[4+l:], b)
  645. return res, nil
  646. }
  647. // UnpackSiblings unpacks the siblings from a byte array.
  648. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  649. fullLen := binary.LittleEndian.Uint16(b[0:2])
  650. l := binary.LittleEndian.Uint16(b[2:4]) // bitmap bytes length
  651. if len(b) != int(fullLen) {
  652. return nil,
  653. fmt.Errorf("expected len: %d, current len: %d",
  654. fullLen, len(b))
  655. }
  656. bitmapBytes := b[4 : 4+l]
  657. bitmap := bytesToBitmap(bitmapBytes)
  658. siblingsBytes := b[4+l:]
  659. iSibl := 0
  660. emptySibl := make([]byte, hashFunc.Len())
  661. var siblings [][]byte
  662. for i := 0; i < len(bitmap); i++ {
  663. if iSibl >= len(siblingsBytes) {
  664. break
  665. }
  666. if bitmap[i] {
  667. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  668. iSibl += hashFunc.Len()
  669. } else {
  670. siblings = append(siblings, emptySibl)
  671. }
  672. }
  673. return siblings, nil
  674. }
  675. func bitmapToBytes(bitmap []bool) []byte {
  676. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  677. b := make([]byte, bitmapBytesLen)
  678. for i := 0; i < len(bitmap); i++ {
  679. if bitmap[i] {
  680. b[i/8] |= 1 << (i % 8)
  681. }
  682. }
  683. return b
  684. }
  685. func bytesToBitmap(b []byte) []bool {
  686. var bitmap []bool
  687. for i := 0; i < len(b); i++ {
  688. for j := 0; j < 8; j++ {
  689. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  690. }
  691. }
  692. return bitmap
  693. }
  694. // Get returns the value in the Tree for a given key. If the key is not found,
  695. // will return the error ErrKeyNotFound, and in the leafK & leafV parameters
  696. // will be placed the data found in the tree in the leaf that was on the path
  697. // going to the input key.
  698. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  699. rTx := t.db.ReadTx()
  700. defer rTx.Discard()
  701. return t.GetWithTx(rTx, k)
  702. }
  703. // GetWithTx does the same than the Get method, but allowing to pass the
  704. // db.ReadTx that is used. If the key is not found, will return the error
  705. // ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data
  706. // found in the tree in the leaf that was on the path going to the input key.
  707. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) {
  708. keyPath, err := keyPathFromKey(t.maxLevels, k)
  709. if err != nil {
  710. return nil, nil, err
  711. }
  712. path := getPath(t.maxLevels, keyPath)
  713. root, err := t.RootWithTx(rTx)
  714. if err != nil {
  715. return nil, nil, err
  716. }
  717. // go down to the leaf
  718. var siblings [][]byte
  719. _, value, _, err := t.down(rTx, k, root, siblings, path, 0, true)
  720. if err != nil {
  721. return nil, nil, err
  722. }
  723. leafK, leafV := ReadLeafValue(value)
  724. if !bytes.Equal(k, leafK) {
  725. return leafK, leafV, ErrKeyNotFound
  726. }
  727. return leafK, leafV, nil
  728. }
  729. // CheckProof verifies the given proof. The proof verification depends on the
  730. // HashFunction passed as parameter.
  731. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  732. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  733. if err != nil {
  734. return false, err
  735. }
  736. keyPath := make([]byte, len(siblings))
  737. copy(keyPath[:], k)
  738. key, _, err := newLeafValue(hashFunc, k, v)
  739. if err != nil {
  740. return false, err
  741. }
  742. path := getPath(len(siblings), keyPath)
  743. for i := len(siblings) - 1; i >= 0; i-- {
  744. if path[i] {
  745. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  746. if err != nil {
  747. return false, err
  748. }
  749. } else {
  750. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  751. if err != nil {
  752. return false, err
  753. }
  754. }
  755. }
  756. if bytes.Equal(key[:], root) {
  757. return true, nil
  758. }
  759. return false, nil
  760. }
  761. func (t *Tree) incNLeafs(wTx db.WriteTx, nLeafs int) error {
  762. oldNLeafs, err := t.GetNLeafsWithTx(wTx)
  763. if err != nil {
  764. return err
  765. }
  766. newNLeafs := oldNLeafs + nLeafs
  767. return t.setNLeafs(wTx, newNLeafs)
  768. }
  769. func (t *Tree) setNLeafs(wTx db.WriteTx, nLeafs int) error {
  770. b := make([]byte, 8)
  771. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  772. if err := wTx.Set(dbKeyNLeafs, b); err != nil {
  773. return err
  774. }
  775. return nil
  776. }
  777. // GetNLeafs returns the number of Leafs of the Tree.
  778. func (t *Tree) GetNLeafs() (int, error) {
  779. rTx := t.db.ReadTx()
  780. defer rTx.Discard()
  781. return t.GetNLeafsWithTx(rTx)
  782. }
  783. // GetNLeafsWithTx does the same than the GetNLeafs method, but allowing to
  784. // pass the db.ReadTx that is used.
  785. func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) {
  786. b, err := rTx.Get(dbKeyNLeafs)
  787. if err != nil {
  788. return 0, err
  789. }
  790. nLeafs := binary.LittleEndian.Uint64(b)
  791. return int(nLeafs), nil
  792. }
  793. // SetRoot sets the root to the given root
  794. func (t *Tree) SetRoot(root []byte) error {
  795. wTx := t.db.WriteTx()
  796. defer wTx.Discard()
  797. if err := t.SetRootWithTx(wTx, root); err != nil {
  798. return err
  799. }
  800. return wTx.Commit()
  801. }
  802. // SetRootWithTx sets the root to the given root using the given db.WriteTx
  803. func (t *Tree) SetRootWithTx(wTx db.WriteTx, root []byte) error {
  804. if !t.editable() {
  805. return ErrSnapshotNotEditable
  806. }
  807. if root == nil {
  808. return fmt.Errorf("can not SetRoot with nil root")
  809. }
  810. // check that the root exists in the db
  811. if !bytes.Equal(root, t.emptyHash) {
  812. if _, err := wTx.Get(root); err == ErrKeyNotFound {
  813. return fmt.Errorf("can not SetRoot with root %x, as it does not exist in the db", root)
  814. } else if err != nil {
  815. return err
  816. }
  817. }
  818. return wTx.Set(dbKeyRoot, root)
  819. }
  820. // Snapshot returns a read-only copy of the Tree from the given root
  821. func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) {
  822. // allow to define which root to use
  823. if fromRoot == nil {
  824. var err error
  825. fromRoot, err = t.Root()
  826. if err != nil {
  827. return nil, err
  828. }
  829. }
  830. rTx := t.db.ReadTx()
  831. defer rTx.Discard()
  832. // check that the root exists in the db
  833. if !bytes.Equal(fromRoot, t.emptyHash) {
  834. if _, err := rTx.Get(fromRoot); err == ErrKeyNotFound {
  835. return nil,
  836. fmt.Errorf("can not do a Snapshot with root %x, as it does not exist in the db",
  837. fromRoot)
  838. } else if err != nil {
  839. return nil, err
  840. }
  841. }
  842. return &Tree{
  843. db: t.db,
  844. maxLevels: t.maxLevels,
  845. snapshotRoot: fromRoot,
  846. emptyHash: t.emptyHash,
  847. hashFunction: t.hashFunction,
  848. dbg: t.dbg,
  849. }, nil
  850. }
  851. // Iterate iterates through the full Tree, executing the given function on each
  852. // node of the Tree.
  853. func (t *Tree) Iterate(fromRoot []byte, f func([]byte, []byte)) error {
  854. rTx := t.db.ReadTx()
  855. defer rTx.Discard()
  856. return t.IterateWithTx(rTx, fromRoot, f)
  857. }
  858. // IterateWithTx does the same than the Iterate method, but allowing to pass
  859. // the db.ReadTx that is used.
  860. func (t *Tree) IterateWithTx(rTx db.ReadTx, fromRoot []byte, f func([]byte, []byte)) error {
  861. // allow to define which root to use
  862. if fromRoot == nil {
  863. var err error
  864. fromRoot, err = t.RootWithTx(rTx)
  865. if err != nil {
  866. return err
  867. }
  868. }
  869. return t.iter(rTx, fromRoot, f)
  870. }
  871. // IterateWithStop does the same than Iterate, but with int for the current
  872. // level, and a boolean parameter used by the passed function, is to indicate to
  873. // stop iterating on the branch when the method returns 'true'.
  874. func (t *Tree) IterateWithStop(fromRoot []byte, f func(int, []byte, []byte) bool) error {
  875. rTx := t.db.ReadTx()
  876. defer rTx.Discard()
  877. // allow to define which root to use
  878. if fromRoot == nil {
  879. var err error
  880. fromRoot, err = t.RootWithTx(rTx)
  881. if err != nil {
  882. return err
  883. }
  884. }
  885. return t.iterWithStop(rTx, fromRoot, 0, f)
  886. }
  887. // IterateWithStopWithTx does the same than the IterateWithStop method, but
  888. // allowing to pass the db.ReadTx that is used.
  889. func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, fromRoot []byte,
  890. f func(int, []byte, []byte) bool) error {
  891. // allow to define which root to use
  892. if fromRoot == nil {
  893. var err error
  894. fromRoot, err = t.RootWithTx(rTx)
  895. if err != nil {
  896. return err
  897. }
  898. }
  899. return t.iterWithStop(rTx, fromRoot, 0, f)
  900. }
  901. func (t *Tree) iterWithStop(rTx db.ReadTx, k []byte, currLevel int,
  902. f func(int, []byte, []byte) bool) error {
  903. var v []byte
  904. var err error
  905. if bytes.Equal(k, t.emptyHash) {
  906. v = t.emptyHash
  907. } else {
  908. v, err = rTx.Get(k)
  909. if err != nil {
  910. return err
  911. }
  912. }
  913. currLevel++
  914. switch v[0] {
  915. case PrefixValueEmpty:
  916. f(currLevel, k, v)
  917. case PrefixValueLeaf:
  918. f(currLevel, k, v)
  919. case PrefixValueIntermediate:
  920. stop := f(currLevel, k, v)
  921. if stop {
  922. return nil
  923. }
  924. l, r := ReadIntermediateChilds(v)
  925. if err = t.iterWithStop(rTx, l, currLevel, f); err != nil {
  926. return err
  927. }
  928. if err = t.iterWithStop(rTx, r, currLevel, f); err != nil {
  929. return err
  930. }
  931. default:
  932. return ErrInvalidValuePrefix
  933. }
  934. return nil
  935. }
  936. func (t *Tree) iter(rTx db.ReadTx, k []byte, f func([]byte, []byte)) error {
  937. f2 := func(currLvl int, k, v []byte) bool {
  938. f(k, v)
  939. return false
  940. }
  941. return t.iterWithStop(rTx, k, 0, f2)
  942. }
  943. // Dump exports all the Tree leafs in a byte array of length:
  944. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  945. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  946. // [ len(k) | len(v) | key | value ]
  947. // Where S is the size of the output of the hash function used for the Tree.
  948. func (t *Tree) Dump(fromRoot []byte) ([]byte, error) {
  949. // allow to define which root to use
  950. if fromRoot == nil {
  951. var err error
  952. fromRoot, err = t.Root()
  953. if err != nil {
  954. return nil, err
  955. }
  956. }
  957. // WARNING current encoding only supports key & values of 255 bytes each
  958. // (due using only 1 byte for the length headers).
  959. var b []byte
  960. var callbackErr error
  961. err := t.IterateWithStop(fromRoot, func(_ int, k, v []byte) bool {
  962. if v[0] != PrefixValueLeaf {
  963. return false
  964. }
  965. leafK, leafV := ReadLeafValue(v)
  966. kv := make([]byte, 2+len(leafK)+len(leafV))
  967. if len(leafK) > maxUint8 {
  968. callbackErr = fmt.Errorf("len(leafK) > %v", maxUint8)
  969. return true
  970. }
  971. kv[0] = byte(len(leafK))
  972. if len(leafV) > maxUint8 {
  973. callbackErr = fmt.Errorf("len(leafV) > %v", maxUint8)
  974. return true
  975. }
  976. kv[1] = byte(len(leafV))
  977. copy(kv[2:2+len(leafK)], leafK)
  978. copy(kv[2+len(leafK):], leafV)
  979. b = append(b, kv...)
  980. return false
  981. })
  982. if callbackErr != nil {
  983. return nil, callbackErr
  984. }
  985. return b, err
  986. }
  987. // ImportDump imports the leafs (that have been exported with the Dump method)
  988. // in the Tree.
  989. func (t *Tree) ImportDump(b []byte) error {
  990. if !t.editable() {
  991. return ErrSnapshotNotEditable
  992. }
  993. root, err := t.Root()
  994. if err != nil {
  995. return err
  996. }
  997. if !bytes.Equal(root, t.emptyHash) {
  998. return ErrTreeNotEmpty
  999. }
  1000. r := bytes.NewReader(b)
  1001. var keys, values [][]byte
  1002. for {
  1003. l := make([]byte, 2)
  1004. _, err = io.ReadFull(r, l)
  1005. if err == io.EOF {
  1006. break
  1007. } else if err != nil {
  1008. return err
  1009. }
  1010. k := make([]byte, l[0])
  1011. _, err = io.ReadFull(r, k)
  1012. if err != nil {
  1013. return err
  1014. }
  1015. v := make([]byte, l[1])
  1016. _, err = io.ReadFull(r, v)
  1017. if err != nil {
  1018. return err
  1019. }
  1020. keys = append(keys, k)
  1021. values = append(values, v)
  1022. }
  1023. if _, err = t.AddBatch(keys, values); err != nil {
  1024. return err
  1025. }
  1026. return nil
  1027. }
  1028. // Graphviz iterates across the full tree to generate a string Graphviz
  1029. // representation of the tree and writes it to w
  1030. func (t *Tree) Graphviz(w io.Writer, fromRoot []byte) error {
  1031. return t.GraphvizFirstNLevels(w, fromRoot, t.maxLevels)
  1032. }
  1033. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  1034. // generate a string Graphviz representation of the first NLevels of the tree
  1035. // and writes it to w
  1036. func (t *Tree) GraphvizFirstNLevels(w io.Writer, fromRoot []byte, untilLvl int) error {
  1037. fmt.Fprintf(w, `digraph hierarchy {
  1038. node [fontname=Monospace,fontsize=10,shape=box]
  1039. `)
  1040. rTx := t.db.ReadTx()
  1041. defer rTx.Discard()
  1042. if fromRoot == nil {
  1043. var err error
  1044. fromRoot, err = t.RootWithTx(rTx)
  1045. if err != nil {
  1046. return err
  1047. }
  1048. }
  1049. nEmpties := 0
  1050. err := t.iterWithStop(rTx, fromRoot, 0, func(currLvl int, k, v []byte) bool {
  1051. if currLvl == untilLvl {
  1052. return true // to stop the iter from going down
  1053. }
  1054. switch v[0] {
  1055. case PrefixValueEmpty:
  1056. case PrefixValueLeaf:
  1057. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  1058. // key & value from the leaf
  1059. kB, vB := ReadLeafValue(v)
  1060. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  1061. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  1062. hex.EncodeToString(vB[:nChars]))
  1063. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  1064. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  1065. case PrefixValueIntermediate:
  1066. l, r := ReadIntermediateChilds(v)
  1067. lStr := hex.EncodeToString(l[:nChars])
  1068. rStr := hex.EncodeToString(r[:nChars])
  1069. eStr := ""
  1070. if bytes.Equal(l, t.emptyHash) {
  1071. lStr = fmt.Sprintf("empty%v", nEmpties)
  1072. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  1073. lStr)
  1074. nEmpties++
  1075. }
  1076. if bytes.Equal(r, t.emptyHash) {
  1077. rStr = fmt.Sprintf("empty%v", nEmpties)
  1078. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  1079. rStr)
  1080. nEmpties++
  1081. }
  1082. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  1083. lStr, rStr)
  1084. fmt.Fprint(w, eStr)
  1085. default:
  1086. }
  1087. return false
  1088. })
  1089. fmt.Fprintf(w, "}\n")
  1090. return err
  1091. }
  1092. // PrintGraphviz prints the output of Tree.Graphviz
  1093. func (t *Tree) PrintGraphviz(fromRoot []byte) error {
  1094. if fromRoot == nil {
  1095. var err error
  1096. fromRoot, err = t.Root()
  1097. if err != nil {
  1098. return err
  1099. }
  1100. }
  1101. return t.PrintGraphvizFirstNLevels(fromRoot, t.maxLevels)
  1102. }
  1103. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  1104. func (t *Tree) PrintGraphvizFirstNLevels(fromRoot []byte, untilLvl int) error {
  1105. if fromRoot == nil {
  1106. var err error
  1107. fromRoot, err = t.Root()
  1108. if err != nil {
  1109. return err
  1110. }
  1111. }
  1112. w := bytes.NewBufferString("")
  1113. fmt.Fprintf(w,
  1114. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+":\n")
  1115. err := t.GraphvizFirstNLevels(w, fromRoot, untilLvl)
  1116. if err != nil {
  1117. fmt.Println(w)
  1118. return err
  1119. }
  1120. fmt.Fprintf(w,
  1121. "End of Graphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+"\n--------\n")
  1122. fmt.Println(w)
  1123. return nil
  1124. }
  1125. // TODO circom proofs
  1126. // TODO data structure for proofs (including root, key, value, siblings,
  1127. // hashFunction) + method to verify that data structure