You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1013 lines
27 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree, following the specification from
  4. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  5. https://eprint.iacr.org/2018/955.
  6. Allows to define which hash function to use. So for example, when working with
  7. zkSnarks the Poseidon hash function can be used, but when not, it can be used
  8. the Blake2b hash function, which has much faster computation time.
  9. */
  10. package arbo
  11. import (
  12. "bytes"
  13. "crypto/sha256"
  14. "encoding/binary"
  15. "encoding/hex"
  16. "fmt"
  17. "io"
  18. "math"
  19. "sync"
  20. "go.vocdoni.io/dvote/db"
  21. )
  22. const (
  23. // PrefixValueLen defines the bytes-prefix length used for the Value
  24. // bytes representation stored in the db
  25. PrefixValueLen = 2
  26. // PrefixValueEmpty is used for the first byte of a Value to indicate
  27. // that is an Empty value
  28. PrefixValueEmpty = 0
  29. // PrefixValueLeaf is used for the first byte of a Value to indicate
  30. // that is a Leaf value
  31. PrefixValueLeaf = 1
  32. // PrefixValueIntermediate is used for the first byte of a Value to
  33. // indicate that is a Intermediate value
  34. PrefixValueIntermediate = 2
  35. // nChars is used to crop the Graphviz nodes labels
  36. nChars = 4
  37. )
  38. var (
  39. dbKeyRoot = []byte("root")
  40. dbKeyNLeafs = []byte("nleafs")
  41. emptyValue = []byte{0}
  42. // ErrKeyNotFound is used when a key is not found in the db neither in
  43. // the current db Batch.
  44. ErrKeyNotFound = fmt.Errorf("key not found")
  45. // ErrKeyAlreadyExists is used when trying to add a key as leaf to the
  46. // tree that already exists.
  47. ErrKeyAlreadyExists = fmt.Errorf("key already exists")
  48. // ErrInvalidValuePrefix is used when going down into the tree, a value
  49. // is read from the db and has an unrecognized prefix.
  50. ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix")
  51. // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.dbBatch==nil
  52. ErrDBNoTx = fmt.Errorf("dbPut error: no db Batch")
  53. // ErrMaxLevel indicates when going down into the tree, the max level is
  54. // reached
  55. ErrMaxLevel = fmt.Errorf("max level reached")
  56. // ErrMaxVirtualLevel indicates when going down into the tree, the max
  57. // virtual level is reached
  58. ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached")
  59. )
  60. // Tree defines the struct that implements the MerkleTree functionalities
  61. type Tree struct {
  62. sync.RWMutex
  63. dbBatch db.Batch
  64. batchMemory kvMap // TODO TMP
  65. db db.Database
  66. maxLevels int
  67. root []byte
  68. hashFunction HashFunction
  69. // TODO in the methods that use it, check if emptyHash param is len>0
  70. // (check if it has been initialized)
  71. emptyHash []byte
  72. dbg *dbgStats
  73. }
  74. // bmKeySize stands for batchMemoryKeySize
  75. const bmKeySize = sha256.Size
  76. // TMP
  77. type kvMap map[[bmKeySize]byte]kv
  78. // Get retreives the value respective to a key from the KvMap
  79. func (m kvMap) Get(k []byte) ([]byte, bool) {
  80. v, ok := m[sha256.Sum256(k)]
  81. return v.v, ok
  82. }
  83. // Put stores a key and a value in the KvMap
  84. func (m kvMap) Put(k, v []byte) {
  85. m[sha256.Sum256(k)] = kv{k: k, v: v}
  86. }
  87. // NewTree returns a new Tree, if there is a Tree still in the given database, it
  88. // will load it.
  89. func NewTree(database db.Database, maxLevels int, hash HashFunction) (*Tree, error) {
  90. t := Tree{db: database, maxLevels: maxLevels, hashFunction: hash}
  91. t.emptyHash = make([]byte, t.hashFunction.Len()) // empty
  92. root, err := t.dbGet(dbKeyRoot)
  93. if err == ErrKeyNotFound {
  94. // store new root 0
  95. t.dbBatch = t.db.NewBatch()
  96. t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP
  97. t.root = t.emptyHash
  98. if err = t.dbPut(dbKeyRoot, t.root); err != nil {
  99. return nil, err
  100. }
  101. if err = t.setNLeafs(0); err != nil {
  102. return nil, err
  103. }
  104. if err = t.dbBatch.Write(); err != nil {
  105. return nil, err
  106. }
  107. return &t, err
  108. } else if err != nil {
  109. return nil, err
  110. }
  111. t.root = root
  112. return &t, nil
  113. }
  114. // Root returns the root of the Tree
  115. func (t *Tree) Root() []byte {
  116. return t.root
  117. }
  118. // HashFunction returns Tree.hashFunction
  119. func (t *Tree) HashFunction() HashFunction {
  120. return t.hashFunction
  121. }
  122. // AddBatch adds a batch of key-values to the Tree. Returns an array containing
  123. // the indexes of the keys failed to add. Supports empty values as input
  124. // parameters, which is equivalent to 0 valued byte array.
  125. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  126. t.Lock()
  127. defer t.Unlock()
  128. vt, err := t.loadVT()
  129. if err != nil {
  130. return nil, err
  131. }
  132. // TODO check validity of keys & values for Tree.hashFunction (maybe do
  133. // not add the checks, as would need more time, and this could be
  134. // checked/ensured before calling this method)
  135. e := []byte{}
  136. // equal the number of keys & values
  137. if len(keys) > len(values) {
  138. // add missing values
  139. for i := len(values); i < len(keys); i++ {
  140. values = append(values, e)
  141. }
  142. } else if len(keys) < len(values) {
  143. // crop extra values
  144. values = values[:len(keys)]
  145. }
  146. invalids, err := vt.addBatch(keys, values)
  147. if err != nil {
  148. return nil, err
  149. }
  150. // once the VirtualTree is build, compute the hashes
  151. pairs, err := vt.computeHashes()
  152. if err != nil {
  153. // currently invalids in computeHashes are not counted,
  154. // but should not be needed, as if there is an error there is
  155. // nothing stored in the db and the error is returned
  156. return nil, err
  157. }
  158. t.root = vt.root.h
  159. // store pairs in db
  160. t.dbBatch = t.db.NewBatch()
  161. t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP
  162. for i := 0; i < len(pairs); i++ {
  163. if err := t.dbPut(pairs[i][0], pairs[i][1]); err != nil {
  164. return nil, err
  165. }
  166. }
  167. // store root to db
  168. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  169. return nil, err
  170. }
  171. // update nLeafs
  172. if err := t.incNLeafs(len(keys) - len(invalids)); err != nil {
  173. return nil, err
  174. }
  175. // commit db dbBatch
  176. if err := t.dbBatch.Write(); err != nil {
  177. return nil, err
  178. }
  179. return invalids, nil
  180. }
  181. // loadVT loads a new virtual tree (vt) from the current Tree, which contains
  182. // the same leafs.
  183. func (t *Tree) loadVT() (vt, error) {
  184. vt := newVT(t.maxLevels, t.hashFunction)
  185. vt.params.dbg = t.dbg
  186. err := t.Iterate(nil, func(k, v []byte) {
  187. if v[0] != PrefixValueLeaf {
  188. return
  189. }
  190. leafK, leafV := ReadLeafValue(v)
  191. if err := vt.add(0, leafK, leafV); err != nil {
  192. // TODO instead of panic, return this error
  193. panic(err)
  194. }
  195. })
  196. return vt, err
  197. }
  198. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  199. // is expected that are represented by a Little-Endian byte array (for circom
  200. // compatibility).
  201. func (t *Tree) Add(k, v []byte) error {
  202. t.Lock()
  203. defer t.Unlock()
  204. var err error
  205. t.dbBatch = t.db.NewBatch()
  206. t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP
  207. // TODO check validity of key & value for Tree.hashFunction (maybe do
  208. // not add the checks, as would need more time, and this could be
  209. // checked/ensured before calling this method)
  210. err = t.add(0, k, v) // add from level 0
  211. if err != nil {
  212. return err
  213. }
  214. // store root to db
  215. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  216. return err
  217. }
  218. // update nLeafs
  219. if err = t.incNLeafs(1); err != nil {
  220. return err
  221. }
  222. return t.dbBatch.Write()
  223. }
  224. func (t *Tree) add(fromLvl int, k, v []byte) error {
  225. keyPath := make([]byte, t.hashFunction.Len())
  226. copy(keyPath[:], k)
  227. path := getPath(t.maxLevels, keyPath)
  228. // go down to the leaf
  229. var siblings [][]byte
  230. _, _, siblings, err := t.down(k, t.root, siblings, path, fromLvl, false)
  231. if err != nil {
  232. return err
  233. }
  234. leafKey, leafValue, err := t.newLeafValue(k, v)
  235. if err != nil {
  236. return err
  237. }
  238. if err := t.dbPut(leafKey, leafValue); err != nil {
  239. return err
  240. }
  241. // go up to the root
  242. if len(siblings) == 0 {
  243. t.root = leafKey
  244. return nil
  245. }
  246. root, err := t.up(leafKey, siblings, path, len(siblings)-1, fromLvl)
  247. if err != nil {
  248. return err
  249. }
  250. t.root = root
  251. return nil
  252. }
  253. // down goes down to the leaf recursively
  254. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte,
  255. path []bool, currLvl int, getLeaf bool) (
  256. []byte, []byte, [][]byte, error) {
  257. if currLvl > t.maxLevels-1 {
  258. return nil, nil, nil, ErrMaxLevel
  259. }
  260. var err error
  261. var currValue []byte
  262. if bytes.Equal(currKey, t.emptyHash) {
  263. // empty value
  264. return currKey, emptyValue, siblings, nil
  265. }
  266. currValue, err = t.dbGet(currKey)
  267. if err != nil {
  268. return nil, nil, nil, err
  269. }
  270. switch currValue[0] {
  271. case PrefixValueEmpty: // empty
  272. fmt.Printf("newKey: %s, currKey: %s, currLvl: %d, currValue: %s\n",
  273. hex.EncodeToString(newKey), hex.EncodeToString(currKey),
  274. currLvl, hex.EncodeToString(currValue))
  275. panic("This point should not be reached, as the 'if' above" +
  276. " should avoid reaching this point. This panic is temporary" +
  277. " for reporting purposes, will be deleted in future versions." +
  278. " Please paste this log (including the previous lines) in a" +
  279. " new issue: https://github.com/vocdoni/arbo/issues/new") // TMP
  280. case PrefixValueLeaf: // leaf
  281. if !bytes.Equal(currValue, emptyValue) {
  282. if getLeaf {
  283. return currKey, currValue, siblings, nil
  284. }
  285. oldLeafKey, _ := ReadLeafValue(currValue)
  286. if bytes.Equal(newKey, oldLeafKey) {
  287. return nil, nil, nil, ErrKeyAlreadyExists
  288. }
  289. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  290. copy(oldLeafKeyFull[:], oldLeafKey)
  291. // if currKey is already used, go down until paths diverge
  292. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  293. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, currLvl)
  294. if err != nil {
  295. return nil, nil, nil, err
  296. }
  297. }
  298. return currKey, currValue, siblings, nil
  299. case PrefixValueIntermediate: // intermediate
  300. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  301. return nil, nil, nil,
  302. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  303. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  304. }
  305. // collect siblings while going down
  306. if path[currLvl] {
  307. // right
  308. lChild, rChild := ReadIntermediateChilds(currValue)
  309. siblings = append(siblings, lChild)
  310. return t.down(newKey, rChild, siblings, path, currLvl+1, getLeaf)
  311. }
  312. // left
  313. lChild, rChild := ReadIntermediateChilds(currValue)
  314. siblings = append(siblings, rChild)
  315. return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf)
  316. default:
  317. return nil, nil, nil, ErrInvalidValuePrefix
  318. }
  319. }
  320. // downVirtually is used when in a leaf already exists, and a new leaf which
  321. // shares the path until the existing leaf is being added
  322. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  323. newPath []bool, currLvl int) ([][]byte, error) {
  324. var err error
  325. if currLvl > t.maxLevels-1 {
  326. return nil, ErrMaxVirtualLevel
  327. }
  328. if oldPath[currLvl] == newPath[currLvl] {
  329. siblings = append(siblings, t.emptyHash)
  330. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, currLvl+1)
  331. if err != nil {
  332. return nil, err
  333. }
  334. return siblings, nil
  335. }
  336. // reached the divergence
  337. siblings = append(siblings, oldKey)
  338. return siblings, nil
  339. }
  340. // up goes up recursively updating the intermediate nodes
  341. func (t *Tree) up(key []byte, siblings [][]byte, path []bool, currLvl, toLvl int) ([]byte, error) {
  342. var k, v []byte
  343. var err error
  344. if path[currLvl+toLvl] {
  345. k, v, err = t.newIntermediate(siblings[currLvl], key)
  346. if err != nil {
  347. return nil, err
  348. }
  349. } else {
  350. k, v, err = t.newIntermediate(key, siblings[currLvl])
  351. if err != nil {
  352. return nil, err
  353. }
  354. }
  355. // store k-v to db
  356. if err = t.dbPut(k, v); err != nil {
  357. return nil, err
  358. }
  359. if currLvl == 0 {
  360. // reached the root
  361. return k, nil
  362. }
  363. return t.up(k, siblings, path, currLvl-1, toLvl)
  364. }
  365. func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
  366. t.dbg.incHash()
  367. return newLeafValue(t.hashFunction, k, v)
  368. }
  369. // newLeafValue takes a key & value from a leaf, and computes the leaf hash,
  370. // which is used as the leaf key. And the value is the concatenation of the
  371. // inputed key & value. The output of this function is used as key-value to
  372. // store the leaf in the DB.
  373. // [ 1 byte | 1 byte | N bytes | M bytes ]
  374. // [ type of node | length of key | key | value ]
  375. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  376. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  377. if err != nil {
  378. return nil, nil, err
  379. }
  380. var leafValue []byte
  381. leafValue = append(leafValue, byte(1))
  382. leafValue = append(leafValue, byte(len(k)))
  383. leafValue = append(leafValue, k...)
  384. leafValue = append(leafValue, v...)
  385. return leafKey, leafValue, nil
  386. }
  387. // ReadLeafValue reads from a byte array the leaf key & value
  388. func ReadLeafValue(b []byte) ([]byte, []byte) {
  389. if len(b) < PrefixValueLen {
  390. return []byte{}, []byte{}
  391. }
  392. kLen := b[1]
  393. if len(b) < PrefixValueLen+int(kLen) {
  394. return []byte{}, []byte{}
  395. }
  396. k := b[PrefixValueLen : PrefixValueLen+kLen]
  397. v := b[PrefixValueLen+kLen:]
  398. return k, v
  399. }
  400. func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
  401. t.dbg.incHash()
  402. return newIntermediate(t.hashFunction, l, r)
  403. }
  404. // newIntermediate takes the left & right keys of a intermediate node, and
  405. // computes its hash. Returns the hash of the node, which is the node key, and a
  406. // byte array that contains the value (which contains the left & right child
  407. // keys) to store in the DB.
  408. // [ 1 byte | 1 byte | N bytes | N bytes ]
  409. // [ type of node | length of key | left key | right key ]
  410. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  411. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  412. b[0] = 2
  413. b[1] = byte(len(l))
  414. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  415. copy(b[PrefixValueLen+hashFunc.Len():], r)
  416. key, err := hashFunc.Hash(l, r)
  417. if err != nil {
  418. return nil, nil, err
  419. }
  420. return key, b, nil
  421. }
  422. // ReadIntermediateChilds reads from a byte array the two childs keys
  423. func ReadIntermediateChilds(b []byte) ([]byte, []byte) {
  424. if len(b) < PrefixValueLen {
  425. return []byte{}, []byte{}
  426. }
  427. lLen := b[1]
  428. if len(b) < PrefixValueLen+int(lLen) {
  429. return []byte{}, []byte{}
  430. }
  431. l := b[PrefixValueLen : PrefixValueLen+lLen]
  432. r := b[PrefixValueLen+lLen:]
  433. return l, r
  434. }
  435. func getPath(numLevels int, k []byte) []bool {
  436. path := make([]bool, numLevels)
  437. for n := 0; n < numLevels; n++ {
  438. path[n] = k[n/8]&(1<<(n%8)) != 0
  439. }
  440. return path
  441. }
  442. // Update updates the value for a given existing key. If the given key does not
  443. // exist, returns an error.
  444. func (t *Tree) Update(k, v []byte) error {
  445. t.Lock()
  446. defer t.Unlock()
  447. var err error
  448. t.dbBatch = t.db.NewBatch()
  449. t.batchMemory = make(map[[bmKeySize]byte]kv) // TODO TMP
  450. keyPath := make([]byte, t.hashFunction.Len())
  451. copy(keyPath[:], k)
  452. path := getPath(t.maxLevels, keyPath)
  453. var siblings [][]byte
  454. _, valueAtBottom, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  455. if err != nil {
  456. return err
  457. }
  458. oldKey, _ := ReadLeafValue(valueAtBottom)
  459. if !bytes.Equal(oldKey, k) {
  460. return fmt.Errorf("key %s does not exist", hex.EncodeToString(k))
  461. }
  462. leafKey, leafValue, err := t.newLeafValue(k, v)
  463. if err != nil {
  464. return err
  465. }
  466. if err := t.dbPut(leafKey, leafValue); err != nil {
  467. return err
  468. }
  469. // go up to the root
  470. if len(siblings) == 0 {
  471. t.root = leafKey
  472. return t.dbBatch.Write()
  473. }
  474. root, err := t.up(leafKey, siblings, path, len(siblings)-1, 0)
  475. if err != nil {
  476. return err
  477. }
  478. t.root = root
  479. // store root to db
  480. if err := t.dbPut(dbKeyRoot, t.root); err != nil {
  481. return err
  482. }
  483. return t.dbBatch.Write()
  484. }
  485. // GenProof generates a MerkleTree proof for the given key. The leaf value is
  486. // returned, together with the packed siblings of the proof, and a boolean
  487. // parameter that indicates if the proof is of existence (true) or not (false).
  488. func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
  489. keyPath := make([]byte, t.hashFunction.Len())
  490. copy(keyPath[:], k)
  491. path := getPath(t.maxLevels, keyPath)
  492. // go down to the leaf
  493. var siblings [][]byte
  494. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  495. if err != nil {
  496. return nil, nil, nil, false, err
  497. }
  498. s := PackSiblings(t.hashFunction, siblings)
  499. leafK, leafV := ReadLeafValue(value)
  500. if !bytes.Equal(k, leafK) {
  501. // key not in tree, proof of non-existence
  502. return leafK, leafV, s, false, err
  503. }
  504. return leafK, leafV, s, true, nil
  505. }
  506. // PackSiblings packs the siblings into a byte array.
  507. // [ 1 byte | L bytes | S * N bytes ]
  508. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  509. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  510. // array. And S is the size of the output of the hash function used for the
  511. // Tree.
  512. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  513. var b []byte
  514. var bitmap []bool
  515. emptySibling := make([]byte, hashFunc.Len())
  516. for i := 0; i < len(siblings); i++ {
  517. if bytes.Equal(siblings[i], emptySibling) {
  518. bitmap = append(bitmap, false)
  519. } else {
  520. bitmap = append(bitmap, true)
  521. b = append(b, siblings[i]...)
  522. }
  523. }
  524. bitmapBytes := bitmapToBytes(bitmap)
  525. l := len(bitmapBytes)
  526. res := make([]byte, l+1+len(b))
  527. res[0] = byte(l) // set the bitmapBytes length
  528. copy(res[1:1+l], bitmapBytes)
  529. copy(res[1+l:], b)
  530. return res
  531. }
  532. // UnpackSiblings unpacks the siblings from a byte array.
  533. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  534. l := b[0]
  535. bitmapBytes := b[1 : 1+l]
  536. bitmap := bytesToBitmap(bitmapBytes)
  537. siblingsBytes := b[1+l:]
  538. iSibl := 0
  539. emptySibl := make([]byte, hashFunc.Len())
  540. var siblings [][]byte
  541. for i := 0; i < len(bitmap); i++ {
  542. if iSibl >= len(siblingsBytes) {
  543. break
  544. }
  545. if bitmap[i] {
  546. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  547. iSibl += hashFunc.Len()
  548. } else {
  549. siblings = append(siblings, emptySibl)
  550. }
  551. }
  552. return siblings, nil
  553. }
  554. func bitmapToBytes(bitmap []bool) []byte {
  555. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  556. b := make([]byte, bitmapBytesLen)
  557. for i := 0; i < len(bitmap); i++ {
  558. if bitmap[i] {
  559. b[i/8] |= 1 << (i % 8)
  560. }
  561. }
  562. return b
  563. }
  564. func bytesToBitmap(b []byte) []bool {
  565. var bitmap []bool
  566. for i := 0; i < len(b); i++ {
  567. for j := 0; j < 8; j++ {
  568. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  569. }
  570. }
  571. return bitmap
  572. }
  573. // Get returns the value for a given key
  574. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  575. keyPath := make([]byte, t.hashFunction.Len())
  576. copy(keyPath[:], k)
  577. path := getPath(t.maxLevels, keyPath)
  578. // go down to the leaf
  579. var siblings [][]byte
  580. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  581. if err != nil {
  582. return nil, nil, err
  583. }
  584. leafK, leafV := ReadLeafValue(value)
  585. if !bytes.Equal(k, leafK) {
  586. return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s",
  587. BytesToBigInt(k), BytesToBigInt(leafK))
  588. }
  589. return leafK, leafV, nil
  590. }
  591. // CheckProof verifies the given proof. The proof verification depends on the
  592. // HashFunction passed as parameter.
  593. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  594. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  595. if err != nil {
  596. return false, err
  597. }
  598. keyPath := make([]byte, hashFunc.Len())
  599. copy(keyPath[:], k)
  600. key, _, err := newLeafValue(hashFunc, k, v)
  601. if err != nil {
  602. return false, err
  603. }
  604. path := getPath(len(siblings), keyPath)
  605. for i := len(siblings) - 1; i >= 0; i-- {
  606. if path[i] {
  607. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  608. if err != nil {
  609. return false, err
  610. }
  611. } else {
  612. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  613. if err != nil {
  614. return false, err
  615. }
  616. }
  617. }
  618. if bytes.Equal(key[:], root) {
  619. return true, nil
  620. }
  621. return false, nil
  622. }
  623. func (t *Tree) dbPut(k, v []byte) error {
  624. if t.dbBatch == nil {
  625. return ErrDBNoTx
  626. }
  627. t.dbg.incDbPut()
  628. t.batchMemory.Put(k, v) // TODO TMP
  629. return t.dbBatch.Put(k, v)
  630. }
  631. func (t *Tree) dbGet(k []byte) ([]byte, error) {
  632. // if key is empty, return empty as value
  633. if bytes.Equal(k, t.emptyHash) {
  634. return t.emptyHash, nil
  635. }
  636. t.dbg.incDbGet()
  637. v, err := t.db.Get(k)
  638. if err == nil {
  639. return v, nil
  640. }
  641. if t.dbBatch != nil {
  642. // TODO TMP
  643. v, ok := t.batchMemory.Get(k)
  644. if !ok {
  645. return nil, ErrKeyNotFound
  646. }
  647. // /TMP
  648. return v, nil
  649. }
  650. return nil, ErrKeyNotFound
  651. }
  652. // Warning: should be called with a Tree.dbBatch created, and with a
  653. // Tree.dbBatch.Write after the incNLeafs call.
  654. func (t *Tree) incNLeafs(nLeafs int) error {
  655. oldNLeafs, err := t.GetNLeafs()
  656. if err != nil {
  657. return err
  658. }
  659. newNLeafs := oldNLeafs + nLeafs
  660. return t.setNLeafs(newNLeafs)
  661. }
  662. // Warning: should be called with a Tree.dbBatch created, and with a
  663. // Tree.dbBatch.Write after the setNLeafs call.
  664. func (t *Tree) setNLeafs(nLeafs int) error {
  665. b := make([]byte, 8)
  666. binary.LittleEndian.PutUint64(b, uint64(nLeafs))
  667. if err := t.dbPut(dbKeyNLeafs, b); err != nil {
  668. return err
  669. }
  670. return nil
  671. }
  672. // GetNLeafs returns the number of Leafs of the Tree.
  673. func (t *Tree) GetNLeafs() (int, error) {
  674. b, err := t.dbGet(dbKeyNLeafs)
  675. if err != nil {
  676. return 0, err
  677. }
  678. nLeafs := binary.LittleEndian.Uint64(b)
  679. return int(nLeafs), nil
  680. }
  681. // Snapshot returns a copy of the Tree from the given root
  682. func (t *Tree) Snapshot(rootKey []byte) (*Tree, error) {
  683. // TODO currently this method only changes the 'root pointer', but the
  684. // db continues being the same. In a future iteration, once the
  685. // db.Database interface allows to do database checkpoints, this method
  686. // could be updated to do a full checkpoint of the database for the
  687. // snapshot, to return a completly new independent tree containing the
  688. // snapshot.
  689. t.RLock()
  690. defer t.RUnlock()
  691. // allow to define which root to use
  692. if rootKey == nil {
  693. rootKey = t.Root()
  694. }
  695. return &Tree{
  696. db: t.db,
  697. maxLevels: t.maxLevels,
  698. root: rootKey,
  699. hashFunction: t.hashFunction,
  700. dbg: t.dbg,
  701. }, nil
  702. }
  703. // Iterate iterates through the full Tree, executing the given function on each
  704. // node of the Tree.
  705. func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error {
  706. // allow to define which root to use
  707. if rootKey == nil {
  708. rootKey = t.Root()
  709. }
  710. return t.iter(rootKey, f)
  711. }
  712. // IterateWithStop does the same than Iterate, but with int for the current
  713. // level, and a boolean parameter used by the passed function, is to indicate to
  714. // stop iterating on the branch when the method returns 'true'.
  715. func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error {
  716. // allow to define which root to use
  717. if rootKey == nil {
  718. rootKey = t.Root()
  719. }
  720. return t.iterWithStop(rootKey, 0, f)
  721. }
  722. func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error {
  723. v, err := t.dbGet(k)
  724. if err != nil {
  725. return err
  726. }
  727. currLevel++
  728. switch v[0] {
  729. case PrefixValueEmpty:
  730. f(currLevel, k, v)
  731. case PrefixValueLeaf:
  732. f(currLevel, k, v)
  733. case PrefixValueIntermediate:
  734. stop := f(currLevel, k, v)
  735. if stop {
  736. return nil
  737. }
  738. l, r := ReadIntermediateChilds(v)
  739. if err = t.iterWithStop(l, currLevel, f); err != nil {
  740. return err
  741. }
  742. if err = t.iterWithStop(r, currLevel, f); err != nil {
  743. return err
  744. }
  745. default:
  746. return ErrInvalidValuePrefix
  747. }
  748. return nil
  749. }
  750. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  751. f2 := func(currLvl int, k, v []byte) bool {
  752. f(k, v)
  753. return false
  754. }
  755. return t.iterWithStop(k, 0, f2)
  756. }
  757. // Dump exports all the Tree leafs in a byte array of length:
  758. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  759. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  760. // [ len(k) | len(v) | key | value ]
  761. // Where S is the size of the output of the hash function used for the Tree.
  762. func (t *Tree) Dump(rootKey []byte) ([]byte, error) {
  763. // allow to define which root to use
  764. if rootKey == nil {
  765. rootKey = t.Root()
  766. }
  767. // WARNING current encoding only supports key & values of 255 bytes each
  768. // (due using only 1 byte for the length headers).
  769. var b []byte
  770. err := t.Iterate(rootKey, func(k, v []byte) {
  771. if v[0] != PrefixValueLeaf {
  772. return
  773. }
  774. leafK, leafV := ReadLeafValue(v)
  775. kv := make([]byte, 2+len(leafK)+len(leafV))
  776. kv[0] = byte(len(leafK))
  777. kv[1] = byte(len(leafV))
  778. copy(kv[2:2+len(leafK)], leafK)
  779. copy(kv[2+len(leafK):], leafV)
  780. b = append(b, kv...)
  781. })
  782. return b, err
  783. }
  784. // ImportDump imports the leafs (that have been exported with the Dump method)
  785. // in the Tree.
  786. func (t *Tree) ImportDump(b []byte) error {
  787. r := bytes.NewReader(b)
  788. var err error
  789. var keys, values [][]byte
  790. for {
  791. l := make([]byte, 2)
  792. _, err = io.ReadFull(r, l)
  793. if err == io.EOF {
  794. break
  795. } else if err != nil {
  796. return err
  797. }
  798. k := make([]byte, l[0])
  799. _, err = io.ReadFull(r, k)
  800. if err != nil {
  801. return err
  802. }
  803. v := make([]byte, l[1])
  804. _, err = io.ReadFull(r, v)
  805. if err != nil {
  806. return err
  807. }
  808. keys = append(keys, k)
  809. values = append(values, v)
  810. }
  811. if _, err = t.AddBatch(keys, values); err != nil {
  812. return err
  813. }
  814. return nil
  815. }
  816. // Graphviz iterates across the full tree to generate a string Graphviz
  817. // representation of the tree and writes it to w
  818. func (t *Tree) Graphviz(w io.Writer, rootKey []byte) error {
  819. return t.GraphvizFirstNLevels(w, rootKey, t.maxLevels)
  820. }
  821. // GraphvizFirstNLevels iterates across the first NLevels of the tree to
  822. // generate a string Graphviz representation of the first NLevels of the tree
  823. // and writes it to w
  824. func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) error {
  825. fmt.Fprintf(w, `digraph hierarchy {
  826. node [fontname=Monospace,fontsize=10,shape=box]
  827. `)
  828. if rootKey == nil {
  829. rootKey = t.Root()
  830. }
  831. nEmpties := 0
  832. err := t.iterWithStop(rootKey, 0, func(currLvl int, k, v []byte) bool {
  833. if currLvl == untilLvl {
  834. return true // to stop the iter from going down
  835. }
  836. switch v[0] {
  837. case PrefixValueEmpty:
  838. case PrefixValueLeaf:
  839. fmt.Fprintf(w, "\"%v\" [style=filled];\n", hex.EncodeToString(k[:nChars]))
  840. // key & value from the leaf
  841. kB, vB := ReadLeafValue(v)
  842. fmt.Fprintf(w, "\"%v\" -> {\"k:%v\\nv:%v\"}\n",
  843. hex.EncodeToString(k[:nChars]), hex.EncodeToString(kB[:nChars]),
  844. hex.EncodeToString(vB[:nChars]))
  845. fmt.Fprintf(w, "\"k:%v\\nv:%v\" [style=dashed]\n",
  846. hex.EncodeToString(kB[:nChars]), hex.EncodeToString(vB[:nChars]))
  847. case PrefixValueIntermediate:
  848. l, r := ReadIntermediateChilds(v)
  849. lStr := hex.EncodeToString(l[:nChars])
  850. rStr := hex.EncodeToString(r[:nChars])
  851. eStr := ""
  852. if bytes.Equal(l, t.emptyHash) {
  853. lStr = fmt.Sprintf("empty%v", nEmpties)
  854. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  855. lStr)
  856. nEmpties++
  857. }
  858. if bytes.Equal(r, t.emptyHash) {
  859. rStr = fmt.Sprintf("empty%v", nEmpties)
  860. eStr += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n",
  861. rStr)
  862. nEmpties++
  863. }
  864. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", hex.EncodeToString(k[:nChars]),
  865. lStr, rStr)
  866. fmt.Fprint(w, eStr)
  867. default:
  868. }
  869. return false
  870. })
  871. fmt.Fprintf(w, "}\n")
  872. return err
  873. }
  874. // PrintGraphviz prints the output of Tree.Graphviz
  875. func (t *Tree) PrintGraphviz(rootKey []byte) error {
  876. if rootKey == nil {
  877. rootKey = t.Root()
  878. }
  879. return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels)
  880. }
  881. // PrintGraphvizFirstNLevels prints the output of Tree.GraphvizFirstNLevels
  882. func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error {
  883. if rootKey == nil {
  884. rootKey = t.Root()
  885. }
  886. w := bytes.NewBufferString("")
  887. fmt.Fprintf(w,
  888. "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n")
  889. err := t.GraphvizFirstNLevels(w, rootKey, untilLvl)
  890. if err != nil {
  891. fmt.Println(w)
  892. return err
  893. }
  894. fmt.Fprintf(w,
  895. "End of Graphviz of the Tree with Root "+hex.EncodeToString(rootKey)+"\n--------\n")
  896. fmt.Println(w)
  897. return nil
  898. }
  899. // TODO circom proofs
  900. // TODO data structure for proofs (including root, key, value, siblings,
  901. // hashFunction) + method to verify that data structure