You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1277 lines
34 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree, following the specification from
  4. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  5. https://eprint.iacr.org/2018/955.
  6. Allows to define which hash function to use. So for example, when working with
  7. zkSnarks the Poseidon hash function can be used, but when not, it can be used
  8. the Blake2b hash function, which has much faster computation time.
  9. */
  10. package arbo
  11. import (
  12. "bytes"
  13. "encoding/binary"
  14. "encoding/hex"
  15. "fmt"
  16. "io"
  17. "math"
  18. "sync"
  19. "go.vocdoni.io/dvote/db"
  20. )
  21. const (
  22. // PrefixValueLen defines the bytes-prefix length used for the Value
  23. // bytes representation stored in the db
  24. PrefixValueLen = 2
  25. // PrefixValueEmpty is used for the first byte of a Value to indicate
  26. // that is an Empty value
  27. PrefixValueEmpty = 0
  28. // PrefixValueLeaf is used for the first byte of a Value to indicate
  29. // that is a Leaf value
  30. PrefixValueLeaf = 1
  31. // PrefixValueIntermediate is used for the first byte of a Value to
  32. // indicate that is a Intermediate value
  33. PrefixValueIntermediate = 2
  34. // nChars is used to crop the Graphviz nodes labels
  35. nChars = 4
  36. maxUint8 = int(^uint8(0)) // 2**8 -1
  37. maxUint16 = int(^uint16(0)) // 2**16 -1
  38. )
  39. var (
  40. dbKeyRoot = []byte("root")
  41. dbKeyNLeafs = []byte("nleafs")
  42. emptyValue = []byte{0}
  43. // ErrKeyNotFound is used when a key is not found in the db neither in
  44. // the current db Batch.
  45. ErrKeyNotFound = fmt.Errorf("key not found")
  46. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  47. // tree that already exists.
  48. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  49. // ErrInvalidValuePrefix is used when going down into the tree, a value
  50. // is read from the db and has an unrecognized prefix.
  51. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  52. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.dbBatch==nil
  53. ErrDBNoTx = fmt.Errorf("dbPut error: no db Batch")
  54. // ErrMaxLevel indicates when going down into the tree, the max level is
  55. // reached
  56. ErrMaxLevel = fmt.Errorf("max level reached")
  57. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  58. // virtual level is reached
  59. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  60. // ErrSnapshotNotEditable indicates when the tree is a non writable
  61. // snapshot, thus can not be modified
  62. ErrSnapshotNotEditable = fmt.Errorf("snapshot tree can not be edited")
  63. // ErrTreeNotEmpty indicates when the tree was expected to be empty and
  64. // it is not
  65. ErrTreeNotEmpty = fmt.Errorf("tree is not empty")
  66. )
  67. // Tree defines the struct that implements the MerkleTree functionalities
  68. type Tree struct {
  69. sync.Mutex
  70. db db.Database
  71. maxLevels int
  72. snapshotRoot []byte
  73. hashFunction HashFunction
  74. // TODO in the methods that use it, check if emptyHash param is len>0
  75. // (check if it has been initialized)
  76. emptyHash []byte
  77. dbg *dbgStats
  78. }
  79. // NewTree returns a new Tree, if there is a Tree still in the given database, it
  80. // will load it.
  81. func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, error) {
  82. wTx := database.WriteTx()
  83. defer wTx.Discard()
  84. t, err := NewTreeWithTx(wTx, database, maxLevels, hash)
  85. if err != nil {
  86. return nil, err
  87. }
  88. if err = wTx.Commit(); err != nil {
  89. return nil, err
  90. }
  91. return t, nil
  92. }
  93. // NewTreeWithTx returns a new Tree using the given db.WriteTx, which will not
  94. // be ccommited inside this method, if there is a Tree still in the given
  95. // database, it will load it.
  96. func NewTreeWithTx(wTx db.WriteTx, database db.Database,
  97. maxLevels int, hash HashFunction) (*Tree, error) {
  98. t := Tree{db: database, maxLevels: maxLevels, hashFunction: hash}
  99. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  100. _, err := wTx.Get(dbKeyRoot)
  101. if err == db.ErrKeyNotFound {
  102. // store new root 0 (empty)
  103. if err = wTx.Set(dbKeyRoot, t.emptyHash); err != nil {
  104. return nil, err
  105. }
  106. if err = t.setNLeafs(wTx, 0); err != nil {
  107. return nil, err
  108. }
  109. return &t, nil
  110. } else if err != nil {
  111. return nil, err
  112. }
  113. return &t, nil
  114. }
  115. // Root returns the root of the Tree
  116. func (t *Tree) Root() ([]byte, error) {
  117. rTx := t.db.ReadTx()
  118. defer rTx.Discard()
  119. return t.RootWithTx(rTx)
  120. }
  121. // RootWithTx returns the root of the Tree using the given db.ReadTx
  122. func (t *Tree) RootWithTx(rTx db.ReadTx) ([]byte, error) {
  123. // if snapshotRoot is defined, means that the tree is a snapshot, and
  124. // the root is not obtained from the db, but from the snapshotRoot
  125. // parameter
  126. if t.snapshotRoot != nil {
  127. return t.snapshotRoot, nil
  128. }
  129. // get db root
  130. return rTx.Get(dbKeyRoot)
  131. }
  132. func (t *Tree) setRoot(wTx db.WriteTx, root []byte) error {
  133. return wTx.Set(dbKeyRoot, root)
  134. }
  135. // HashFunction returns Tree.hashFunction
  136. func (t *Tree) HashFunction() HashFunction {
  137. return t.hashFunction
  138. }
  139. // editable returns true if the tree is editable, and false when is not
  140. // editable (because is a snapshot tree)
  141. func (t *Tree) editable() bool {
  142. return t.snapshotRoot == nil
  143. }
  144. // Invalid is used when a key-value can not be added trough AddBatch, and
  145. // contains the index of the key-value and the error.
  146. type Invalid struct {
  147. Index int
  148. Error error
  149. }
  150. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  151. // the indexes of the keys failed to add. Supports empty values as input
  152. // parameters, which is equivalent to 0 valued byte array.
  153. func (t *Tree) AddBatch(keys, values [][]byte) ([]Invalid, error) {
  154. wTx := t.db.WriteTx()
  155. defer wTx.Discard()
  156. invalids, err := t.AddBatchWithTx(wTx, keys, values)
  157. if err != nil {
  158. return invalids, err
  159. }
  160. return invalids, wTx.Commit()
  161. }
  162. // AddBatchWithTx does the same than the AddBatch method, but allowing to pass
  163. // the db.WriteTx that is used. The db.WriteTx will not be committed inside
  164. // this method.
  165. func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]Invalid, error) {
  166. t.Lock()
  167. defer t.Unlock()
  168. if !t.editable() {
  169. return nil, ErrSnapshotNotEditable
  170. }
  171. vt, err := t.loadVT(wTx)
  172. if err != nil {
  173. return nil, err
  174. }
  175. e := []byte{}
  176. // equal the number of keys & values
  177. if len(keys) > len(values) {
  178. // add missing values
  179. for i := len(values); i < len(keys); i++ {
  180. values = append(values, e)
  181. }
  182. } else if len(keys) < len(values) {
  183. // crop extra values
  184. values = values[:len(keys)]
  185. }
  186. invalids, err := vt.addBatch(keys, values)
  187. if err != nil {
  188. return nil, err
  189. }
  190. // once the VirtualTree is build, compute the hashes
  191. pairs, err := vt.computeHashes()
  192. if err != nil {
  193. // currently invalids in computeHashes are not counted,
  194. // but should not be needed, as if there is an error there is
  195. // nothing stored in the db and the error is returned
  196. return nil, err
  197. }
  198. // store pairs in db
  199. for i := 0; i < len(pairs); i++ {
  200. if err := wTx.Set(pairs[i][0], pairs[i][1]); err != nil {
  201. return nil, err
  202. }
  203. }
  204. // store root (from the vt) to db
  205. if vt.root != nil {
  206. if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil {
  207. return nil, err
  208. }
  209. }
  210. // update nLeafs
  211. if err := t.incNLeafs(wTx, len(keys)-len(invalids)); err != nil {
  212. return nil, err
  213. }
  214. return invalids, nil
  215. }
  216. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  217. // the same leafs.
  218. func (t *Tree) loadVT(rTx db.ReadTx) (vt, error) {
  219. vt := newVT(t.maxLevels, t.hashFunction)
  220. vt.params.dbg = t.dbg
  221. var callbackErr error
  222. err := t.IterateWithStopWithTx(rTx, nil, func(_ int, k, v []byte) bool {
  223. if v[0] != PrefixValueLeaf {
  224. return false
  225. }
  226. leafK, leafV := ReadLeafValue(v)
  227. if err := vt.add(0, leafK, leafV); err != nil {
  228. callbackErr = err
  229. return true
  230. }
  231. return false
  232. })
  233. if callbackErr != nil {
  234. return vt, callbackErr
  235. }
  236. return vt, err
  237. }
  238. // Add inserts the key-value into the Tree. If the inputs come from a
  239. // *big.Int, is expected that are represented by a Little-Endian byte array
  240. // (for circom compatibility).
  241. func (t *Tree) Add(k, v []byte) error {
  242. wTx := t.db.WriteTx()
  243. defer wTx.Discard()
  244. if err := t.AddWithTx(wTx, k, v); err != nil {
  245. return err
  246. }
  247. return wTx.Commit()
  248. }
  249. // AddWithTx does the same than the Add method, but allowing to pass the
  250. // db.WriteTx that is used. The db.WriteTx will not be committed inside this
  251. // method.
  252. func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error {
  253. t.Lock()
  254. defer t.Unlock()
  255. if !t.editable() {
  256. return ErrSnapshotNotEditable
  257. }
  258. root, err := t.RootWithTx(wTx)
  259. if err != nil {
  260. return err
  261. }
  262. root, err = t.add(wTx, root, 0, k, v) // add from level 0
  263. if err != nil {
  264. return err
  265. }
  266. // store root to db
  267. if err := t.setRoot(wTx, root); err != nil {
  268. return err
  269. }
  270. // update nLeafs
  271. if err = t.incNLeafs(wTx, 1); err != nil {
  272. return err
  273. }
  274. return nil
  275. }
  276. // keyPathFromKey returns the keyPath and checks that the key is not bigger
  277. // than maximum key length for the tree maxLevels size.
  278. // This is because if the key bits length is bigger than the maxLevels of the
  279. // tree, two different keys that their difference is at the end, will collision
  280. // in the same leaf of the tree (at the max depth).
  281. func keyPathFromKey(maxLevels int, k []byte) ([]byte, error) {
  282. maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
  283. if len(k) > maxKeyLen {
  284. return nil, fmt.Errorf("len(k) can not be bigger than ceil(maxLevels/8), where"+
  285. " len(k): %d, maxLevels: %d, max key len=ceil(maxLevels/8): %d. Might need"+
  286. " a bigger tree depth (maxLevels>=%d) in order to input keys of length %d",
  287. len(k), maxLevels, maxKeyLen, len(k)*8, len(k)) //nolint:gomnd
  288. }
  289. keyPath := make([]byte, maxKeyLen) //nolint:gomnd
  290. copy(keyPath[:], k)
  291. return keyPath, nil
  292. }
  293. func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) {
  294. keyPath, err := keyPathFromKey(t.maxLevels, k)
  295. if err != nil {
  296. return nil, err
  297. }
  298. path := getPath(t.maxLevels, keyPath)
  299. // go down to the leaf
  300. var siblings [][]byte
  301. _, _, siblings, err = t.down(wTx, k, root, siblings, path, fromLvl, false)
  302. if err != nil {
  303. return nil, err
  304. }
  305. leafKey, leafValue, err := t.newLeafValue(k, v)
  306. if err != nil {
  307. return nil, err
  308. }
  309. if err := wTx.Set(leafKey, leafValue); err != nil {
  310. return nil, err
  311. }
  312. // go up to the root
  313. if len(siblings) == 0 {
  314. // return the leafKey as root
  315. return leafKey, nil
  316. }
  317. root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, fromLvl)
  318. if err != nil {
  319. return nil, err
  320. }
  321. return root, nil
  322. }
  323. // down goes down to the leaf recursively
  324. func (t *Tree) down(rTx db.ReadTx, newKey, currKey []byte, siblings [][]byte,
  325. path []bool, currLvl int, getLeaf bool) (
  326. []byte, []byte, [][]byte, error) {
  327. if currLvl > t.maxLevels {
  328. return nil, nil, nil, ErrMaxLevel
  329. }
  330. var err error
  331. var currValue []byte
  332. if bytes.Equal(currKey, t.emptyHash) {
  333. // empty value
  334. return currKey, emptyValue, siblings, nil
  335. }
  336. currValue, err = rTx.Get(currKey)
  337. if err != nil {
  338. return nil, nil, nil, err
  339. }
  340. switch currValue[0] {
  341. case PrefixValueEmpty: // empty
  342. fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n",
  343. hex.EncodeToString(newKey), hex.EncodeToString(currKey),
  344. currLvl, hex.EncodeToString(currValue))
  345. panic("This point should not be reached, as the 'if currKey==t.emptyHash'" +
  346. " above should avoid reaching this point. This panic is temporary" +
  347. " for reporting purposes, will be deleted in future versions." +
  348. " Please paste this log (including the previous log lines) in a" +
  349. " new issue: https://github.com/vocdoni/arbo/issues/new") // TMP
  350. case PrefixValueLeaf: // leaf
  351. if !bytes.Equal(currValue, emptyValue) {
  352. if getLeaf {
  353. return currKey, currValue, siblings, nil
  354. }
  355. oldLeafKey, _ := ReadLeafValue(currValue)
  356. if bytes.Equal(newKey, oldLeafKey) {
  357. return nil, nil, nil, ErrKeyAlreadyExists
  358. }
  359. oldLeafKeyFull, err := keyPathFromKey(t.maxLevels, oldLeafKey)
  360. if err != nil {
  361. return nil, nil, nil, err
  362. }
  363. // if currKey is already used, go down until paths diverge
  364. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  365. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  366. if err != nil {
  367. return nil, nil, nil, err
  368. }
  369. }
  370. return currKey, currValue, siblings, nil
  371. case PrefixValueIntermediate: // intermediate
  372. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  373. return nil, nil, nil,
  374. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  375. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  376. }
  377. // collect siblings while going down
  378. if path[currLvl] {
  379. // right
  380. lChild, rChild := ReadIntermediateChilds(currValue)
  381. siblings = append(siblings, lChild)
  382. return t.down(rTx, newKey, rChild, siblings, path, currLvl+1, getLeaf)
  383. }
  384. // left
  385. lChild, rChild := ReadIntermediateChilds(currValue)
  386. siblings = append(siblings, rChild)
  387. return t.down(rTx, newKey, lChild, siblings, path, currLvl+1, getLeaf)
  388. default:
  389. return nil, nil, nil, ErrInvalidValuePrefix
  390. }
  391. }
  392. // downVirtually is used when in a leaf already exists, and a new leaf which
  393. // shares the path until the existing leaf is being added
  394. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  395. newPath []bool, currLvl int) ([][]byte, error) {
  396. var err error
  397. if currLvl > t.maxLevels-1 {
  398. return nil, ErrMaxVirtualLevel
  399. }
  400. if oldPath[currLvl] == newPath[currLvl] {
  401. siblings = append(siblings, t.emptyHash)
  402. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  403. if err != nil {
  404. return nil, err
  405. }
  406. return siblings, nil
  407. }
  408. // reached the divergence
  409. siblings = append(siblings, oldKey)
  410. return siblings, nil
  411. }
  412. // up goes up recursively updating the intermediate nodes
  413. func (t *Tree) up(wTx db.WriteTx, key []byte, siblings [][]byte, path []bool,
  414. currLvl, toLvl int) ([]byte, error) {
  415. var k, v []byte
  416. var err error
  417. if path[currLvl+toLvl] {
  418. k, v, err = t.newIntermediate(siblings[currLvl], key)
  419. if err != nil {
  420. return nil, err
  421. }
  422. } else {
  423. k, v, err = t.newIntermediate(key, siblings[currLvl])
  424. if err != nil {
  425. return nil, err
  426. }
  427. }
  428. // store k-v to db
  429. if err = wTx.Set(k, v); err != nil {
  430. return nil, err
  431. }
  432. if currLvl == 0 {
  433. // reached the root
  434. return k, nil
  435. }
  436. return t.up(wTx, k, siblings, path, currLvl-1, toLvl)
  437. }
  438. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  439. t.dbg.incHash()
  440. return newLeafValue(t.hashFunction, k, v)
  441. }
  442. // newLeafValue takes a key & value from a leaf, and computes the leaf hash,
  443. // which is used as the leaf key. And the value is the concatenation of the
  444. // inputed key & value. The output of this function is used as key-value to
  445. // store the leaf in the DB.
  446. // [ 1 byte | 1 byte | N bytes | M bytes ]
  447. // [ type of node | length of key | key | value ]
  448. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  449. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  450. if err != nil {
  451. return nil, nil, err
  452. }
  453. var leafValue []byte
  454. leafValue = append(leafValue, byte(PrefixValueLeaf))
  455. if len(k) > maxUint8 {
  456. return nil, nil, fmt.Errorf("newLeafValue: len(k) > %v", maxUint8)
  457. }
  458. leafValue = append(leafValue, byte(len(k)))
  459. leafValue = append(leafValue, k...)
  460. leafValue = append(leafValue, v...)
  461. return leafKey, leafValue, nil
  462. }
  463. // ReadLeafValue reads from a byte array the leaf key & value
  464. func ReadLeafValue(b []byte) ([]byte, []byte) {
  465. if len(b) < PrefixValueLen {
  466. return []byte{}, []byte{}
  467. }
  468. kLen := b[1]
  469. if len(b) < PrefixValueLen+int(kLen) {
  470. return []byte{}, []byte{}
  471. }
  472. k := b[PrefixValueLen : PrefixValueLen+kLen]
  473. v := b[PrefixValueLen+kLen:]
  474. return k, v
  475. }
  476. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  477. t.dbg.incHash()
  478. return newIntermediate(t.hashFunction, l, r)
  479. }
  480. // newIntermediate takes the left & right keys of a intermediate node, and
  481. // computes its hash. Returns the hash of the node, which is the node key, and a
  482. // byte array that contains the value (which contains the left & right child
  483. // keys) to store in the DB.
  484. // [ 1 byte | 1 byte | N bytes | N bytes ]
  485. // [ type of node | length of left key | left key | right key ]
  486. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  487. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  488. b[0] = PrefixValueIntermediate
  489. if len(l) > maxUint8 {
  490. return nil, nil, fmt.Errorf("newIntermediate: len(l) > %v", maxUint8)
  491. }
  492. b[1] = byte(len(l))
  493. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  494. copy(b[PrefixValueLen+hashFunc.Len():], r)
  495. key, err := hashFunc.Hash(l, r)
  496. if err != nil {
  497. return nil, nil, err
  498. }
  499. return key, b, nil
  500. }
  501. // ReadIntermediateChilds reads from a byte array the two childs keys
  502. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  503. if len(b) < PrefixValueLen {
  504. return []byte{}, []byte{}
  505. }
  506. lLen := b[1]
  507. if len(b) < PrefixValueLen+int(lLen) {
  508. return []byte{}, []byte{}
  509. }
  510. l := b[PrefixValueLen : PrefixValueLen+lLen]
  511. r := b[PrefixValueLen+lLen:]
  512. return l, r
  513. }
  514. func getPath(numLevels int, k []byte) []bool {
  515. path := make([]bool, numLevels)
  516. for n := 0; n < numLevels; n++ {
  517. path[n] = k[n/8]&(1<<(n%8)) != 0
  518. }
  519. return path
  520. }
  521. // Update updates the value for a given existing key. If the given key does not
  522. // exist, returns an error.
  523. func (t *Tree) Update(k, v []byte) error {
  524. wTx := t.db.WriteTx()
  525. defer wTx.Discard()
  526. if err := t.UpdateWithTx(wTx, k, v); err != nil {
  527. return err
  528. }
  529. return wTx.Commit()
  530. }
  531. // UpdateWithTx does the same than the Update method, but allowing to pass the
  532. // db.WriteTx that is used. The db.WriteTx will not be committed inside this
  533. // method.
  534. func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
  535. t.Lock()
  536. defer t.Unlock()
  537. if !t.editable() {
  538. return ErrSnapshotNotEditable
  539. }
  540. keyPath, err := keyPathFromKey(t.maxLevels, k)
  541. if err != nil {
  542. return err
  543. }
  544. path := getPath(t.maxLevels, keyPath)
  545. root, err := t.RootWithTx(wTx)
  546. if err != nil {
  547. return err
  548. }
  549. var siblings [][]byte
  550. _, valueAtBottom, siblings, err := t.down(wTx, k, root, siblings, path, 0, true)
  551. if err != nil {
  552. return err
  553. }
  554. oldKey, _ := ReadLeafValue(valueAtBottom)
  555. if !bytes.Equal(oldKey, k) {
  556. return ErrKeyNotFound
  557. }
  558. leafKey, leafValue, err := t.newLeafValue(k, v)
  559. if err != nil {
  560. return err
  561. }
  562. if err := wTx.Set(leafKey, leafValue); err != nil {
  563. return err
  564. }
  565. // go up to the root
  566. if len(siblings) == 0 {
  567. return t.setRoot(wTx, leafKey)
  568. }
  569. root, err = t.up(wTx, leafKey, siblings, path, len(siblings)-1, 0)
  570. if err != nil {
  571. return err
  572. }
  573. // store root to db
  574. if err := t.setRoot(wTx, root); err != nil {
  575. return err
  576. }
  577. return nil
  578. }
  579. // GenProof generates a MerkleTree proof for the given key. The leaf value is
  580. // returned, together with the packed siblings of the proof, and a boolean
  581. // parameter that indicates if the proof is of existence (true) or not (false).
  582. func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
  583. rTx := t.db.ReadTx()
  584. defer rTx.Discard()
  585. return t.GenProofWithTx(rTx, k)
  586. }
  587. // GenProofWithTx does the same than the GenProof method, but allowing to pass
  588. // the db.ReadTx that is used.
  589. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) {
  590. keyPath, err := keyPathFromKey(t.maxLevels, k)
  591. if err != nil {
  592. return nil, nil, nil, false, err
  593. }
  594. path := getPath(t.maxLevels, keyPath)
  595. root, err := t.RootWithTx(rTx)
  596. if err != nil {
  597. return nil, nil, nil, false, err
  598. }
  599. // go down to the leaf
  600. var siblings [][]byte
  601. _, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true)
  602. if err != nil {
  603. return nil, nil, nil, false, err
  604. }
  605. s, err := PackSiblings(t.hashFunction, siblings)
  606. if err != nil {
  607. return nil, nil, nil, false, err
  608. }
  609. leafK, leafV := ReadLeafValue(value)
  610. if !bytes.Equal(k, leafK) {
  611. // key not in tree, proof of non-existence
  612. return leafK, leafV, s, false, nil
  613. }
  614. return leafK, leafV, s, true, nil
  615. }
  616. // PackSiblings packs the siblings into a byte array.
  617. // [ 2 byte | 2 byte | L bytes | S * N bytes ]
  618. // [ full length | bitmap length (L) | bitmap | N non-zero siblings ]
  619. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  620. // array. And S is the size of the output of the hash function used for the
  621. // Tree. The 2 2-byte that define the full length and bitmap length, are
  622. // encoded in little-endian.
  623. func PackSiblings(hashFunc HashFunction, siblings [][]byte) ([]byte, error) {
  624. var b []byte
  625. var bitmap []bool
  626. emptySibling := make([]byte, hashFunc.Len())
  627. for i := 0; i < len(siblings); i++ {
  628. if bytes.Equal(siblings[i], emptySibling) {
  629. bitmap = append(bitmap, false)
  630. } else {
  631. bitmap = append(bitmap, true)
  632. b = append(b, siblings[i]...)
  633. }
  634. }
  635. bitmapBytes := bitmapToBytes(bitmap)
  636. l := len(bitmapBytes)
  637. if l > maxUint16 {
  638. return nil, fmt.Errorf("PackSiblings: bitmapBytes length > %v", maxUint16)
  639. }
  640. fullLen := 4 + l + len(b) //nolint:gomnd
  641. if fullLen > maxUint16 {
  642. return nil, fmt.Errorf("PackSiblings: fullLen > %v", maxUint16)
  643. }
  644. res := make([]byte, fullLen)
  645. binary.LittleEndian.PutUint16(res[0:2], uint16(fullLen)) // set full length
  646. binary.LittleEndian.PutUint16(res[2:4], uint16(l)) // set the bitmapBytes length
  647. copy(res[4:4+l], bitmapBytes)
  648. copy(res[4+l:], b)
  649. return res, nil
  650. }
  651. // UnpackSiblings unpacks the siblings from a byte array.
  652. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  653. fullLen := binary.LittleEndian.Uint16(b[0:2])
  654. l := binary.LittleEndian.Uint16(b[2:4]) // bitmap bytes length
  655. if len(b) != int(fullLen) {
  656. return nil,
  657. fmt.Errorf("expected len: %d, current len: %d",
  658. fullLen, len(b))
  659. }
  660. bitmapBytes := b[4 : 4+l]
  661. bitmap := bytesToBitmap(bitmapBytes)
  662. siblingsBytes := b[4+l:]
  663. iSibl := 0
  664. emptySibl := make([]byte, hashFunc.Len())
  665. var siblings [][]byte
  666. for i := 0; i < len(bitmap); i++ {
  667. if iSibl >= len(siblingsBytes) {
  668. break
  669. }
  670. if bitmap[i] {
  671. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  672. iSibl += hashFunc.Len()
  673. } else {
  674. siblings = append(siblings, emptySibl)
  675. }
  676. }
  677. return siblings, nil
  678. }
  679. func bitmapToBytes(bitmap []bool) []byte {
  680. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  681. b := make([]byte, bitmapBytesLen)
  682. for i := 0; i < len(bitmap); i++ {
  683. if bitmap[i] {
  684. b[i/8] |= 1 << (i % 8)
  685. }
  686. }
  687. return b
  688. }
  689. func bytesToBitmap(b []byte) []bool {
  690. var bitmap []bool
  691. for i := 0; i < len(b); i++ {
  692. for j := 0; j < 8; j++ {
  693. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  694. }
  695. }
  696. return bitmap
  697. }
  698. // Get returns the value in the Tree for a given key. If the key is not found,
  699. // will return the error ErrKeyNotFound, and in the leafK & leafV parameters
  700. // will be placed the data found in the tree in the leaf that was on the path
  701. // going to the input key.
  702. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  703. rTx := t.db.ReadTx()
  704. defer rTx.Discard()
  705. return t.GetWithTx(rTx, k)
  706. }
  707. // GetWithTx does the same than the Get method, but allowing to pass the
  708. // db.ReadTx that is used. If the key is not found, will return the error
  709. // ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data
  710. // found in the tree in the leaf that was on the path going to the input key.
  711. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) {
  712. keyPath, err := keyPathFromKey(t.maxLevels, k)
  713. if err != nil {
  714. return nil, nil, err
  715. }
  716. path := getPath(t.maxLevels, keyPath)
  717. root, err := t.RootWithTx(rTx)
  718. if err != nil {
  719. return nil, nil, err
  720. }
  721. // go down to the leaf
  722. var siblings [][]byte
  723. _, value, _, err := t.down(rTx, k, root, siblings, path, 0, true)
  724. if err != nil {
  725. return nil, nil, err
  726. }
  727. leafK, leafV := ReadLeafValue(value)
  728. if !bytes.Equal(k, leafK) {
  729. return leafK, leafV, ErrKeyNotFound
  730. }
  731. return leafK, leafV, nil
  732. }
  733. // CheckProof verifies the given proof. The proof verification depends on the
  734. // HashFunction passed as parameter.
  735. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  736. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  737. if err != nil {
  738. return false, err
  739. }
  740. keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8)))) //nolint:gomnd
  741. copy(keyPath[:], k)
  742. key, _, err := newLeafValue(hashFunc, k, v)
  743. if err != nil {
  744. return false, err
  745. }
  746. path := getPath(len(siblings), keyPath)
  747. for i := len(siblings) - 1; i >= 0; i-- {
  748. if path[i] {
  749. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  750. if err != nil {
  751. return false, err
  752. }
  753. } else {
  754. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  755. if err != nil {
  756. return false, err
  757. }
  758. }
  759. }
  760. if bytes.Equal(key[:], root) {
  761. return true, nil
  762. }
  763. return false, nil
  764. }
  765. func (t *Tree) incNLeafs(wTx db.WriteTx, nLeafs int) error {
  766. oldNLeafs, err := t.GetNLeafsWithTx(wTx)
  767. if err != nil {
  768. return err
  769. }
  770. newNLeafs := oldNLeafs + nLeafs
  771. return t.setNLeafs(wTx, newNLeafs)
  772. }
  773. func (t *Tree) setNLeafs(wTx db.WriteTx, nLeafs int) error {
  774. b := make([]byte, 8)
  775. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  776. if err := wTx.Set(dbKeyNLeafs, b); err != nil {
  777. return err
  778. }
  779. return nil
  780. }
  781. // GetNLeafs returns the number of Leafs of the Tree.
  782. func (t *Tree) GetNLeafs() (int, error) {
  783. rTx := t.db.ReadTx()
  784. defer rTx.Discard()
  785. return t.GetNLeafsWithTx(rTx)
  786. }
  787. // GetNLeafsWithTx does the same than the GetNLeafs method, but allowing to
  788. // pass the db.ReadTx that is used.
  789. func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) {
  790. b, err := rTx.Get(dbKeyNLeafs)
  791. if err != nil {
  792. return 0, err
  793. }
  794. nLeafs := binary.LittleEndian.Uint64(b)
  795. return int(nLeafs), nil
  796. }
  797. // SetRoot sets the root to the given root
  798. func (t *Tree) SetRoot(root []byte) error {
  799. wTx := t.db.WriteTx()
  800. defer wTx.Discard()
  801. if err := t.SetRootWithTx(wTx, root); err != nil {
  802. return err
  803. }
  804. return wTx.Commit()
  805. }
  806. // SetRootWithTx sets the root to the given root using the given db.WriteTx
  807. func (t *Tree) SetRootWithTx(wTx db.WriteTx, root []byte) error {
  808. if !t.editable() {
  809. return ErrSnapshotNotEditable
  810. }
  811. if root == nil {
  812. return fmt.Errorf("can not SetRoot with nil root")
  813. }
  814. // check that the root exists in the db
  815. if !bytes.Equal(root, t.emptyHash) {
  816. if _, err := wTx.Get(root); err == ErrKeyNotFound {
  817. return fmt.Errorf("can not SetRoot with root %x, as it does not exist in the db", root)
  818. } else if err != nil {
  819. return err
  820. }
  821. }
  822. return wTx.Set(dbKeyRoot, root)
  823. }
  824. // Snapshot returns a read-only copy of the Tree from the given root
  825. func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) {
  826. // allow to define which root to use
  827. if fromRoot == nil {
  828. var err error
  829. fromRoot, err = t.Root()
  830. if err != nil {
  831. return nil, err
  832. }
  833. }
  834. rTx := t.db.ReadTx()
  835. defer rTx.Discard()
  836. // check that the root exists in the db
  837. if !bytes.Equal(fromRoot, t.emptyHash) {
  838. if _, err := rTx.Get(fromRoot); err == ErrKeyNotFound {
  839. return nil,
  840. fmt.Errorf("can not do a Snapshot with root %x, as it does not exist in the db",
  841. fromRoot)
  842. } else if err != nil {
  843. return nil, err
  844. }
  845. }
  846. return &Tree{
  847. db: t.db,
  848. maxLevels: t.maxLevels,
  849. snapshotRoot: fromRoot,
  850. emptyHash: t.emptyHash,
  851. hashFunction: t.hashFunction,
  852. dbg: t.dbg,
  853. }, nil
  854. }
  855. // Iterate iterates through the full Tree, executing the given function on each
  856. // node of the Tree.
  857. func (t *Tree) Iterate(fromRoot []byte, f func([]byte, []byte)) error {
  858. rTx := t.db.ReadTx()
  859. defer rTx.Discard()
  860. return t.IterateWithTx(rTx, fromRoot, f)
  861. }
  862. // IterateWithTx does the same than the Iterate method, but allowing to pass
  863. // the db.ReadTx that is used.
  864. func (t *Tree) IterateWithTx(rTx db.ReadTx, fromRoot []byte, f func([]byte, []byte)) error {
  865. // allow to define which root to use
  866. if fromRoot == nil {
  867. var err error
  868. fromRoot, err = t.RootWithTx(rTx)
  869. if err != nil {
  870. return err
  871. }
  872. }
  873. return t.iter(rTx, fromRoot, f)
  874. }
  875. // IterateWithStop does the same than Iterate, but with int for the current
  876. // level, and a boolean parameter used by the passed function, is to indicate to
  877. // stop iterating on the branch when the method returns 'true'.
  878. func (t *Tree) IterateWithStop(fromRoot []byte, f func(int, []byte, []byte) bool) error {
  879. rTx := t.db.ReadTx()
  880. defer rTx.Discard()
  881. // allow to define which root to use
  882. if fromRoot == nil {
  883. var err error
  884. fromRoot, err = t.RootWithTx(rTx)
  885. if err != nil {
  886. return err
  887. }
  888. }
  889. return t.iterWithStop(rTx, fromRoot, 0, f)
  890. }
  891. // IterateWithStopWithTx does the same than the IterateWithStop method, but
  892. // allowing to pass the db.ReadTx that is used.
  893. func (t *Tree) IterateWithStopWithTx(rTx db.ReadTx, fromRoot []byte,
  894. f func(int, []byte, []byte) bool) error {
  895. // allow to define which root to use
  896. if fromRoot == nil {
  897. var err error
  898. fromRoot, err = t.RootWithTx(rTx)
  899. if err != nil {
  900. return err
  901. }
  902. }
  903. return t.iterWithStop(rTx, fromRoot, 0, f)
  904. }
  905. func (t *Tree) iterWithStop(rTx db.ReadTx, k []byte, currLevel int,
  906. f func(int, []byte, []byte) bool) error {
  907. var v []byte
  908. var err error
  909. if bytes.Equal(k, t.emptyHash) {
  910. v = t.emptyHash
  911. } else {
  912. v, err = rTx.Get(k)
  913. if err != nil {
  914. return err
  915. }
  916. }
  917. currLevel++
  918. switch v[0] {
  919. case PrefixValueEmpty:
  920. f(currLevel, k, v)
  921. case PrefixValueLeaf:
  922. f(currLevel, k, v)
  923. case PrefixValueIntermediate:
  924. stop := f(currLevel, k, v)
  925. if stop {
  926. return nil
  927. }
  928. l, r := ReadIntermediateChilds(v)
  929. if err = t.iterWithStop(rTx, l, currLevel, f); err != nil {
  930. return err
  931. }
  932. if err = t.iterWithStop(rTx, r, currLevel, f); err != nil {
  933. return err
  934. }
  935. default:
  936. return ErrInvalidValuePrefix
  937. }
  938. return nil
  939. }
  940. func (t *Tree) iter(rTx db.ReadTx, k []byte, f func([]byte, []byte)) error {
  941. f2 := func(currLvl int, k, v []byte) bool {
  942. f(k, v)
  943. return false
  944. }
  945. return t.iterWithStop(rTx, k, 0, f2)
  946. }
  947. // Dump exports all the Tree leafs in a byte array of length:
  948. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  949. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  950. // [ len(k) | len(v) | key | value ]
  951. // Where S is the size of the output of the hash function used for the Tree.
  952. func (t *Tree) Dump(fromRoot []byte) ([]byte, error) {
  953. // allow to define which root to use
  954. if fromRoot == nil {
  955. var err error
  956. fromRoot, err = t.Root()
  957. if err != nil {
  958. return nil, err
  959. }
  960. }
  961. // WARNING current encoding only supports key & values of 255 bytes each
  962. // (due using only 1 byte for the length headers).
  963. var b []byte
  964. var callbackErr error
  965. err := t.IterateWithStop(fromRoot, func(_ int, k, v []byte) bool {
  966. if v[0] != PrefixValueLeaf {
  967. return false
  968. }
  969. leafK, leafV := ReadLeafValue(v)
  970. kv := make([]byte, 2+len(leafK)+len(leafV))
  971. if len(leafK) > maxUint8 {
  972. callbackErr = fmt.Errorf("len(leafK) > %v", maxUint8)
  973. return true
  974. }
  975. kv[0] = byte(len(leafK))
  976. if len(leafV) > maxUint8 {
  977. callbackErr = fmt.Errorf("len(leafV) > %v", maxUint8)
  978. return true
  979. }
  980. kv[1] = byte(len(leafV))
  981. copy(kv[2:2+len(leafK)], leafK)
  982. copy(kv[2+len(leafK):], leafV)
  983. b = append(b, kv...)
  984. return false
  985. })
  986. if callbackErr != nil {
  987. return nil, callbackErr
  988. }
  989. return b, err
  990. }
  991. // ImportDump imports the leafs (that have been exported with the Dump method)
  992. // in the Tree.
  993. func (t *Tree) ImportDump(b []byte) error {
  994. if !t.editable() {
  995. return ErrSnapshotNotEditable
  996. }
  997. root, err := t.Root()
  998. if err != nil {
  999. return err
  1000. }
  1001. if !bytes.Equal(root, t.emptyHash) {
  1002. return ErrTreeNotEmpty
  1003. }
  1004. r := bytes.NewReader(b)
  1005. var keys, values [][]byte
  1006. for {
  1007. l := make([]byte, 2)
  1008. _, err = io.ReadFull(r, l)
  1009. if err == io.EOF {
  1010. break
  1011. } else if err != nil {
  1012. return err
  1013. }
  1014. k := make([]byte, l[0])
  1015. _, err = io.ReadFull(r, k)
  1016. if err != nil {
  1017. return err
  1018. }
  1019. v := make([]byte, l[1])
  1020. _, err = io.ReadFull(r, v)
  1021. if err != nil {
  1022. return err
  1023. }
  1024. keys = append(keys, k)
  1025. values = append(values, v)
  1026. }
  1027. if _, err = t.AddBatch(keys, values); err != nil {
  1028. return err
  1029. }
  1030. return nil
  1031. }
  1032. // Graphviz iterates across the full tree to generate a string Graphviz
  1033. // representation of the tree and writes it to w
  1034. func (t *Tree) Graphviz(w io.Writer, fromRoot []byte) error {
  1035. return t.GraphvizFirstNLevels(w, fromRoot, t.maxLevels)
  1036. }
  1037. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  1038. // generate a string Graphviz representation of the first NLevels of the tree
  1039. // and writes it to w
  1040. func (t *Tree) GraphvizFirstNLevels(w io.Writer, fromRoot []byte, untilLvl int) error {
  1041. fmt.Fprintf(w, `digraph hierarchy {
  1042. node [fontname=Monospace,fontsize=10,shape=box]
  1043. `)
  1044. rTx := t.db.ReadTx()
  1045. defer rTx.Discard()
  1046. if fromRoot == nil {
  1047. var err error
  1048. fromRoot, err = t.RootWithTx(rTx)
  1049. if err != nil {
  1050. return err
  1051. }
  1052. }
  1053. nEmpties := 0
  1054. err := t.iterWithStop(rTx, fromRoot, 0, func(currLvl int, k, v []byte) bool {
  1055. if currLvl == untilLvl {
  1056. return true // to stop the iter from going down
  1057. }
  1058. switch v[0] {
  1059. case PrefixValueEmpty:
  1060. case PrefixValueLeaf:
  1061. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  1062. // key & value from the leaf
  1063. kB, vB := ReadLeafValue(v)
  1064. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  1065. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  1066. hex.EncodeToString(vB[:nChars]))
  1067. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  1068. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  1069. case PrefixValueIntermediate:
  1070. l, r := ReadIntermediateChilds(v)
  1071. lStr := hex.EncodeToString(l[:nChars])
  1072. rStr := hex.EncodeToString(r[:nChars])
  1073. eStr := ""
  1074. if bytes.Equal(l, t.emptyHash) {
  1075. lStr = fmt.Sprintf("empty%v", nEmpties)
  1076. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  1077. lStr)
  1078. nEmpties++
  1079. }
  1080. if bytes.Equal(r, t.emptyHash) {
  1081. rStr = fmt.Sprintf("empty%v", nEmpties)
  1082. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  1083. rStr)
  1084. nEmpties++
  1085. }
  1086. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  1087. lStr, rStr)
  1088. fmt.Fprint(w, eStr)
  1089. default:
  1090. }
  1091. return false
  1092. })
  1093. fmt.Fprintf(w, "}\n")
  1094. return err
  1095. }
  1096. // PrintGraphviz prints the output of Tree.Graphviz
  1097. func (t *Tree) PrintGraphviz(fromRoot []byte) error {
  1098. if fromRoot == nil {
  1099. var err error
  1100. fromRoot, err = t.Root()
  1101. if err != nil {
  1102. return err
  1103. }
  1104. }
  1105. return t.PrintGraphvizFirstNLevels(fromRoot, t.maxLevels)
  1106. }
  1107. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  1108. func (t *Tree) PrintGraphvizFirstNLevels(fromRoot []byte, untilLvl int) error {
  1109. if fromRoot == nil {
  1110. var err error
  1111. fromRoot, err = t.Root()
  1112. if err != nil {
  1113. return err
  1114. }
  1115. }
  1116. w := bytes.NewBufferString("")
  1117. fmt.Fprintf(w,
  1118. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+":\n")
  1119. err := t.GraphvizFirstNLevels(w, fromRoot, untilLvl)
  1120. if err != nil {
  1121. fmt.Println(w)
  1122. return err
  1123. }
  1124. fmt.Fprintf(w,
  1125. "End of Graphviz of the Tree with Root "+hex.EncodeToString(fromRoot)+"\n--------\n")
  1126. fmt.Println(w)
  1127. return nil
  1128. }
  1129. // TODO circom proofs
  1130. // TODO data structure for proofs (including root, key, value, siblings,
  1131. // hashFunction) + method to verify that data structure