You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

974 lines
26 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "crypto/sha256"
  15. "encoding/binary"
  16. "encoding/hex"
  17. "fmt"
  18. "io"
  19. "math"
  20. "sync"
  21. "go.vocdoni.io/dvote/db"
  22. )
  23. const (
  24. // PrefixValueLen defines the bytes-prefix length used for the Value
  25. // bytes representation stored in the db
  26. PrefixValueLen = 2
  27. // PrefixValueEmpty is used for the first byte of a Value to indicate
  28. // that is an Empty value
  29. PrefixValueEmpty = 0
  30. // PrefixValueLeaf is used for the first byte of a Value to indicate
  31. // that is a Leaf value
  32. PrefixValueLeaf = 1
  33. // PrefixValueIntermediate is used for the first byte of a Value to
  34. // indicate that is a Intermediate value
  35. PrefixValueIntermediate = 2
  36. // nChars is used to crop the Graphviz nodes labels
  37. nChars = 4
  38. )
  39. var (
  40. dbKeyRoot = []byte("root")
  41. dbKeyNLeafs = []byte("nleafs")
  42. emptyValue = []byte{0}
  43. // ErrKeyNotFound is used when a key is not found in the db neither in
  44. // the current db Batch.
  45. ErrKeyNotFound = fmt.Errorf("key not found")
  46. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  47. // tree that already exists.
  48. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  49. // ErrInvalidValuePrefix is used when going down into the tree, a value
  50. // is read from the db and has an unrecognized prefix.
  51. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  52. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.dbBatch==nil
  53. ErrDBNoTx = fmt.Errorf("dbPut error: no db Batch")
  54. // ErrMaxLevel indicates when going down into the tree, the max level is
  55. // reached
  56. ErrMaxLevel = fmt.Errorf("max level reached")
  57. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  58. // virtual level is reached
  59. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  60. )
  61. // Tree defines the struct that implements the MerkleTree functionalities
  62. type Tree struct {
  63. sync.RWMutex
  64. dbBatch db.Batch
  65. batchMemory kvMap // TODO TMP
  66. db db.Database
  67. maxLevels int
  68. root []byte
  69. hashFunction HashFunction
  70. // TODO in the methods that use it, check if emptyHash param is len>0
  71. // (check if it has been initialized)
  72. emptyHash []byte
  73. dbg *dbgStats
  74. }
  75. const bmSize = sha256.Size
  76. // TMP
  77. type kvMap map[[bmSize]byte]kv
  78. // Get retreives the value respective to a key from the KvMap
  79. func (m kvMap) Get(k []byte) ([]byte, bool) {
  80. v, ok := m[sha256.Sum256(k)]
  81. return v.v, ok
  82. }
  83. // Put stores a key and a value in the KvMap
  84. func (m kvMap) Put(k, v []byte) {
  85. m[sha256.Sum256(k)] = kv{k: k, v: v}
  86. }
  87. // NewTree returns a new Tree, if there is a Tree still in the given database, it
  88. // will load it.
  89. func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, error) {
  90. t := Tree{db: database, maxLevels: maxLevels, hashFunction: hash}
  91. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  92. root, err := t.dbGet(dbKeyRoot)
  93. if err == ErrKeyNotFound {
  94. // store new root 0
  95. t.dbBatch = t.db.NewBatch()
  96. t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP
  97. t.root = t.emptyHash
  98. if err = t.dbPut(dbKeyRoot, t.root); err != nil {
  99. return nil, err
  100. }
  101. if err = t.setNLeafs(0); err != nil {
  102. return nil, err
  103. }
  104. if err = t.dbBatch.Write(); err != nil {
  105. return nil, err
  106. }
  107. return &t, err
  108. } else if err != nil {
  109. return nil, err
  110. }
  111. t.root = root
  112. return &t, nil
  113. }
  114. // Root returns the root of the Tree
  115. func (t *Tree) Root() []byte {
  116. return t.root
  117. }
  118. // HashFunction returns Tree.hashFunction
  119. func (t *Tree) HashFunction() HashFunction {
  120. return t.hashFunction
  121. }
  122. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  123. // the indexes of the keys failed to add.
  124. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  125. t.Lock()
  126. defer t.Unlock()
  127. vt, err := t.loadVT()
  128. if err != nil {
  129. return nil, err
  130. }
  131. // TODO check validity of keys & values for Tree.hashFunction
  132. invalids, err := vt.addBatch(keys, values)
  133. if err != nil {
  134. return nil, err
  135. }
  136. // once the VirtualTree is build, compute the hashes
  137. pairs, err := vt.computeHashes()
  138. if err != nil {
  139. // TODO currently invalids in computeHashes are not counted
  140. return nil, err
  141. }
  142. t.root = vt.root.h
  143. // store pairs in db
  144. t.dbBatch = t.db.NewBatch()
  145. t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP
  146. for i := 0; i < len(pairs); i++ {
  147. if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil {
  148. return nil, err
  149. }
  150. }
  151. // store root to db
  152. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  153. return nil, err
  154. }
  155. // update nLeafs
  156. if err := t.incNLeafs(len(keys) - len(invalids)); err != nil {
  157. return nil, err
  158. }
  159. // commit db dbBatch
  160. if err := t.dbBatch.Write(); err != nil {
  161. return nil, err
  162. }
  163. return invalids, nil
  164. }
  165. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  166. // the same leafs.
  167. func (t *Tree) loadVT() (vt, error) {
  168. vt := newVT(t.maxLevels, t.hashFunction)
  169. vt.params.dbg = t.dbg
  170. err := t.Iterate(nil, func(k, v []byte) {
  171. if v[0] != PrefixValueLeaf {
  172. return
  173. }
  174. leafK, leafV := ReadLeafValue(v)
  175. if err := vt.add(0, leafK, leafV); err != nil {
  176. panic(err)
  177. }
  178. })
  179. return vt, err
  180. }
  181. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  182. // is expected that are represented by a Little-Endian byte array (for circom
  183. // compatibility).
  184. func (t *Tree) Add(k, v []byte) error {
  185. t.Lock()
  186. defer t.Unlock()
  187. var err error
  188. t.dbBatch = t.db.NewBatch()
  189. t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP
  190. // TODO check validity of key & value for Tree.hashFunction
  191. err = t.add(0, k, v) // add from level 0
  192. if err != nil {
  193. return err
  194. }
  195. // store root to db
  196. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  197. return err
  198. }
  199. // update nLeafs
  200. if err = t.incNLeafs(1); err != nil {
  201. return err
  202. }
  203. return t.dbBatch.Write()
  204. }
  205. func (t *Tree) add(fromLvl int, k, v []byte) error {
  206. keyPath := make([]byte, t.hashFunction.Len())
  207. copy(keyPath[:], k)
  208. path := getPath(t.maxLevels, keyPath)
  209. // go down to the leaf
  210. var siblings [][]byte
  211. _, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false)
  212. if err != nil {
  213. return err
  214. }
  215. leafKey, leafValue, err := t.newLeafValue(k, v)
  216. if err != nil {
  217. return err
  218. }
  219. if err := t.dbPut(leafKey, leafValue); err != nil {
  220. return err
  221. }
  222. // go up to the root
  223. if len(siblings) == 0 {
  224. t.root = leafKey
  225. return nil
  226. }
  227. root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl)
  228. if err != nil {
  229. return err
  230. }
  231. t.root = root
  232. return nil
  233. }
  234. // down goes down to the leaf recursively
  235. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
  236. path []bool, currLvl int, getLeaf bool) (
  237. []byte, []byte, [][]byte, error) {
  238. if currLvl > t.maxLevels-1 {
  239. return nil, nil, nil, ErrMaxLevel
  240. }
  241. var err error
  242. var currValue []byte
  243. if bytes.Equal(currKey, t.emptyHash) {
  244. // empty value
  245. return currKey, emptyValue, siblings, nil
  246. }
  247. currValue, err = t.dbGet(currKey)
  248. if err != nil {
  249. return nil, nil, nil, err
  250. }
  251. switch currValue[0] {
  252. case PrefixValueEmpty: // empty
  253. fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n",
  254. hex.EncodeToString(newKey), hex.EncodeToString(currKey),
  255. currLvl, hex.EncodeToString(currValue))
  256. panic("This point should not be reached, as the 'if' above" +
  257. " should avoid reaching this point. This panic is temporary" +
  258. " for reporting purposes, will be deleted in future versions." +
  259. " Please paste this log (including the previous lines) in a" +
  260. " new issue: https://github.com/arnaucube/arbo/issues/new") // TMP
  261. case PrefixValueLeaf: // leaf
  262. if !bytes.Equal(currValue, emptyValue) {
  263. if getLeaf {
  264. return currKey, currValue, siblings, nil
  265. }
  266. oldLeafKey, _ := ReadLeafValue(currValue)
  267. if bytes.Equal(newKey, oldLeafKey) {
  268. return nil, nil, nil, ErrKeyAlreadyExists
  269. }
  270. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  271. copy(oldLeafKeyFull[:], oldLeafKey)
  272. // if currKey is already used, go down until paths diverge
  273. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  274. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  275. if err != nil {
  276. return nil, nil, nil, err
  277. }
  278. }
  279. return currKey, currValue, siblings, nil
  280. case PrefixValueIntermediate: // intermediate
  281. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  282. return nil, nil, nil,
  283. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  284. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  285. }
  286. // collect siblings while going down
  287. if path[currLvl] {
  288. // right
  289. lChild, rChild := ReadIntermediateChilds(currValue)
  290. siblings = append(siblings, lChild)
  291. return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf)
  292. }
  293. // left
  294. lChild, rChild := ReadIntermediateChilds(currValue)
  295. siblings = append(siblings, rChild)
  296. return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf)
  297. default:
  298. return nil, nil, nil, ErrInvalidValuePrefix
  299. }
  300. }
  301. // downVirtually is used when in a leaf already exists, and a new leaf which
  302. // shares the path until the existing leaf is being added
  303. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  304. newPath []bool, currLvl int) ([][]byte, error) {
  305. var err error
  306. if currLvl > t.maxLevels-1 {
  307. return nil, ErrMaxVirtualLevel
  308. }
  309. if oldPath[currLvl] == newPath[currLvl] {
  310. siblings = append(siblings, t.emptyHash)
  311. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  312. if err != nil {
  313. return nil, err
  314. }
  315. return siblings, nil
  316. }
  317. // reached the divergence
  318. siblings = append(siblings, oldKey)
  319. return siblings, nil
  320. }
  321. // up goes up recursively updating the intermediate nodes
  322. func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) {
  323. var k, v []byte
  324. var err error
  325. if path[currLvl+toLvl] {
  326. k, v, err = t.newIntermediate(siblings[currLvl], key)
  327. if err != nil {
  328. return nil, err
  329. }
  330. } else {
  331. k, v, err = t.newIntermediate(key, siblings[currLvl])
  332. if err != nil {
  333. return nil, err
  334. }
  335. }
  336. // store k-v to db
  337. if err = t.dbPut(k, v); err != nil {
  338. return nil, err
  339. }
  340. if currLvl == 0 {
  341. // reached the root
  342. return k, nil
  343. }
  344. return t.up(k, siblings, path, currLvl-1, toLvl)
  345. }
  346. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  347. t.dbg.incHash()
  348. return newLeafValue(t.hashFunction, k, v)
  349. }
  350. // newLeafValue takes a key & value from a leaf, and computes the leaf hash,
  351. // which is used as the leaf key. And the value is the concatenation of the
  352. // inputed key & value. The output of this function is used as key-value to
  353. // store the leaf in the DB.
  354. // [ 1 byte | 1 byte | N bytes | M bytes ]
  355. // [ type of node | length of key | key | value ]
  356. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  357. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  358. if err != nil {
  359. return nil, nil, err
  360. }
  361. var leafValue []byte
  362. leafValue = append(leafValue, byte(1))
  363. leafValue = append(leafValue, byte(len(k)))
  364. leafValue = append(leafValue, k...)
  365. leafValue = append(leafValue, v...)
  366. return leafKey, leafValue, nil
  367. }
  368. // ReadLeafValue reads from a byte array the leaf key & value
  369. func ReadLeafValue(b []byte) ([]byte, []byte) {
  370. if len(b) < PrefixValueLen {
  371. return []byte{}, []byte{}
  372. }
  373. kLen := b[1]
  374. if len(b) < PrefixValueLen+int(kLen) {
  375. return []byte{}, []byte{}
  376. }
  377. k := b[PrefixValueLen : PrefixValueLen+kLen]
  378. v := b[PrefixValueLen+kLen:]
  379. return k, v
  380. }
  381. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  382. t.dbg.incHash()
  383. return newIntermediate(t.hashFunction, l, r)
  384. }
  385. // newIntermediate takes the left & right keys of a intermediate node, and
  386. // computes its hash. Returns the hash of the node, which is the node key, and a
  387. // byte array that contains the value (which contains the left & right child
  388. // keys) to store in the DB.
  389. // [ 1 byte | 1 byte | N bytes | N bytes ]
  390. // [ type of node | length of key | left key | right key ]
  391. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  392. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  393. b[0] = 2
  394. b[1] = byte(len(l))
  395. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  396. copy(b[PrefixValueLen+hashFunc.Len():], r)
  397. key, err := hashFunc.Hash(l, r)
  398. if err != nil {
  399. return nil, nil, err
  400. }
  401. return key, b, nil
  402. }
  403. // ReadIntermediateChilds reads from a byte array the two childs keys
  404. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  405. if len(b) < PrefixValueLen {
  406. return []byte{}, []byte{}
  407. }
  408. lLen := b[1]
  409. if len(b) < PrefixValueLen+int(lLen) {
  410. return []byte{}, []byte{}
  411. }
  412. l := b[PrefixValueLen : PrefixValueLen+lLen]
  413. r := b[PrefixValueLen+lLen:]
  414. return l, r
  415. }
  416. func getPath(numLevels int, k []byte) []bool {
  417. path := make([]bool, numLevels)
  418. for n := 0; n < numLevels; n++ {
  419. path[n] = k[n/8]&(1<<(n%8)) != 0
  420. }
  421. return path
  422. }
  423. // Update updates the value for a given existing key. If the given key does not
  424. // exist, returns an error.
  425. func (t *Tree) Update(k, v []byte) error {
  426. t.Lock()
  427. defer t.Unlock()
  428. var err error
  429. t.dbBatch = t.db.NewBatch()
  430. t.batchMemory = make(map[[bmSize]byte]kv) // TODO TMP
  431. keyPath := make([]byte, t.hashFunction.Len())
  432. copy(keyPath[:], k)
  433. path := getPath(t.maxLevels, keyPath)
  434. var siblings [][]byte
  435. _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  436. if err != nil {
  437. return err
  438. }
  439. oldKey, _ := ReadLeafValue(valueAtBottom)
  440. if !bytes.Equal(oldKey, k) {
  441. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  442. }
  443. leafKey, leafValue, err := t.newLeafValue(k, v)
  444. if err != nil {
  445. return err
  446. }
  447. if err := t.dbPut(leafKey, leafValue); err != nil {
  448. return err
  449. }
  450. // go up to the root
  451. if len(siblings) == 0 {
  452. t.root = leafKey
  453. return t.dbBatch.Write()
  454. }
  455. root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0)
  456. if err != nil {
  457. return err
  458. }
  459. t.root = root
  460. // store root to db
  461. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  462. return err
  463. }
  464. return t.dbBatch.Write()
  465. }
  466. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  467. // the Tree, the proof will be of existence, if the key does not exist in the
  468. // tree, the proof will be of non-existence.
  469. func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) {
  470. keyPath := make([]byte, t.hashFunction.Len())
  471. copy(keyPath[:], k)
  472. path := getPath(t.maxLevels, keyPath)
  473. // go down to the leaf
  474. var siblings [][]byte
  475. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  476. if err != nil {
  477. return nil, nil, err
  478. }
  479. leafK, leafV := ReadLeafValue(value)
  480. if !bytes.Equal(k, leafK) {
  481. fmt.Println("key not in Tree")
  482. fmt.Println(leafK)
  483. fmt.Println(leafV)
  484. // TODO proof of non-existence
  485. panic("unimplemented")
  486. }
  487. s := PackSiblings(t.hashFunction, siblings)
  488. return leafV, s, nil
  489. }
  490. // PackSiblings packs the siblings into a byte array.
  491. // [ 1 byte | L bytes | S * N bytes ]
  492. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  493. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  494. // array. And S is the size of the output of the hash function used for the
  495. // Tree.
  496. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  497. var b []byte
  498. var bitmap []bool
  499. emptySibling := make([]byte, hashFunc.Len())
  500. for i := 0; i < len(siblings); i++ {
  501. if bytes.Equal(siblings[i], emptySibling) {
  502. bitmap = append(bitmap, false)
  503. } else {
  504. bitmap = append(bitmap, true)
  505. b = append(b, siblings[i]...)
  506. }
  507. }
  508. bitmapBytes := bitmapToBytes(bitmap)
  509. l := len(bitmapBytes)
  510. res := make([]byte, l+1+len(b))
  511. res[0] = byte(l) // set the bitmapBytes length
  512. copy(res[1:1+l], bitmapBytes)
  513. copy(res[1+l:], b)
  514. return res
  515. }
  516. // UnpackSiblings unpacks the siblings from a byte array.
  517. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  518. l := b[0]
  519. bitmapBytes := b[1 : 1+l]
  520. bitmap := bytesToBitmap(bitmapBytes)
  521. siblingsBytes := b[1+l:]
  522. iSibl := 0
  523. emptySibl := make([]byte, hashFunc.Len())
  524. var siblings [][]byte
  525. for i := 0; i < len(bitmap); i++ {
  526. if iSibl >= len(siblingsBytes) {
  527. break
  528. }
  529. if bitmap[i] {
  530. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  531. iSibl += hashFunc.Len()
  532. } else {
  533. siblings = append(siblings, emptySibl)
  534. }
  535. }
  536. return siblings, nil
  537. }
  538. func bitmapToBytes(bitmap []bool) []byte {
  539. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  540. b := make([]byte, bitmapBytesLen)
  541. for i := 0; i < len(bitmap); i++ {
  542. if bitmap[i] {
  543. b[i/8] |= 1 << (i % 8)
  544. }
  545. }
  546. return b
  547. }
  548. func bytesToBitmap(b []byte) []bool {
  549. var bitmap []bool
  550. for i := 0; i < len(b); i++ {
  551. for j := 0; j < 8; j++ {
  552. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  553. }
  554. }
  555. return bitmap
  556. }
  557. // Get returns the value for a given key
  558. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  559. keyPath := make([]byte, t.hashFunction.Len())
  560. copy(keyPath[:], k)
  561. path := getPath(t.maxLevels, keyPath)
  562. // go down to the leaf
  563. var siblings [][]byte
  564. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  565. if err != nil {
  566. return nil, nil, err
  567. }
  568. leafK, leafV := ReadLeafValue(value)
  569. if !bytes.Equal(k, leafK) {
  570. return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s",
  571. BytesToBigInt(k), BytesToBigInt(leafK))
  572. }
  573. return leafK, leafV, nil
  574. }
  575. // CheckProof verifies the given proof. The proof verification depends on the
  576. // HashFunction passed as parameter.
  577. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  578. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  579. if err != nil {
  580. return false, err
  581. }
  582. keyPath := make([]byte, hashFunc.Len())
  583. copy(keyPath[:], k)
  584. key, _, err := newLeafValue(hashFunc, k, v)
  585. if err != nil {
  586. return false, err
  587. }
  588. path := getPath(len(siblings), keyPath)
  589. for i := len(siblings) - 1; i >= 0; i-- {
  590. if path[i] {
  591. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  592. if err != nil {
  593. return false, err
  594. }
  595. } else {
  596. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  597. if err != nil {
  598. return false, err
  599. }
  600. }
  601. }
  602. if bytes.Equal(key[:], root) {
  603. return true, nil
  604. }
  605. return false, nil
  606. }
  607. func (t *Tree) dbPut(k, v []byte) error {
  608. if t.dbBatch == nil {
  609. return ErrDBNoTx
  610. }
  611. t.dbg.incDbPut()
  612. t.batchMemory.Put(k, v) // TODO TMP
  613. return t.dbBatch.Put(k, v)
  614. }
  615. func (t *Tree) dbGet(k []byte) ([]byte, error) {
  616. // if key is empty, return empty as value
  617. if bytes.Equal(k, t.emptyHash) {
  618. return t.emptyHash, nil
  619. }
  620. t.dbg.incDbGet()
  621. v, err := t.db.Get(k)
  622. if err == nil {
  623. return v, nil
  624. }
  625. if t.dbBatch != nil {
  626. // TODO TMP
  627. v, ok := t.batchMemory.Get(k)
  628. if !ok {
  629. return nil, ErrKeyNotFound
  630. }
  631. // /TMP
  632. return v, nil
  633. }
  634. return nil, ErrKeyNotFound
  635. }
  636. // Warning: should be called with a Tree.dbBatch created, and with a
  637. // Tree.dbBatch.Write after the incNLeafs call.
  638. func (t *Tree) incNLeafs(nLeafs int) error {
  639. oldNLeafs, err := t.GetNLeafs()
  640. if err != nil {
  641. return err
  642. }
  643. newNLeafs := oldNLeafs + nLeafs
  644. return t.setNLeafs(newNLeafs)
  645. }
  646. // Warning: should be called with a Tree.dbBatch created, and with a
  647. // Tree.dbBatch.Write after the setNLeafs call.
  648. func (t *Tree) setNLeafs(nLeafs int) error {
  649. b := make([]byte, 8)
  650. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  651. if err := t.dbPut(dbKeyNLeafs, b); err != nil {
  652. return err
  653. }
  654. return nil
  655. }
  656. // GetNLeafs returns the number of Leafs of the Tree.
  657. func (t *Tree) GetNLeafs() (int, error) {
  658. b, err := t.dbGet(dbKeyNLeafs)
  659. if err != nil {
  660. return 0, err
  661. }
  662. nLeafs := binary.LittleEndian.Uint64(b)
  663. return int(nLeafs), nil
  664. }
  665. // Iterate iterates through the full Tree, executing the given function on each
  666. // node of the Tree.
  667. func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error {
  668. // allow to define which root to use
  669. if rootKey == nil {
  670. rootKey = t.Root()
  671. }
  672. return t.iter(rootKey, f)
  673. }
  674. // IterateWithStop does the same than Iterate, but with int for the current
  675. // level, and a boolean parameter used by the passed function, is to indicate to
  676. // stop iterating on the branch when the method returns 'true'.
  677. func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error {
  678. // allow to define which root to use
  679. if rootKey == nil {
  680. rootKey = t.Root()
  681. }
  682. return t.iterWithStop(rootKey, 0, f)
  683. }
  684. func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error {
  685. v, err := t.dbGet(k)
  686. if err != nil {
  687. return err
  688. }
  689. currLevel++
  690. switch v[0] {
  691. case PrefixValueEmpty:
  692. f(currLevel, k, v)
  693. case PrefixValueLeaf:
  694. f(currLevel, k, v)
  695. case PrefixValueIntermediate:
  696. stop := f(currLevel, k, v)
  697. if stop {
  698. return nil
  699. }
  700. l, r := ReadIntermediateChilds(v)
  701. if err = t.iterWithStop(l, currLevel, f); err != nil {
  702. return err
  703. }
  704. if err = t.iterWithStop(r, currLevel, f); err != nil {
  705. return err
  706. }
  707. default:
  708. return ErrInvalidValuePrefix
  709. }
  710. return nil
  711. }
  712. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  713. f2 := func(currLvl int, k, v []byte) bool {
  714. f(k, v)
  715. return false
  716. }
  717. return t.iterWithStop(k, 0, f2)
  718. }
  719. // Dump exports all the Tree leafs in a byte array of length:
  720. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  721. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  722. // [ len(k) | len(v) | key | value ]
  723. // Where S is the size of the output of the hash function used for the Tree.
  724. func (t *Tree) Dump(rootKey []byte) ([]byte, error) {
  725. // allow to define which root to use
  726. if rootKey == nil {
  727. rootKey = t.Root()
  728. }
  729. // WARNING current encoding only supports key & values of 255 bytes each
  730. // (due using only 1 byte for the length headers).
  731. var b []byte
  732. err := t.Iterate(rootKey, func(k, v []byte) {
  733. if v[0] != PrefixValueLeaf {
  734. return
  735. }
  736. leafK, leafV := ReadLeafValue(v)
  737. kv := make([]byte, 2+len(leafK)+len(leafV))
  738. kv[0] = byte(len(leafK))
  739. kv[1] = byte(len(leafV))
  740. copy(kv[2:2+len(leafK)], leafK)
  741. copy(kv[2+len(leafK):], leafV)
  742. b = append(b, kv...)
  743. })
  744. return b, err
  745. }
  746. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  747. // method) in the Tree.
  748. func (t *Tree) ImportDump(b []byte) error {
  749. r := bytes.NewReader(b)
  750. var err error
  751. var keys, values [][]byte
  752. for {
  753. l := make([]byte, 2)
  754. _, err = io.ReadFull(r, l)
  755. if err == io.EOF {
  756. break
  757. } else if err != nil {
  758. return err
  759. }
  760. k := make([]byte, l[0])
  761. _, err = io.ReadFull(r, k)
  762. if err != nil {
  763. return err
  764. }
  765. v := make([]byte, l[1])
  766. _, err = io.ReadFull(r, v)
  767. if err != nil {
  768. return err
  769. }
  770. keys = append(keys, k)
  771. values = append(values, v)
  772. }
  773. if _, err = t.AddBatch(keys, values); err != nil {
  774. return err
  775. }
  776. return nil
  777. }
  778. // Graphviz iterates across the full tree to generate a string Graphviz
  779. // representation of the tree and writes it to w
  780. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  781. return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
  782. }
  783. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  784. // generate a string Graphviz representation of the first NLevels of the tree
  785. // and writes it to w
  786. func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
  787. fmt.Fprintf(w, `digraph hierarchy {
  788. node [fontname=Monospace,fontsize=10,shape=box]
  789. `)
  790. if rootKey == nil {
  791. rootKey = t.Root()
  792. }
  793. nEmpties := 0
  794. err := t.iterWithStop(rootKey, 0, func(currLvl int, k, v []byte) bool {
  795. if currLvl == untilLvl {
  796. return true // to stop the iter from going down
  797. }
  798. switch v[0] {
  799. case PrefixValueEmpty:
  800. case PrefixValueLeaf:
  801. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  802. // key & value from the leaf
  803. kB, vB := ReadLeafValue(v)
  804. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  805. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  806. hex.EncodeToString(vB[:nChars]))
  807. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  808. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  809. case PrefixValueIntermediate:
  810. l, r := ReadIntermediateChilds(v)
  811. lStr := hex.EncodeToString(l[:nChars])
  812. rStr := hex.EncodeToString(r[:nChars])
  813. eStr := ""
  814. if bytes.Equal(l, t.emptyHash) {
  815. lStr = fmt.Sprintf("empty%v", nEmpties)
  816. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  817. lStr)
  818. nEmpties++
  819. }
  820. if bytes.Equal(r, t.emptyHash) {
  821. rStr = fmt.Sprintf("empty%v", nEmpties)
  822. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  823. rStr)
  824. nEmpties++
  825. }
  826. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  827. lStr, rStr)
  828. fmt.Fprint(w, eStr)
  829. default:
  830. }
  831. return false
  832. })
  833. fmt.Fprintf(w, "}\n")
  834. return err
  835. }
  836. // PrintGraphviz prints the output of Tree.Graphviz
  837. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  838. if rootKey == nil {
  839. rootKey = t.Root()
  840. }
  841. return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
  842. }
  843. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  844. func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
  845. if rootKey == nil {
  846. rootKey = t.Root()
  847. }
  848. w := bytes.NewBufferString("")
  849. fmt.Fprintf(w,
  850. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  851. err := t.GraphvizFirstNLevels(w, rootKey, untilLvl)
  852. if err != nil {
  853. fmt.Println(w)
  854. return err
  855. }
  856. fmt.Fprintf(w,
  857. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  858. fmt.Println(w)
  859. return nil
  860. }
  861. // Purge WIP: unimplemented TODO
  862. func (t *Tree) Purge(keys [][]byte) error {
  863. return nil
  864. }
  865. // TODO circom proofs