You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

630 lines
16 KiB

3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "fmt"
  15. "io"
  16. "math"
  17. "sync/atomic"
  18. "time"
  19. "github.com/iden3/go-merkletree/db"
  20. )
  21. const (
  22. // PrefixValueLen defines the bytes-prefix length used for the Value
  23. // bytes representation stored in the db
  24. PrefixValueLen = 2
  25. // PrefixValueEmpty is used for the first byte of a Value to indicate
  26. // that is an Empty value
  27. PrefixValueEmpty = 0
  28. // PrefixValueLeaf is used for the first byte of a Value to indicate
  29. // that is a Leaf value
  30. PrefixValueLeaf = 1
  31. // PrefixValueIntermediate is used for the first byte of a Value to
  32. // indicate that is a Intermediate value
  33. PrefixValueIntermediate = 2
  34. )
  35. var (
  36. dbKeyRoot = []byte("root")
  37. emptyValue = []byte{0}
  38. )
  39. // Tree defines the struct that implements the MerkleTree functionalities
  40. type Tree struct {
  41. db db.Storage
  42. lastAccess int64 // in unix time
  43. maxLevels int
  44. root []byte
  45. hashFunction HashFunction
  46. }
  47. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  48. // will load it.
  49. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  50. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  51. t.updateAccessTime()
  52. root, err := t.dbGet(nil, dbKeyRoot)
  53. if err == db.ErrNotFound {
  54. // store new root 0
  55. tx, err := t.db.NewTx()
  56. if err != nil {
  57. return nil, err
  58. }
  59. t.root = make([]byte, t.hashFunction.Len()) // empty
  60. if err = tx.Put(dbKeyRoot, t.root); err != nil {
  61. return nil, err
  62. }
  63. if err = tx.Commit(); err != nil {
  64. return nil, err
  65. }
  66. return &t, err
  67. } else if err != nil {
  68. return nil, err
  69. }
  70. t.root = root
  71. return &t, nil
  72. }
  73. func (t *Tree) updateAccessTime() {
  74. atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
  75. }
  76. // LastAccess returns the last access timestamp in Unixtime
  77. func (t *Tree) LastAccess() int64 {
  78. return atomic.LoadInt64(&t.lastAccess)
  79. }
  80. // Root returns the root of the Tree
  81. func (t *Tree) Root() []byte {
  82. return t.root
  83. }
  84. // AddBatch adds a batch of key-values to the Tree. This method will be
  85. // optimized to do some internal parallelization. Returns an array containing
  86. // the indexes of the keys failed to add.
  87. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  88. t.updateAccessTime()
  89. if len(keys) != len(values) {
  90. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  91. len(keys), len(values))
  92. }
  93. tx, err := t.db.NewTx()
  94. if err != nil {
  95. return nil, err
  96. }
  97. var indexes []int
  98. for i := 0; i < len(keys); i++ {
  99. tx, err = t.add(tx, keys[i], values[i])
  100. if err != nil {
  101. indexes = append(indexes, i)
  102. }
  103. }
  104. // store root to db
  105. if err := tx.Put(dbKeyRoot, t.root); err != nil {
  106. return indexes, err
  107. }
  108. if err := tx.Commit(); err != nil {
  109. return nil, err
  110. }
  111. return indexes, nil
  112. }
  113. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  114. // is expected that are represented by a Little-Endian byte array (for circom
  115. // compatibility).
  116. func (t *Tree) Add(k, v []byte) error {
  117. t.updateAccessTime()
  118. tx, err := t.db.NewTx()
  119. if err != nil {
  120. return err
  121. }
  122. tx, err = t.add(tx, k, v)
  123. if err != nil {
  124. return err
  125. }
  126. // store root to db
  127. if err := tx.Put(dbKeyRoot, t.root); err != nil {
  128. return err
  129. }
  130. return tx.Commit()
  131. }
  132. func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) {
  133. // TODO check validity of key & value (for the Tree.HashFunction type)
  134. keyPath := make([]byte, t.hashFunction.Len())
  135. copy(keyPath[:], k)
  136. path := getPath(t.maxLevels, keyPath)
  137. // go down to the leaf
  138. var siblings [][]byte
  139. _, _, siblings, err := t.down(tx, k, t.root, siblings, path, 0, false)
  140. if err != nil {
  141. return tx, err
  142. }
  143. leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
  144. if err != nil {
  145. return tx, err
  146. }
  147. if err := tx.Put(leafKey, leafValue); err != nil {
  148. return tx, err
  149. }
  150. // go up to the root
  151. if len(siblings) == 0 {
  152. t.root = leafKey
  153. return tx, nil
  154. }
  155. root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
  156. if err != nil {
  157. return tx, err
  158. }
  159. t.root = root
  160. return tx, nil
  161. }
  162. // down goes down to the leaf recursively
  163. func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
  164. path []bool, l int, getLeaf bool) (
  165. []byte, []byte, [][]byte, error) {
  166. if l > t.maxLevels-1 {
  167. return nil, nil, nil, fmt.Errorf("max level")
  168. }
  169. var err error
  170. var currValue []byte
  171. emptyKey := make([]byte, t.hashFunction.Len())
  172. if bytes.Equal(currKey, emptyKey) {
  173. // empty value
  174. return currKey, emptyValue, siblings, nil
  175. }
  176. currValue, err = t.dbGet(tx, currKey)
  177. if err != nil {
  178. return nil, nil, nil, err
  179. }
  180. switch currValue[0] {
  181. case PrefixValueEmpty: // empty
  182. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  183. // reaching this point
  184. // return currKey, empty, siblings, nil
  185. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  186. case PrefixValueLeaf: // leaf
  187. if bytes.Equal(newKey, currKey) {
  188. return nil, nil, nil, fmt.Errorf("key already exists")
  189. }
  190. if !bytes.Equal(currValue, emptyValue) {
  191. if getLeaf {
  192. return currKey, currValue, siblings, nil
  193. }
  194. oldLeafKey, _ := readLeafValue(currValue)
  195. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  196. copy(oldLeafKeyFull[:], oldLeafKey)
  197. // if currKey is already used, go down until paths diverge
  198. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  199. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
  200. if err != nil {
  201. return nil, nil, nil, err
  202. }
  203. }
  204. return currKey, currValue, siblings, nil
  205. case PrefixValueIntermediate: // intermediate
  206. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  207. return nil, nil, nil,
  208. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  209. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  210. }
  211. // collect siblings while going down
  212. if path[l] {
  213. // right
  214. lChild, rChild := readIntermediateChilds(currValue)
  215. siblings = append(siblings, lChild)
  216. return t.down(tx, newKey, rChild, siblings, path, l+1, getLeaf)
  217. }
  218. // left
  219. lChild, rChild := readIntermediateChilds(currValue)
  220. siblings = append(siblings, rChild)
  221. return t.down(tx, newKey, lChild, siblings, path, l+1, getLeaf)
  222. default:
  223. return nil, nil, nil, fmt.Errorf("invalid value")
  224. }
  225. }
  226. // downVirtually is used when in a leaf already exists, and a new leaf which
  227. // shares the path until the existing leaf is being added
  228. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  229. newPath []bool, l int) ([][]byte, error) {
  230. var err error
  231. if l > t.maxLevels-1 {
  232. return nil, fmt.Errorf("max virtual level %d", l)
  233. }
  234. if oldPath[l] == newPath[l] {
  235. emptyKey := make([]byte, t.hashFunction.Len())
  236. siblings = append(siblings, emptyKey)
  237. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
  238. if err != nil {
  239. return nil, err
  240. }
  241. return siblings, nil
  242. }
  243. // reached the divergence
  244. siblings = append(siblings, oldKey)
  245. return siblings, nil
  246. }
  247. // up goes up recursively updating the intermediate nodes
  248. func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
  249. var k, v []byte
  250. var err error
  251. if path[l] {
  252. k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
  253. if err != nil {
  254. return nil, err
  255. }
  256. } else {
  257. k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
  258. if err != nil {
  259. return nil, err
  260. }
  261. }
  262. // store k-v to db
  263. if err = tx.Put(k, v); err != nil {
  264. return nil, err
  265. }
  266. if l == 0 {
  267. // reached the root
  268. return k, nil
  269. }
  270. return t.up(tx, k, siblings, path, l-1)
  271. }
  272. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  273. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  274. if err != nil {
  275. return nil, nil, err
  276. }
  277. var leafValue []byte
  278. leafValue = append(leafValue, byte(1))
  279. leafValue = append(leafValue, byte(len(k)))
  280. leafValue = append(leafValue, k...)
  281. leafValue = append(leafValue, v...)
  282. return leafKey, leafValue, nil
  283. }
  284. func readLeafValue(b []byte) ([]byte, []byte) {
  285. if len(b) < PrefixValueLen {
  286. return []byte{}, []byte{}
  287. }
  288. kLen := b[1]
  289. if len(b) < PrefixValueLen+int(kLen) {
  290. return []byte{}, []byte{}
  291. }
  292. k := b[PrefixValueLen : PrefixValueLen+kLen]
  293. v := b[PrefixValueLen+kLen:]
  294. return k, v
  295. }
  296. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  297. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  298. b[0] = 2
  299. b[1] = byte(len(l))
  300. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  301. copy(b[PrefixValueLen+hashFunc.Len():], r)
  302. key, err := hashFunc.Hash(l, r)
  303. if err != nil {
  304. return nil, nil, err
  305. }
  306. return key, b, nil
  307. }
  308. func readIntermediateChilds(b []byte) ([]byte, []byte) {
  309. if len(b) < PrefixValueLen {
  310. return []byte{}, []byte{}
  311. }
  312. lLen := b[1]
  313. if len(b) < PrefixValueLen+int(lLen) {
  314. return []byte{}, []byte{}
  315. }
  316. l := b[PrefixValueLen : PrefixValueLen+lLen]
  317. r := b[PrefixValueLen+lLen:]
  318. return l, r
  319. }
  320. func getPath(numLevels int, k []byte) []bool {
  321. path := make([]bool, numLevels)
  322. for n := 0; n < numLevels; n++ {
  323. path[n] = k[n/8]&(1<<(n%8)) != 0
  324. }
  325. return path
  326. }
  327. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  328. // the Tree, the proof will be of existence, if the key does not exist in the
  329. // tree, the proof will be of non-existence.
  330. func (t *Tree) GenProof(k []byte) ([]byte, error) {
  331. t.updateAccessTime()
  332. keyPath := make([]byte, t.hashFunction.Len())
  333. copy(keyPath[:], k)
  334. path := getPath(t.maxLevels, keyPath)
  335. // go down to the leaf
  336. var siblings [][]byte
  337. _, value, siblings, err := t.down(nil, k, t.root, siblings, path, 0, true)
  338. if err != nil {
  339. return nil, err
  340. }
  341. leafK, leafV := readLeafValue(value)
  342. if !bytes.Equal(k, leafK) {
  343. fmt.Println("key not in Tree")
  344. fmt.Println(leafK)
  345. fmt.Println(leafV)
  346. // TODO proof of non-existence
  347. panic(fmt.Errorf("unimplemented"))
  348. }
  349. s := PackSiblings(t.hashFunction, siblings)
  350. return s, nil
  351. }
  352. // PackSiblings packs the siblings into a byte array.
  353. // [ 1 byte | L bytes | S * N bytes ]
  354. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  355. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  356. // array. And S is the size of the output of the hash function used for the
  357. // Tree.
  358. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  359. var b []byte
  360. var bitmap []bool
  361. emptySibling := make([]byte, hashFunc.Len())
  362. for i := 0; i < len(siblings); i++ {
  363. if bytes.Equal(siblings[i], emptySibling) {
  364. bitmap = append(bitmap, false)
  365. } else {
  366. bitmap = append(bitmap, true)
  367. b = append(b, siblings[i]...)
  368. }
  369. }
  370. bitmapBytes := bitmapToBytes(bitmap)
  371. l := len(bitmapBytes)
  372. res := make([]byte, l+1+len(b))
  373. res[0] = byte(l) // set the bitmapBytes length
  374. copy(res[1:1+l], bitmapBytes)
  375. copy(res[1+l:], b)
  376. return res
  377. }
  378. // UnpackSiblings unpacks the siblings from a byte array.
  379. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  380. l := b[0]
  381. bitmapBytes := b[1 : 1+l]
  382. bitmap := bytesToBitmap(bitmapBytes)
  383. siblingsBytes := b[1+l:]
  384. iSibl := 0
  385. emptySibl := make([]byte, hashFunc.Len())
  386. var siblings [][]byte
  387. for i := 0; i < len(bitmap); i++ {
  388. if iSibl >= len(siblingsBytes) {
  389. break
  390. }
  391. if bitmap[i] {
  392. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  393. iSibl += hashFunc.Len()
  394. } else {
  395. siblings = append(siblings, emptySibl)
  396. }
  397. }
  398. return siblings, nil
  399. }
  400. func bitmapToBytes(bitmap []bool) []byte {
  401. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  402. b := make([]byte, bitmapBytesLen)
  403. for i := 0; i < len(bitmap); i++ {
  404. if bitmap[i] {
  405. b[i/8] |= 1 << (i % 8)
  406. }
  407. }
  408. return b
  409. }
  410. func bytesToBitmap(b []byte) []bool {
  411. var bitmap []bool
  412. for i := 0; i < len(b); i++ {
  413. for j := 0; j < 8; j++ {
  414. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  415. }
  416. }
  417. return bitmap
  418. }
  419. // Get returns the value for a given key
  420. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  421. keyPath := make([]byte, t.hashFunction.Len())
  422. copy(keyPath[:], k)
  423. path := getPath(t.maxLevels, keyPath)
  424. // go down to the leaf
  425. var siblings [][]byte
  426. _, value, _, err := t.down(nil, k, t.root, siblings, path, 0, true)
  427. if err != nil {
  428. return nil, nil, err
  429. }
  430. leafK, leafV := readLeafValue(value)
  431. if !bytes.Equal(k, leafK) {
  432. panic(fmt.Errorf("%s != %s", BytesToBigInt(k), BytesToBigInt(leafK)))
  433. }
  434. return leafK, leafV, nil
  435. }
  436. // CheckProof verifies the given proof. The proof verification depends on the
  437. // HashFunction passed as parameter.
  438. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  439. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  440. if err != nil {
  441. return false, err
  442. }
  443. keyPath := make([]byte, hashFunc.Len())
  444. copy(keyPath[:], k)
  445. key, _, err := newLeafValue(hashFunc, k, v)
  446. if err != nil {
  447. return false, err
  448. }
  449. path := getPath(len(siblings), keyPath)
  450. for i := len(siblings) - 1; i >= 0; i-- {
  451. if path[i] {
  452. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  453. if err != nil {
  454. return false, err
  455. }
  456. } else {
  457. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  458. if err != nil {
  459. return false, err
  460. }
  461. }
  462. }
  463. if bytes.Equal(key[:], root) {
  464. return true, nil
  465. }
  466. return false, nil
  467. }
  468. func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) {
  469. v, err := t.db.Get(k)
  470. if err == nil {
  471. return v, nil
  472. }
  473. if tx != nil {
  474. return tx.Get(k)
  475. }
  476. return nil, db.ErrNotFound
  477. }
  478. // Iterate iterates through the full Tree, executing the given function on each
  479. // node of the Tree.
  480. func (t *Tree) Iterate(f func([]byte, []byte)) error {
  481. t.updateAccessTime()
  482. return t.iter(t.root, f)
  483. }
  484. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  485. v, err := t.dbGet(nil, k)
  486. if err != nil {
  487. return err
  488. }
  489. switch v[0] {
  490. case PrefixValueEmpty:
  491. f(k, v)
  492. case PrefixValueLeaf:
  493. f(k, v)
  494. case PrefixValueIntermediate:
  495. f(k, v)
  496. l, r := readIntermediateChilds(v)
  497. if err = t.iter(l, f); err != nil {
  498. return err
  499. }
  500. if err = t.iter(r, f); err != nil {
  501. return err
  502. }
  503. default:
  504. return fmt.Errorf("invalid value")
  505. }
  506. return nil
  507. }
  508. // Dump exports all the Tree leafs in a byte array of length:
  509. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  510. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  511. // [ len(k) | len(v) | key | value ]
  512. // Where S is the size of the output of the hash function used for the Tree.
  513. func (t *Tree) Dump() ([]byte, error) {
  514. t.updateAccessTime()
  515. // WARNING current encoding only supports key & values of 255 bytes each
  516. // (due using only 1 byte for the length headers).
  517. var b []byte
  518. err := t.Iterate(func(k, v []byte) {
  519. if v[0] != PrefixValueLeaf {
  520. return
  521. }
  522. leafK, leafV := readLeafValue(v)
  523. kv := make([]byte, 2+len(leafK)+len(leafV))
  524. kv[0] = byte(len(leafK))
  525. kv[1] = byte(len(leafV))
  526. copy(kv[2:2+len(leafK)], leafK)
  527. copy(kv[2+len(leafK):], leafV)
  528. b = append(b, kv...)
  529. })
  530. return b, err
  531. }
  532. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  533. // method) in the Tree.
  534. func (t *Tree) ImportDump(b []byte) error {
  535. t.updateAccessTime()
  536. r := bytes.NewReader(b)
  537. for {
  538. l := make([]byte, 2)
  539. _, err := io.ReadFull(r, l)
  540. if err == io.EOF {
  541. break
  542. } else if err != nil {
  543. return err
  544. }
  545. k := make([]byte, l[0])
  546. _, err = io.ReadFull(r, k)
  547. if err != nil {
  548. return err
  549. }
  550. v := make([]byte, l[1])
  551. _, err = io.ReadFull(r, v)
  552. if err != nil {
  553. return err
  554. }
  555. err = t.Add(k, v)
  556. if err != nil {
  557. return err
  558. }
  559. }
  560. return nil
  561. }