You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

489 lines
13 KiB

3 years ago
  1. /*
  2. Package arbo implements a Merkle Tree compatible with the circomlib
  3. implementation of the MerkleTree (when using the Poseidon hash function),
  4. following the specification from
  5. https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
  6. https://eprint.iacr.org/2018/955.
  7. Also allows to define which hash function to use. So for example, when working
  8. with zkSnarks the Poseidon hash function can be used, but when not, it can be
  9. used the Blake3 hash function, which improves the computation time.
  10. */
  11. package arbo
  12. import (
  13. "bytes"
  14. "fmt"
  15. "math"
  16. "sync/atomic"
  17. "time"
  18. "github.com/iden3/go-merkletree/db"
  19. )
  20. const (
  21. // PrefixValueLen defines the bytes-prefix length used for the Value
  22. // bytes representation stored in the db
  23. PrefixValueLen = 2
  24. // PrefixValueEmpty is used for the first byte of a Value to indicate
  25. // that is an Empty value
  26. PrefixValueEmpty = 0
  27. // PrefixValueLeaf is used for the first byte of a Value to indicate
  28. // that is a Leaf value
  29. PrefixValueLeaf = 1
  30. // PrefixValueIntermediate is used for the first byte of a Value to
  31. // indicate that is a Intermediate value
  32. PrefixValueIntermediate = 2
  33. )
  34. var (
  35. dbKeyRoot = []byte("root")
  36. emptyValue = []byte{0}
  37. )
  38. // Tree defines the struct that implements the MerkleTree functionalities
  39. type Tree struct {
  40. db db.Storage
  41. lastAccess int64 // in unix time
  42. maxLevels int
  43. root []byte
  44. hashFunction HashFunction
  45. }
  46. // NewTree returns a new Tree, if there is a Tree still in the given storage, it
  47. // will load it.
  48. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
  49. t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
  50. t.updateAccessTime()
  51. root, err := t.db.Get(dbKeyRoot)
  52. if err == db.ErrNotFound {
  53. // store new root 0
  54. tx, err := t.db.NewTx()
  55. if err != nil {
  56. return nil, err
  57. }
  58. t.root = make([]byte, t.hashFunction.Len()) // empty
  59. err = tx.Put(dbKeyRoot, t.root)
  60. if err != nil {
  61. return nil, err
  62. }
  63. err = tx.Commit()
  64. if err != nil {
  65. return nil, err
  66. }
  67. return &t, err
  68. } else if err != nil {
  69. return nil, err
  70. }
  71. t.root = root
  72. return &t, nil
  73. }
  74. func (t *Tree) updateAccessTime() {
  75. atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
  76. }
  77. // LastAccess returns the last access timestamp in Unixtime
  78. func (t *Tree) LastAccess() int64 {
  79. return atomic.LoadInt64(&t.lastAccess)
  80. }
  81. // Root returns the root of the Tree
  82. func (t *Tree) Root() []byte {
  83. return t.root
  84. }
  85. // AddBatch adds a batch of key-values to the Tree. This method is optimized to
  86. // do some internal parallelization. Returns an array containing the indexes of
  87. // the keys failed to add.
  88. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
  89. return nil, fmt.Errorf("unimplemented")
  90. }
  91. // Add inserts the key-value into the Tree.
  92. // If the inputs come from a *big.Int, is expected that are represented by a
  93. // Little-Endian byte array (for circom compatibility).
  94. func (t *Tree) Add(k, v []byte) error {
  95. // TODO check validity of key & value (for the Tree.HashFunction type)
  96. keyPath := make([]byte, t.hashFunction.Len())
  97. copy(keyPath[:], k)
  98. path := getPath(t.maxLevels, keyPath)
  99. // go down to the leaf
  100. var siblings [][]byte
  101. _, _, siblings, err := t.down(k, t.root, siblings, path, 0, false)
  102. if err != nil {
  103. return err
  104. }
  105. leafKey, leafValue, err := newLeafValue(t.hashFunction, k, v)
  106. if err != nil {
  107. return err
  108. }
  109. tx, err := t.db.NewTx()
  110. if err != nil {
  111. return err
  112. }
  113. if err := tx.Put(leafKey, leafValue); err != nil {
  114. return err
  115. }
  116. // go up to the root
  117. if len(siblings) == 0 {
  118. t.root = leafKey
  119. return tx.Commit()
  120. }
  121. root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
  122. if err != nil {
  123. return err
  124. }
  125. t.root = root
  126. // store root to db
  127. return tx.Commit()
  128. }
  129. // down goes down to the leaf recursively
  130. func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int, getLeaf bool) (
  131. []byte, []byte, [][]byte, error) {
  132. if l > t.maxLevels-1 {
  133. return nil, nil, nil, fmt.Errorf("max level")
  134. }
  135. var err error
  136. var currValue []byte
  137. emptyKey := make([]byte, t.hashFunction.Len())
  138. if bytes.Equal(currKey, emptyKey) {
  139. // empty value
  140. return currKey, emptyValue, siblings, nil
  141. }
  142. currValue, err = t.db.Get(currKey)
  143. if err != nil {
  144. return nil, nil, nil, err
  145. }
  146. switch currValue[0] {
  147. case PrefixValueEmpty: // empty
  148. // TODO WIP WARNING should not be reached, as the 'if' above should avoid
  149. // reaching this point
  150. // return currKey, empty, siblings, nil
  151. panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
  152. case PrefixValueLeaf: // leaf
  153. if bytes.Equal(newKey, currKey) {
  154. return nil, nil, nil, fmt.Errorf("key already exists")
  155. }
  156. if !bytes.Equal(currValue, emptyValue) {
  157. if getLeaf {
  158. return currKey, currValue, siblings, nil
  159. }
  160. oldLeafKey, _ := readLeafValue(currValue)
  161. oldLeafKeyFull := make([]byte, t.hashFunction.Len())
  162. copy(oldLeafKeyFull[:], oldLeafKey)
  163. // if currKey is already used, go down until paths diverge
  164. oldPath := getPath(t.maxLevels, oldLeafKeyFull)
  165. siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
  166. if err != nil {
  167. return nil, nil, nil, err
  168. }
  169. }
  170. return currKey, currValue, siblings, nil
  171. case PrefixValueIntermediate: // intermediate
  172. if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
  173. return nil, nil, nil,
  174. fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
  175. PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
  176. }
  177. // collect siblings while going down
  178. if path[l] {
  179. // right
  180. lChild, rChild := readIntermediateChilds(currValue)
  181. siblings = append(siblings, lChild)
  182. return t.down(newKey, rChild, siblings, path, l+1, getLeaf)
  183. }
  184. // left
  185. lChild, rChild := readIntermediateChilds(currValue)
  186. siblings = append(siblings, rChild)
  187. return t.down(newKey, lChild, siblings, path, l+1, getLeaf)
  188. default:
  189. return nil, nil, nil, fmt.Errorf("invalid value")
  190. }
  191. }
  192. // downVirtually is used when in a leaf already exists, and a new leaf which
  193. // shares the path until the existing leaf is being added
  194. func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
  195. newPath []bool, l int) ([][]byte, error) {
  196. var err error
  197. if l > t.maxLevels-1 {
  198. return nil, fmt.Errorf("max virtual level %d", l)
  199. }
  200. if oldPath[l] == newPath[l] {
  201. emptyKey := make([]byte, t.hashFunction.Len()) // empty
  202. siblings = append(siblings, emptyKey)
  203. siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
  204. if err != nil {
  205. return nil, err
  206. }
  207. return siblings, nil
  208. }
  209. // reached the divergence
  210. siblings = append(siblings, oldKey)
  211. return siblings, nil
  212. }
  213. // up goes up recursively updating the intermediate nodes
  214. func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
  215. var k, v []byte
  216. var err error
  217. if path[l] {
  218. k, v, err = newIntermediate(t.hashFunction, siblings[l], key)
  219. if err != nil {
  220. return nil, err
  221. }
  222. } else {
  223. k, v, err = newIntermediate(t.hashFunction, key, siblings[l])
  224. if err != nil {
  225. return nil, err
  226. }
  227. }
  228. // store k-v to db
  229. err = tx.Put(k, v)
  230. if err != nil {
  231. return nil, err
  232. }
  233. if l == 0 {
  234. // reached the root
  235. return k, nil
  236. }
  237. return t.up(tx, k, siblings, path, l-1)
  238. }
  239. func newLeafValue(hashFunc HashFunction, k, v []byte) ([]byte, []byte, error) {
  240. leafKey, err := hashFunc.Hash(k, v, []byte{1})
  241. if err != nil {
  242. return nil, nil, err
  243. }
  244. var leafValue []byte
  245. leafValue = append(leafValue, byte(1))
  246. leafValue = append(leafValue, byte(len(k)))
  247. leafValue = append(leafValue, k...)
  248. leafValue = append(leafValue, v...)
  249. return leafKey, leafValue, nil
  250. }
  251. func readLeafValue(b []byte) ([]byte, []byte) {
  252. if len(b) < PrefixValueLen {
  253. return []byte{}, []byte{}
  254. }
  255. kLen := b[1]
  256. if len(b) < PrefixValueLen+int(kLen) {
  257. return []byte{}, []byte{}
  258. }
  259. k := b[PrefixValueLen : PrefixValueLen+kLen]
  260. v := b[PrefixValueLen+kLen:]
  261. return k, v
  262. }
  263. func newIntermediate(hashFunc HashFunction, l, r []byte) ([]byte, []byte, error) {
  264. b := make([]byte, PrefixValueLen+hashFunc.Len()*2)
  265. b[0] = 2
  266. b[1] = byte(len(l))
  267. copy(b[PrefixValueLen:PrefixValueLen+hashFunc.Len()], l)
  268. copy(b[PrefixValueLen+hashFunc.Len():], r)
  269. key, err := hashFunc.Hash(l, r)
  270. if err != nil {
  271. return nil, nil, err
  272. }
  273. return key, b, nil
  274. }
  275. func readIntermediateChilds(b []byte) ([]byte, []byte) {
  276. if len(b) < PrefixValueLen {
  277. return []byte{}, []byte{}
  278. }
  279. lLen := b[1]
  280. if len(b) < PrefixValueLen+int(lLen) {
  281. return []byte{}, []byte{}
  282. }
  283. l := b[PrefixValueLen : PrefixValueLen+lLen]
  284. r := b[PrefixValueLen+lLen:]
  285. return l, r
  286. }
  287. func getPath(numLevels int, k []byte) []bool {
  288. path := make([]bool, numLevels)
  289. for n := 0; n < numLevels; n++ {
  290. path[n] = k[n/8]&(1<<(n%8)) != 0
  291. }
  292. return path
  293. }
  294. // GenProof generates a MerkleTree proof for the given key. If the key exists in
  295. // the Tree, the proof will be of existence, if the key does not exist in the
  296. // tree, the proof will be of non-existence.
  297. func (t *Tree) GenProof(k []byte) ([]byte, error) {
  298. keyPath := make([]byte, t.hashFunction.Len())
  299. copy(keyPath[:], k)
  300. path := getPath(t.maxLevels, keyPath)
  301. // go down to the leaf
  302. var siblings [][]byte
  303. _, value, siblings, err := t.down(k, t.root, siblings, path, 0, true)
  304. if err != nil {
  305. return nil, err
  306. }
  307. leafK, leafV := readLeafValue(value)
  308. if !bytes.Equal(k, leafK) {
  309. fmt.Println("key not in Tree")
  310. fmt.Println(leafK)
  311. fmt.Println(leafV)
  312. // TODO proof of non-existence
  313. panic(fmt.Errorf("unimplemented"))
  314. }
  315. s := PackSiblings(t.hashFunction, siblings)
  316. return s, nil
  317. }
  318. // PackSiblings packs the siblings into a byte array.
  319. // [ 1 byte | L bytes | 32 * N bytes ]
  320. // [ bitmap length (L) | bitmap | N non-zero siblings ]
  321. // Where the bitmap indicates if the sibling is 0 or a value from the siblings array.
  322. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
  323. var b []byte
  324. var bitmap []bool
  325. emptySibling := make([]byte, hashFunc.Len())
  326. for i := 0; i < len(siblings); i++ {
  327. if bytes.Equal(siblings[i], emptySibling) {
  328. bitmap = append(bitmap, false)
  329. } else {
  330. bitmap = append(bitmap, true)
  331. b = append(b, siblings[i]...)
  332. }
  333. }
  334. bitmapBytes := bitmapToBytes(bitmap)
  335. l := len(bitmapBytes)
  336. res := make([]byte, l+1+len(b))
  337. res[0] = byte(l) // set the bitmapBytes length
  338. copy(res[1:1+l], bitmapBytes)
  339. copy(res[1+l:], b)
  340. return res
  341. }
  342. // UnpackSiblings unpacks the siblings from a byte array.
  343. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
  344. l := b[0]
  345. bitmapBytes := b[1 : 1+l]
  346. bitmap := bytesToBitmap(bitmapBytes)
  347. siblingsBytes := b[1+l:]
  348. iSibl := 0
  349. emptySibl := make([]byte, hashFunc.Len())
  350. var siblings [][]byte
  351. for i := 0; i < len(bitmap); i++ {
  352. if iSibl >= len(siblingsBytes) {
  353. break
  354. }
  355. if bitmap[i] {
  356. siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()])
  357. iSibl += hashFunc.Len()
  358. } else {
  359. siblings = append(siblings, emptySibl)
  360. }
  361. }
  362. return siblings, nil
  363. }
  364. func bitmapToBytes(bitmap []bool) []byte {
  365. bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd
  366. b := make([]byte, bitmapBytesLen)
  367. for i := 0; i < len(bitmap); i++ {
  368. if bitmap[i] {
  369. b[i/8] |= 1 << (i % 8)
  370. }
  371. }
  372. return b
  373. }
  374. func bytesToBitmap(b []byte) []bool {
  375. var bitmap []bool
  376. for i := 0; i < len(b); i++ {
  377. for j := 0; j < 8; j++ {
  378. bitmap = append(bitmap, b[i]&(1<<j) > 0)
  379. }
  380. }
  381. return bitmap
  382. }
  383. // Get returns the value for a given key
  384. func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
  385. keyPath := make([]byte, t.hashFunction.Len())
  386. copy(keyPath[:], k)
  387. path := getPath(t.maxLevels, keyPath)
  388. // go down to the leaf
  389. var siblings [][]byte
  390. _, value, _, err := t.down(k, t.root, siblings, path, 0, true)
  391. if err != nil {
  392. return nil, nil, err
  393. }
  394. leafK, leafV := readLeafValue(value)
  395. if !bytes.Equal(k, leafK) {
  396. panic(fmt.Errorf("%s != %s", BytesToBigInt(k), BytesToBigInt(leafK)))
  397. }
  398. return leafK, leafV, nil
  399. }
  400. // CheckProof verifies the given proof. The proof verification depends on the
  401. // HashFunction passed as parameter.
  402. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
  403. siblings, err := UnpackSiblings(hashFunc, packedSiblings)
  404. if err != nil {
  405. return false, err
  406. }
  407. keyPath := make([]byte, hashFunc.Len())
  408. copy(keyPath[:], k)
  409. key, _, err := newLeafValue(hashFunc, k, v)
  410. if err != nil {
  411. return false, err
  412. }
  413. path := getPath(len(siblings), keyPath)
  414. for i := len(siblings) - 1; i >= 0; i-- {
  415. if path[i] {
  416. key, _, err = newIntermediate(hashFunc, siblings[i], key)
  417. if err != nil {
  418. return false, err
  419. }
  420. } else {
  421. key, _, err = newIntermediate(hashFunc, key, siblings[i])
  422. if err != nil {
  423. return false, err
  424. }
  425. }
  426. }
  427. if bytes.Equal(key[:], root) {
  428. return true, nil
  429. }
  430. return false, nil
  431. }