You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

535 lines
13 KiB

3 years ago
3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "fmt"
  15. "math"
  16. "sync/atomic"
  17. "time"
  18. "github.com/iden3/go-merkletree/db"
  19. )
  20. const (
  21. // PrefixValueLen defines the bytes-prefix length used for the Value
  22. // bytes representation stored in the db
  23. PrefixValueLen = 2
  24. // PrefixValueEmpty is used for the first byte of a Value to indicate
  25. // that is an Empty value
  26. PrefixValueEmpty = 0
  27. // PrefixValueLeaf is used for the first byte of a Value to indicate
  28. // that is a Leaf value
  29. PrefixValueLeaf = 1
  30. // PrefixValueIntermediate is used for the first byte of a Value to
  31. // indicate that is a Intermediate value
  32. PrefixValueIntermediate = 2
  33. )
  34. var (
  35. dbKeyRoot = []byte("root")
  36. emptyValue = []byte{0}
  37. )
  38. // Tree defines the struct that implements the MerkleTree functionalities
  39. type Tree struct {
  40. db db.Storage
  41. lastAccess int64 // in unix time
  42. maxLevels int
  43. root []byte
  44. hashFunction HashFunction
  45. }
  46. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  47. // will load it.
  48. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  49. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  50. t.updateAccessTime()
  51. root, err := t.dbGet(nil, dbKeyRoot)
  52. if err == db.ErrNotFound {
  53. // store new root 0
  54. tx, err := t.db.NewTx()
  55. if err != nil {
  56. return nil, err
  57. }
  58. t.root = make([]byte, t.hashFunction.Len()) // empty
  59. if err = tx.Put(dbKeyRoot, t.root); err != nil {
  60. return nil, err
  61. }
  62. if err = tx.Commit(); err != nil {
  63. return nil, err
  64. }
  65. return &t, err
  66. } else if err != nil {
  67. return nil, err
  68. }
  69. t.root = root
  70. return &t, nil
  71. }
  72. func (t *Tree) updateAccessTime() {
  73. atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
  74. }
  75. // LastAccess returns the last access timestamp in Unixtime
  76. func (t *Tree) LastAccess() int64 {
  77. return atomic.LoadInt64(&t.lastAccess)
  78. }
  79. // Root returns the root of the Tree
  80. func (t *Tree) Root() []byte {
  81. return t.root
  82. }
  83. // AddBatch adds a batch of key-values to the Tree. This method will be
  84. // optimized to do some internal parallelization. Returns an array containing
  85. // the indexes of the keys failed to add.
  86. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  87. if len(keys) != len(values) {
  88. return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
  89. len(keys), len(values))
  90. }
  91. tx, err := t.db.NewTx()
  92. if err != nil {
  93. return nil, err
  94. }
  95. var indexes []int
  96. for i := 0; i < len(keys); i++ {
  97. tx, err = t.add(tx, keys[i], values[i])
  98. if err != nil {
  99. indexes = append(indexes, i)
  100. }
  101. }
  102. // store root to db
  103. if err := tx.Put(dbKeyRoot, t.root); err != nil {
  104. return indexes, err
  105. }
  106. if err := tx.Commit(); err != nil {
  107. return nil, err
  108. }
  109. return indexes, nil
  110. }
  111. // Add inserts the key-value into the Tree. If the inputs come from a *big.Int,
  112. // is expected that are represented by a Little-Endian byte array (for circom
  113. // compatibility).
  114. func (t *Tree) Add(k, v []byte) error {
  115. tx, err := t.db.NewTx()
  116. if err != nil {
  117. return err
  118. }
  119. tx, err = t.add(tx, k, v)
  120. if err != nil {
  121. return err
  122. }
  123. // store root to db
  124. if err := tx.Put(dbKeyRoot, t.root); err != nil {
  125. return err
  126. }
  127. return tx.Commit()
  128. }
  129. func (t *Tree) add(tx db.Tx, k, v []byte) (db.Tx, error) {
  130. // TODO check validity of key & value (for the Tree.HashFunction type)
  131. keyPath := make([]byte, t.hashFunction.Len())
  132. copy(keyPath[:], k)
  133. path := getPath(t.maxLevels, keyPath)
  134. // go down to the leaf
  135. var siblings [][]byte
  136. _, _, siblings, err := t.down(tx, k, t.root, siblings, path, 0, false)
  137. if err != nil {
  138. return tx, err
  139. }
  140. leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
  141. if err != nil {
  142. return tx, err
  143. }
  144. if err := tx.Put(leafKey, leafValue); err != nil {
  145. return tx, err
  146. }
  147. // go up to the root
  148. if len(siblings) == 0 {
  149. t.root = leafKey
  150. return tx, nil
  151. }
  152. root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
  153. if err != nil {
  154. return tx, err
  155. }
  156. t.root = root
  157. return tx, nil
  158. }
  159. // down goes down to the leaf recursively
  160. func (t *Tree) down(tx db.Tx, newKey, currKey []byte, siblings [][]byte,
  161. path []bool, l int, getLeaf bool) (
  162. []byte, []byte, [][]byte, error) {
  163. if l > t.maxLevels-1 {
  164. return nil, nil, nil, fmt.Errorf("max level")
  165. }
  166. var err error
  167. var currValue []byte
  168. emptyKey := make([]byte, t.hashFunction.Len())
  169. if bytes.Equal(currKey, emptyKey) {
  170. // empty value
  171. return currKey, emptyValue, siblings, nil
  172. }
  173. currValue, err = t.dbGet(tx, currKey)
  174. if err != nil {
  175. return nil, nil, nil, err
  176. }
  177. switch currValue[0] {
  178. case PrefixValueEmpty: // empty
  179. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  180. // reaching this point
  181. // return currKey, empty, siblings, nil
  182. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  183. case PrefixValueLeaf: // leaf
  184. if bytes.Equal(newKey, currKey) {
  185. return nil, nil, nil, fmt.Errorf("key already exists")
  186. }
  187. if !bytes.Equal(currValue, emptyValue) {
  188. if getLeaf {
  189. return currKey, currValue, siblings, nil
  190. }
  191. oldLeafKey, _ := readLeafValue(currValue)
  192. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  193. copy(oldLeafKeyFull[:], oldLeafKey)
  194. // if currKey is already used, go down until paths diverge
  195. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  196. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
  197. if err != nil {
  198. return nil, nil, nil, err
  199. }
  200. }
  201. return currKey, currValue, siblings, nil
  202. case PrefixValueIntermediate: // intermediate
  203. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  204. return nil, nil, nil,
  205. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  206. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  207. }
  208. // collect siblings while going down
  209. if path[l] {
  210. // right
  211. lChild, rChild := readIntermediateChilds(currValue)
  212. siblings = append(siblings, lChild)
  213. return t.down(tx, newKey, rChild, siblings, path, l+1, getLeaf)
  214. }
  215. // left
  216. lChild, rChild := readIntermediateChilds(currValue)
  217. siblings = append(siblings, rChild)
  218. return t.down(tx, newKey, lChild, siblings, path, l+1, getLeaf)
  219. default:
  220. return nil, nil, nil, fmt.Errorf("invalid value")
  221. }
  222. }
  223. // downVirtually is used when in a leaf already exists, and a new leaf which
  224. // shares the path until the existing leaf is being added
  225. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  226. newPath []bool, l int) ([][]byte, error) {
  227. var err error
  228. if l > t.maxLevels-1 {
  229. return nil, fmt.Errorf("max virtual level %d", l)
  230. }
  231. if oldPath[l] == newPath[l] {
  232. emptyKey := make([]byte, t.hashFunction.Len())
  233. siblings = append(siblings, emptyKey)
  234. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
  235. if err != nil {
  236. return nil, err
  237. }
  238. return siblings, nil
  239. }
  240. // reached the divergence
  241. siblings = append(siblings, oldKey)
  242. return siblings, nil
  243. }
  244. // up goes up recursively updating the intermediate nodes
  245. func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
  246. var k, v []byte
  247. var err error
  248. if path[l] {
  249. k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
  250. if err != nil {
  251. return nil, err
  252. }
  253. } else {
  254. k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
  255. if err != nil {
  256. return nil, err
  257. }
  258. }
  259. // store k-v to db
  260. if err = tx.Put(k, v); err != nil {
  261. return nil, err
  262. }
  263. if l == 0 {
  264. // reached the root
  265. return k, nil
  266. }
  267. return t.up(tx, k, siblings, path, l-1)
  268. }
  269. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  270. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  271. if err != nil {
  272. return nil, nil, err
  273. }
  274. var leafValue []byte
  275. leafValue = append(leafValue, byte(1))
  276. leafValue = append(leafValue, byte(len(k)))
  277. leafValue = append(leafValue, k...)
  278. leafValue = append(leafValue, v...)
  279. return leafKey, leafValue, nil
  280. }
  281. func readLeafValue(b []byte) ([]byte, []byte) {
  282. if len(b) < PrefixValueLen {
  283. return []byte{}, []byte{}
  284. }
  285. kLen := b[1]
  286. if len(b) < PrefixValueLen+int(kLen) {
  287. return []byte{}, []byte{}
  288. }
  289. k := b[PrefixValueLen : PrefixValueLen+kLen]
  290. v := b[PrefixValueLen+kLen:]
  291. return k, v
  292. }
  293. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  294. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  295. b[0] = 2
  296. b[1] = byte(len(l))
  297. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  298. copy(b[PrefixValueLen+hashFunc.Len():], r)
  299. key, err := hashFunc.Hash(l, r)
  300. if err != nil {
  301. return nil, nil, err
  302. }
  303. return key, b, nil
  304. }
  305. func readIntermediateChilds(b []byte) ([]byte, []byte) {
  306. if len(b) < PrefixValueLen {
  307. return []byte{}, []byte{}
  308. }
  309. lLen := b[1]
  310. if len(b) < PrefixValueLen+int(lLen) {
  311. return []byte{}, []byte{}
  312. }
  313. l := b[PrefixValueLen : PrefixValueLen+lLen]
  314. r := b[PrefixValueLen+lLen:]
  315. return l, r
  316. }
  317. func getPath(numLevels int, k []byte) []bool {
  318. path := make([]bool, numLevels)
  319. for n := 0; n < numLevels; n++ {
  320. path[n] = k[n/8]&(1<<(n%8)) != 0
  321. }
  322. return path
  323. }
  324. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  325. // the Tree, the proof will be of existence, if the key does not exist in the
  326. // tree, the proof will be of non-existence.
  327. func (t *Tree) GenProof(k []byte) ([]byte, error) {
  328. keyPath := make([]byte, t.hashFunction.Len())
  329. copy(keyPath[:], k)
  330. path := getPath(t.maxLevels, keyPath)
  331. // go down to the leaf
  332. var siblings [][]byte
  333. _, value, siblings, err := t.down(nil, k, t.root, siblings, path, 0, true)
  334. if err != nil {
  335. return nil, err
  336. }
  337. leafK, leafV := readLeafValue(value)
  338. if !bytes.Equal(k, leafK) {
  339. fmt.Println("key not in Tree")
  340. fmt.Println(leafK)
  341. fmt.Println(leafV)
  342. // TODO proof of non-existence
  343. panic(fmt.Errorf("unimplemented"))
  344. }
  345. s := PackSiblings(t.hashFunction, siblings)
  346. return s, nil
  347. }
  348. // PackSiblings packs the siblings into a byte array.
  349. // [ 1 byte | L bytes | 32 * N bytes ]
  350. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  351. // Where the bitmap indicates if the sibling is 0 or a value from the siblings array.
  352. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  353. var b []byte
  354. var bitmap []bool
  355. emptySibling := make([]byte, hashFunc.Len())
  356. for i := 0; i < len(siblings); i++ {
  357. if bytes.Equal(siblings[i], emptySibling) {
  358. bitmap = append(bitmap, false)
  359. } else {
  360. bitmap = append(bitmap, true)
  361. b = append(b, siblings[i]...)
  362. }
  363. }
  364. bitmapBytes := bitmapToBytes(bitmap)
  365. l := len(bitmapBytes)
  366. res := make([]byte, l+1+len(b))
  367. res[0] = byte(l) // set the bitmapBytes length
  368. copy(res[1:1+l], bitmapBytes)
  369. copy(res[1+l:], b)
  370. return res
  371. }
  372. // UnpackSiblings unpacks the siblings from a byte array.
  373. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  374. l := b[0]
  375. bitmapBytes := b[1 : 1+l]
  376. bitmap := bytesToBitmap(bitmapBytes)
  377. siblingsBytes := b[1+l:]
  378. iSibl := 0
  379. emptySibl := make([]byte, hashFunc.Len())
  380. var siblings [][]byte
  381. for i := 0; i < len(bitmap); i++ {
  382. if iSibl >= len(siblingsBytes) {
  383. break
  384. }
  385. if bitmap[i] {
  386. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  387. iSibl += hashFunc.Len()
  388. } else {
  389. siblings = append(siblings, emptySibl)
  390. }
  391. }
  392. return siblings, nil
  393. }
  394. func bitmapToBytes(bitmap []bool) []byte {
  395. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  396. b := make([]byte, bitmapBytesLen)
  397. for i := 0; i < len(bitmap); i++ {
  398. if bitmap[i] {
  399. b[i/8] |= 1 << (i % 8)
  400. }
  401. }
  402. return b
  403. }
  404. func bytesToBitmap(b []byte) []bool {
  405. var bitmap []bool
  406. for i := 0; i < len(b); i++ {
  407. for j := 0; j < 8; j++ {
  408. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  409. }
  410. }
  411. return bitmap
  412. }
  413. // Get returns the value for a given key
  414. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  415. keyPath := make([]byte, t.hashFunction.Len())
  416. copy(keyPath[:], k)
  417. path := getPath(t.maxLevels, keyPath)
  418. // go down to the leaf
  419. var siblings [][]byte
  420. _, value, _, err := t.down(nil, k, t.root, siblings, path, 0, true)
  421. if err != nil {
  422. return nil, nil, err
  423. }
  424. leafK, leafV := readLeafValue(value)
  425. if !bytes.Equal(k, leafK) {
  426. panic(fmt.Errorf("%s != %s", BytesToBigInt(k), BytesToBigInt(leafK)))
  427. }
  428. return leafK, leafV, nil
  429. }
  430. // CheckProof verifies the given proof. The proof verification depends on the
  431. // HashFunction passed as parameter.
  432. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  433. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  434. if err != nil {
  435. return false, err
  436. }
  437. keyPath := make([]byte, hashFunc.Len())
  438. copy(keyPath[:], k)
  439. key, _, err := newLeafValue(hashFunc, k, v)
  440. if err != nil {
  441. return false, err
  442. }
  443. path := getPath(len(siblings), keyPath)
  444. for i := len(siblings) - 1; i >= 0; i-- {
  445. if path[i] {
  446. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  447. if err != nil {
  448. return false, err
  449. }
  450. } else {
  451. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  452. if err != nil {
  453. return false, err
  454. }
  455. }
  456. }
  457. if bytes.Equal(key[:], root) {
  458. return true, nil
  459. }
  460. return false, nil
  461. }
  462. func (t *Tree) dbGet(tx db.Tx, k []byte) ([]byte, error) {
  463. v, err := t.db.Get(k)
  464. if err == nil {
  465. return v, nil
  466. }
  467. if tx != nil {
  468. return tx.Get(k)
  469. }
  470. return nil, db.ErrNotFound
  471. }