You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

624 lines
16 KiB

3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "fmt"
  15. "io"
  16. "math"
  17. "sync/atomic"
  18. "time"
  19. "github.com/iden3/go-merkletree/db"
  20. )
  21. const (
  22. // PrefixValueLen defines the bytes-prefix length used for the Value
  23. // bytes representation stored in the db
  24. PrefixValueLen = 2
  25. // PrefixValueEmpty is used for the first byte of a Value to indicate
  26. // that is an Empty value
  27. PrefixValueEmpty = 0
  28. // PrefixValueLeaf is used for the first byte of a Value to indicate
  29. // that is a Leaf value
  30. PrefixValueLeaf = 1
  31. // PrefixValueIntermediate is used for the first byte of a Value to
  32. // indicate that is a Intermediate value
  33. PrefixValueIntermediate = 2
  34. )
  35. var (
  36. dbKeyRoot = []byte("root")
  37. emptyValue = []byte{0}
  38. )
  39. // Tree defines the struct that implements the MerkleTree functionalities
  40. type Tree struct {
  41. db db.Storage
  42. lastAccess int64 // in unix time
  43. maxLevels int
  44. root []byte
  45. hashFunction HashFunction
  46. }
  47. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  48. // will load it.
  49. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  50. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  51. t.updateAccessTime()
  52. root, err := t.dbGet(nil, dbKeyRoot)
  53. if err == db.ErrNotFound {
  54. // store new root 0
  55. tx, err := t.db.NewTx()
  56. if err != nil {
  57. return nil, err
  58. }
  59. t.root = make([]byte, t.hashFunction.Len()) // empty
  60. if err = tx.Put(dbKeyRoot, t.root); err != nil {
  61. return nil, err
  62. }
  63. if err = tx.Commit(); err != nil {
  64. return nil, err
  65. }
  66. return &t, err
  67. } else if err != nil {
  68. return nil, err
  69. }
  70. t.root = root
  71. return &t, nil
  72. }
  73. func (t *Tree) updateAccessTime() {
  74. atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
  75. }
  76. // LastAccess returns the last access timestamp in Unixtime
  77. func (t *Tree) LastAccess() int64 {
  78. return atomic.LoadInt64(&t.lastAccess)
  79. }
  80. // Root returns the root of the Tree
  81. func (t *Tree) Root() []byte {
  82. return t.root
  83. }
  84. // AddBatch adds a batch of key-values to the Tree. This method will be
  85. // optimized to do some internal parallelization. Returns an array containing
  86. // the indexes of the keys failed to add.
  87. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  88. if len(keys) != len(values) {
  89. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  90. len(keys), len(values))
  91. }
  92. tx, err := t.db.NewTx()
  93. if err != nil {
  94. return nil, err
  95. }
  96. var indexes []int
  97. for i := 0; i < len(keys); i++ {
  98. tx, err = t.add(tx, keys[i], values[i])
  99. if err != nil {
  100. indexes = append(indexes, i)
  101. }
  102. }
  103. // store root to db
  104. if err := tx.Put(dbKeyRoot, t.root); err != nil {
  105. return indexes, err
  106. }
  107. if err := tx.Commit(); err != nil {
  108. return nil, err
  109. }
  110. return indexes, nil
  111. }
  112. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  113. // is expected that are represented by a Little-Endian byte array (for circom
  114. // compatibility).
  115. func (t *Tree) Add(k, v []byte) error {
  116. tx, err := t.db.NewTx()
  117. if err != nil {
  118. return err
  119. }
  120. tx, err = t.add(tx, k, v)
  121. if err != nil {
  122. return err
  123. }
  124. // store root to db
  125. if err := tx.Put(dbKeyRoot, t.root); err != nil {
  126. return err
  127. }
  128. return tx.Commit()
  129. }
  130. func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) {
  131. // TODO check validity of key & value (for the Tree.HashFunction type)
  132. keyPath := make([]byte, t.hashFunction.Len())
  133. copy(keyPath[:], k)
  134. path := getPath(t.maxLevels, keyPath)
  135. // go down to the leaf
  136. var siblings [][]byte
  137. _, _, siblings, err := t.down(tx, k, t.root, siblings, path, 0, false)
  138. if err != nil {
  139. return tx, err
  140. }
  141. leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
  142. if err != nil {
  143. return tx, err
  144. }
  145. if err := tx.Put(leafKey, leafValue); err != nil {
  146. return tx, err
  147. }
  148. // go up to the root
  149. if len(siblings) == 0 {
  150. t.root = leafKey
  151. return tx, nil
  152. }
  153. root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
  154. if err != nil {
  155. return tx, err
  156. }
  157. t.root = root
  158. return tx, nil
  159. }
  160. // down goes down to the leaf recursively
  161. func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
  162. path []bool, l int, getLeaf bool) (
  163. []byte, []byte, [][]byte, error) {
  164. if l > t.maxLevels-1 {
  165. return nil, nil, nil, fmt.Errorf("max level")
  166. }
  167. var err error
  168. var currValue []byte
  169. emptyKey := make([]byte, t.hashFunction.Len())
  170. if bytes.Equal(currKey, emptyKey) {
  171. // empty value
  172. return currKey, emptyValue, siblings, nil
  173. }
  174. currValue, err = t.dbGet(tx, currKey)
  175. if err != nil {
  176. return nil, nil, nil, err
  177. }
  178. switch currValue[0] {
  179. case PrefixValueEmpty: // empty
  180. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  181. // reaching this point
  182. // return currKey, empty, siblings, nil
  183. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  184. case PrefixValueLeaf: // leaf
  185. if bytes.Equal(newKey, currKey) {
  186. return nil, nil, nil, fmt.Errorf("key already exists")
  187. }
  188. if !bytes.Equal(currValue, emptyValue) {
  189. if getLeaf {
  190. return currKey, currValue, siblings, nil
  191. }
  192. oldLeafKey, _ := readLeafValue(currValue)
  193. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  194. copy(oldLeafKeyFull[:], oldLeafKey)
  195. // if currKey is already used, go down until paths diverge
  196. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  197. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
  198. if err != nil {
  199. return nil, nil, nil, err
  200. }
  201. }
  202. return currKey, currValue, siblings, nil
  203. case PrefixValueIntermediate: // intermediate
  204. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  205. return nil, nil, nil,
  206. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  207. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  208. }
  209. // collect siblings while going down
  210. if path[l] {
  211. // right
  212. lChild, rChild := readIntermediateChilds(currValue)
  213. siblings = append(siblings, lChild)
  214. return t.down(tx, newKey, rChild, siblings, path, l+1, getLeaf)
  215. }
  216. // left
  217. lChild, rChild := readIntermediateChilds(currValue)
  218. siblings = append(siblings, rChild)
  219. return t.down(tx, newKey, lChild, siblings, path, l+1, getLeaf)
  220. default:
  221. return nil, nil, nil, fmt.Errorf("invalid value")
  222. }
  223. }
  224. // downVirtually is used when in a leaf already exists, and a new leaf which
  225. // shares the path until the existing leaf is being added
  226. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  227. newPath []bool, l int) ([][]byte, error) {
  228. var err error
  229. if l > t.maxLevels-1 {
  230. return nil, fmt.Errorf("max virtual level %d", l)
  231. }
  232. if oldPath[l] == newPath[l] {
  233. emptyKey := make([]byte, t.hashFunction.Len())
  234. siblings = append(siblings, emptyKey)
  235. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
  236. if err != nil {
  237. return nil, err
  238. }
  239. return siblings, nil
  240. }
  241. // reached the divergence
  242. siblings = append(siblings, oldKey)
  243. return siblings, nil
  244. }
  245. // up goes up recursively updating the intermediate nodes
  246. func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
  247. var k, v []byte
  248. var err error
  249. if path[l] {
  250. k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
  251. if err != nil {
  252. return nil, err
  253. }
  254. } else {
  255. k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
  256. if err != nil {
  257. return nil, err
  258. }
  259. }
  260. // store k-v to db
  261. if err = tx.Put(k, v); err != nil {
  262. return nil, err
  263. }
  264. if l == 0 {
  265. // reached the root
  266. return k, nil
  267. }
  268. return t.up(tx, k, siblings, path, l-1)
  269. }
  270. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  271. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  272. if err != nil {
  273. return nil, nil, err
  274. }
  275. var leafValue []byte
  276. leafValue = append(leafValue, byte(1))
  277. leafValue = append(leafValue, byte(len(k)))
  278. leafValue = append(leafValue, k...)
  279. leafValue = append(leafValue, v...)
  280. return leafKey, leafValue, nil
  281. }
  282. func readLeafValue(b []byte) ([]byte, []byte) {
  283. if len(b) < PrefixValueLen {
  284. return []byte{}, []byte{}
  285. }
  286. kLen := b[1]
  287. if len(b) < PrefixValueLen+int(kLen) {
  288. return []byte{}, []byte{}
  289. }
  290. k := b[PrefixValueLen : PrefixValueLen+kLen]
  291. v := b[PrefixValueLen+kLen:]
  292. return k, v
  293. }
  294. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  295. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  296. b[0] = 2
  297. b[1] = byte(len(l))
  298. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  299. copy(b[PrefixValueLen+hashFunc.Len():], r)
  300. key, err := hashFunc.Hash(l, r)
  301. if err != nil {
  302. return nil, nil, err
  303. }
  304. return key, b, nil
  305. }
  306. func readIntermediateChilds(b []byte) ([]byte, []byte) {
  307. if len(b) < PrefixValueLen {
  308. return []byte{}, []byte{}
  309. }
  310. lLen := b[1]
  311. if len(b) < PrefixValueLen+int(lLen) {
  312. return []byte{}, []byte{}
  313. }
  314. l := b[PrefixValueLen : PrefixValueLen+lLen]
  315. r := b[PrefixValueLen+lLen:]
  316. return l, r
  317. }
  318. func getPath(numLevels int, k []byte) []bool {
  319. path := make([]bool, numLevels)
  320. for n := 0; n < numLevels; n++ {
  321. path[n] = k[n/8]&(1<<(n%8)) != 0
  322. }
  323. return path
  324. }
  325. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  326. // the Tree, the proof will be of existence, if the key does not exist in the
  327. // tree, the proof will be of non-existence.
  328. func (t *Tree) GenProof(k []byte) ([]byte, error) {
  329. keyPath := make([]byte, t.hashFunction.Len())
  330. copy(keyPath[:], k)
  331. path := getPath(t.maxLevels, keyPath)
  332. // go down to the leaf
  333. var siblings [][]byte
  334. _, value, siblings, err := t.down(nil, k, t.root, siblings, path, 0, true)
  335. if err != nil {
  336. return nil, err
  337. }
  338. leafK, leafV := readLeafValue(value)
  339. if !bytes.Equal(k, leafK) {
  340. fmt.Println("key not in Tree")
  341. fmt.Println(leafK)
  342. fmt.Println(leafV)
  343. // TODO proof of non-existence
  344. panic(fmt.Errorf("unimplemented"))
  345. }
  346. s := PackSiblings(t.hashFunction, siblings)
  347. return s, nil
  348. }
  349. // PackSiblings packs the siblings into a byte array.
  350. // [ 1 byte | L bytes | S * N bytes ]
  351. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  352. // Where the bitmap indicates if the sibling is 0 or a value from the siblings
  353. // array. And S is the size of the output of the hash function used for the
  354. // Tree.
  355. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  356. var b []byte
  357. var bitmap []bool
  358. emptySibling := make([]byte, hashFunc.Len())
  359. for i := 0; i < len(siblings); i++ {
  360. if bytes.Equal(siblings[i], emptySibling) {
  361. bitmap = append(bitmap, false)
  362. } else {
  363. bitmap = append(bitmap, true)
  364. b = append(b, siblings[i]...)
  365. }
  366. }
  367. bitmapBytes := bitmapToBytes(bitmap)
  368. l := len(bitmapBytes)
  369. res := make([]byte, l+1+len(b))
  370. res[0] = byte(l) // set the bitmapBytes length
  371. copy(res[1:1+l], bitmapBytes)
  372. copy(res[1+l:], b)
  373. return res
  374. }
  375. // UnpackSiblings unpacks the siblings from a byte array.
  376. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  377. l := b[0]
  378. bitmapBytes := b[1 : 1+l]
  379. bitmap := bytesToBitmap(bitmapBytes)
  380. siblingsBytes := b[1+l:]
  381. iSibl := 0
  382. emptySibl := make([]byte, hashFunc.Len())
  383. var siblings [][]byte
  384. for i := 0; i < len(bitmap); i++ {
  385. if iSibl >= len(siblingsBytes) {
  386. break
  387. }
  388. if bitmap[i] {
  389. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  390. iSibl += hashFunc.Len()
  391. } else {
  392. siblings = append(siblings, emptySibl)
  393. }
  394. }
  395. return siblings, nil
  396. }
  397. func bitmapToBytes(bitmap []bool) []byte {
  398. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  399. b := make([]byte, bitmapBytesLen)
  400. for i := 0; i < len(bitmap); i++ {
  401. if bitmap[i] {
  402. b[i/8] |= 1 << (i % 8)
  403. }
  404. }
  405. return b
  406. }
  407. func bytesToBitmap(b []byte) []bool {
  408. var bitmap []bool
  409. for i := 0; i < len(b); i++ {
  410. for j := 0; j < 8; j++ {
  411. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  412. }
  413. }
  414. return bitmap
  415. }
  416. // Get returns the value for a given key
  417. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  418. keyPath := make([]byte, t.hashFunction.Len())
  419. copy(keyPath[:], k)
  420. path := getPath(t.maxLevels, keyPath)
  421. // go down to the leaf
  422. var siblings [][]byte
  423. _, value, _, err := t.down(nil, k, t.root, siblings, path, 0, true)
  424. if err != nil {
  425. return nil, nil, err
  426. }
  427. leafK, leafV := readLeafValue(value)
  428. if !bytes.Equal(k, leafK) {
  429. panic(fmt.Errorf("%s != %s", BytesToBigInt(k), BytesToBigInt(leafK)))
  430. }
  431. return leafK, leafV, nil
  432. }
  433. // CheckProof verifies the given proof. The proof verification depends on the
  434. // HashFunction passed as parameter.
  435. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  436. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  437. if err != nil {
  438. return false, err
  439. }
  440. keyPath := make([]byte, hashFunc.Len())
  441. copy(keyPath[:], k)
  442. key, _, err := newLeafValue(hashFunc, k, v)
  443. if err != nil {
  444. return false, err
  445. }
  446. path := getPath(len(siblings), keyPath)
  447. for i := len(siblings) - 1; i >= 0; i-- {
  448. if path[i] {
  449. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  450. if err != nil {
  451. return false, err
  452. }
  453. } else {
  454. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  455. if err != nil {
  456. return false, err
  457. }
  458. }
  459. }
  460. if bytes.Equal(key[:], root) {
  461. return true, nil
  462. }
  463. return false, nil
  464. }
  465. func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) {
  466. v, err := t.db.Get(k)
  467. if err == nil {
  468. return v, nil
  469. }
  470. if tx != nil {
  471. return tx.Get(k)
  472. }
  473. return nil, db.ErrNotFound
  474. }
  475. // Iterate iterates through the full Tree, executing the given function on each
  476. // node of the Tree.
  477. func (t *Tree) Iterate(f func([]byte, []byte)) error {
  478. return t.iter(t.root, f)
  479. }
  480. func (t *Tree) iter(k []byte, f func([]byte, []byte)) error {
  481. v, err := t.dbGet(nil, k)
  482. if err != nil {
  483. return err
  484. }
  485. switch v[0] {
  486. case PrefixValueEmpty:
  487. f(k, v)
  488. case PrefixValueLeaf:
  489. f(k, v)
  490. case PrefixValueIntermediate:
  491. f(k, v)
  492. l, r := readIntermediateChilds(v)
  493. if err = t.iter(l, f); err != nil {
  494. return err
  495. }
  496. if err = t.iter(r, f); err != nil {
  497. return err
  498. }
  499. default:
  500. return fmt.Errorf("invalid value")
  501. }
  502. return nil
  503. }
  504. // Dump exports all the Tree leafs in a byte array of length:
  505. // [ N * (2+len(k+v)) ]. Where N is the number of key-values, and for each k+v:
  506. // [ 1 byte | 1 byte | S bytes | len(v) bytes ]
  507. // [ len(k) | len(v) | key | value ]
  508. // Where S is the size of the output of the hash function used for the Tree.
  509. func (t *Tree) Dump() ([]byte, error) {
  510. // WARNING current encoding only supports key & values of 255 bytes each
  511. // (due using only 1 byte for the length headers).
  512. var b []byte
  513. err := t.Iterate(func(k, v []byte) {
  514. if v[0] != PrefixValueLeaf {
  515. return
  516. }
  517. leafK, leafV := readLeafValue(v)
  518. kv := make([]byte, 2+len(leafK)+len(leafV))
  519. kv[0] = byte(len(leafK))
  520. kv[1] = byte(len(leafV))
  521. copy(kv[2:2+len(leafK)], leafK)
  522. copy(kv[2+len(leafK):], leafV)
  523. b = append(b, kv...)
  524. })
  525. return b, err
  526. }
  527. // ImportDump imports the leafs (that have been exported with the ExportLeafs
  528. // method) in the Tree.
  529. func (t *Tree) ImportDump(b []byte) error {
  530. r := bytes.NewReader(b)
  531. for {
  532. l := make([]byte, 2)
  533. _, err := io.ReadFull(r, l)
  534. if err == io.EOF {
  535. break
  536. } else if err != nil {
  537. return err
  538. }
  539. k := make([]byte, l[0])
  540. _, err = io.ReadFull(r, k)
  541. if err != nil {
  542. return err
  543. }
  544. v := make([]byte, l[1])
  545. _, err = io.ReadFull(r, v)
  546. if err != nil {
  547. return err
  548. }
  549. err = t.Add(k, v)
  550. if err != nil {
  551. return err
  552. }
  553. }
  554. return nil
  555. }