You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1095 lines
30 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "sync"
  10. "github.com/iden3/go-iden3-core/common"
  11. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  12. "github.com/iden3/go-merkletree/db"
  13. )
  14. const (
  15. // proofFlagsLen is the byte length of the flags in the proof header (first 32
  16. // bytes).
  17. proofFlagsLen = 2
  18. // ElemBytesLen is the length of the Hash byte array
  19. ElemBytesLen = 32
  20. )
  21. var (
  22. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  23. ErrNodeKeyAlreadyExists = errors.New("key already exists")
  24. // ErrKeyNotFound is used when a key is not found in the MerkleTree.
  25. ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
  26. // ErrNodeBytesBadSize is used when the data of a node has an incorrect
  27. // size and can't be parsed.
  28. ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
  29. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  30. // maximum level.
  31. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  32. // ErrInvalidNodeFound is used when an invalid node is found and can't
  33. // be parsed.
  34. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  35. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  36. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  37. // ErrInvalidDBValue is used when a value in the key value DB is
  38. // invalid (for example, it doen't contain a byte header and a []byte
  39. // body of at least len=1.
  40. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  41. // ErrEntryIndexAlreadyExists is used when the entry index already
  42. // exists in the tree.
  43. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  44. // ErrNotWritable is used when the MerkleTree is not writable and a write function is called
  45. ErrNotWritable = errors.New("Merkle Tree not writable")
  46. rootNodeValue = []byte("currentroot")
  47. // HashZero is used at Empty nodes
  48. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  49. )
  50. // Hash is the generic type stored in the MerkleTree
  51. type Hash [32]byte
  52. // String returns decimal representation in string format of the Hash
  53. func (h Hash) String() string {
  54. s := h.BigInt().String()
  55. if len(s) < 8 {
  56. return s
  57. }
  58. return s[0:8] + "..."
  59. }
  60. // Hex returns the hexadecimal representation of the Hash
  61. func (h Hash) Hex() string {
  62. return hex.EncodeToString(h.BigInt().Bytes())
  63. }
  64. // BigInt returns the *big.Int representation of the *Hash
  65. func (h *Hash) BigInt() *big.Int {
  66. if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
  67. return big.NewInt(0)
  68. }
  69. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  70. }
  71. // Bytes returns the []byte representation of the *Hash
  72. func (h *Hash) Bytes() []byte {
  73. bi := new(big.Int).SetBytes(common.SwapEndianness(h[:])).Bytes()
  74. b := [32]byte{}
  75. copy(b[:], bi[:])
  76. return b[:]
  77. }
  78. // NewBigIntFromBytes returns a *big.Int from a byte array, swapping the endianness in the process. This is the intended method to get a *big.Int from a byte array that previously has ben generated by the Hash.Bytes() method.
  79. func NewBigIntFromBytes(b []byte) (*big.Int, error) {
  80. if len(b) != 32 {
  81. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  82. }
  83. return new(big.Int).SetBytes(common.SwapEndianness(b[:32])), nil
  84. }
  85. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  86. func NewHashFromBigInt(b *big.Int) *Hash {
  87. r := &Hash{}
  88. copy(r[:], common.SwapEndianness(b.Bytes()))
  89. return r
  90. }
  91. // MerkleTree is the struct with the main elements of the MerkleTree
  92. type MerkleTree struct {
  93. sync.RWMutex
  94. db db.Storage
  95. rootKey *Hash
  96. writable bool
  97. maxLevels int
  98. }
  99. // NewMerkleTree loads a new Merkletree. If in the sotrage already exists one will open that one, if not, will create a new one.
  100. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  101. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  102. v, err := mt.db.Get(rootNodeValue)
  103. if err != nil {
  104. tx, err := mt.db.NewTx()
  105. if err != nil {
  106. return nil, err
  107. }
  108. mt.rootKey = &HashZero
  109. tx.Put(rootNodeValue, mt.rootKey[:])
  110. err = tx.Commit()
  111. if err != nil {
  112. return nil, err
  113. }
  114. return &mt, nil
  115. }
  116. mt.rootKey = &Hash{}
  117. copy(mt.rootKey[:], v)
  118. return &mt, nil
  119. }
  120. // DB returns the MerkleTree.DB()
  121. func (mt *MerkleTree) DB() db.Storage {
  122. return mt.db
  123. }
  124. // Root returns the MerkleRoot
  125. func (mt *MerkleTree) Root() *Hash {
  126. return mt.rootKey
  127. }
  128. // MaxLevels returns the MT maximum level
  129. func (mt *MerkleTree) MaxLevels() int {
  130. return mt.maxLevels
  131. }
  132. // Snapshot returns a read-only copy of the MerkleTree
  133. func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
  134. mt.RLock()
  135. defer mt.RUnlock()
  136. _, err := mt.GetNode(rootKey)
  137. if err != nil {
  138. return nil, err
  139. }
  140. return &MerkleTree{db: mt.db, maxLevels: mt.maxLevels, rootKey: rootKey, writable: false}, nil
  141. }
  142. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the path from the Root to the Leaf.
  143. func (mt *MerkleTree) Add(k, v *big.Int) error {
  144. // verify that the MerkleTree is writable
  145. if !mt.writable {
  146. return ErrNotWritable
  147. }
  148. // verfy that k & v are valid and fit inside the Finite Field.
  149. if !cryptoUtils.CheckBigIntInField(k) {
  150. return errors.New("Key not inside the Finite Field")
  151. }
  152. if !cryptoUtils.CheckBigIntInField(v) {
  153. return errors.New("Value not inside the Finite Field")
  154. }
  155. tx, err := mt.db.NewTx()
  156. if err != nil {
  157. return err
  158. }
  159. mt.Lock()
  160. defer mt.Unlock()
  161. kHash := NewHashFromBigInt(k)
  162. vHash := NewHashFromBigInt(v)
  163. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  164. path := getPath(mt.maxLevels, kHash[:])
  165. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  166. if err != nil {
  167. return err
  168. }
  169. mt.rootKey = newRootKey
  170. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  171. if err := tx.Commit(); err != nil {
  172. return err
  173. }
  174. return nil
  175. }
  176. // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
  177. func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
  178. var cp CircomProcessorProof
  179. cp.Fnc = 2
  180. cp.OldRoot = mt.rootKey
  181. gettedK, gettedV, siblings, err := mt.Get(k)
  182. if err != nil && err != ErrKeyNotFound {
  183. return nil, err
  184. }
  185. cp.OldKey = NewHashFromBigInt(gettedK)
  186. cp.OldValue = NewHashFromBigInt(gettedV)
  187. if bytes.Equal(cp.OldKey[:], HashZero[:]) {
  188. cp.IsOld0 = true
  189. }
  190. _, _, siblings, err = mt.Get(k)
  191. if err != nil && err != ErrKeyNotFound {
  192. return nil, err
  193. }
  194. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  195. err = mt.Add(k, v)
  196. if err != nil {
  197. return nil, err
  198. }
  199. cp.NewKey = NewHashFromBigInt(k)
  200. cp.NewValue = NewHashFromBigInt(v)
  201. cp.NewRoot = mt.rootKey
  202. return &cp, nil
  203. }
  204. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  205. // from newLeaf, at which point both leafs are stored, all while updating the
  206. // path.
  207. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  208. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  209. if lvl > mt.maxLevels-2 {
  210. return nil, ErrReachedMaxLevel
  211. }
  212. var newNodeMiddle *Node
  213. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  214. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  215. if err != nil {
  216. return nil, err
  217. }
  218. if pathNewLeaf[lvl] {
  219. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  220. } else {
  221. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  222. }
  223. return mt.addNode(tx, newNodeMiddle)
  224. } else {
  225. oldLeafKey, err := oldLeaf.Key()
  226. if err != nil {
  227. return nil, err
  228. }
  229. newLeafKey, err := newLeaf.Key()
  230. if err != nil {
  231. return nil, err
  232. }
  233. if pathNewLeaf[lvl] {
  234. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  235. } else {
  236. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  237. }
  238. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  239. _, err = mt.addNode(tx, newLeaf)
  240. if err != nil {
  241. return nil, err
  242. }
  243. return mt.addNode(tx, newNodeMiddle)
  244. }
  245. }
  246. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  247. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  248. lvl int, path []bool) (*Hash, error) {
  249. var err error
  250. var nextKey *Hash
  251. if lvl > mt.maxLevels-1 {
  252. return nil, ErrReachedMaxLevel
  253. }
  254. n, err := mt.GetNode(key)
  255. if err != nil {
  256. return nil, err
  257. }
  258. switch n.Type {
  259. case NodeTypeEmpty:
  260. // We can add newLeaf now
  261. return mt.addNode(tx, newLeaf)
  262. case NodeTypeLeaf:
  263. nKey := n.Entry[0]
  264. // Check if leaf node found contains the leaf node we are trying to add
  265. newLeafKey := newLeaf.Entry[0]
  266. if bytes.Equal(nKey[:], newLeafKey[:]) {
  267. return nil, ErrEntryIndexAlreadyExists
  268. }
  269. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  270. // We need to push newLeaf down until its path diverges from n's path
  271. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  272. case NodeTypeMiddle:
  273. // We need to go deeper, continue traversing the tree, left or
  274. // right depending on path
  275. var newNodeMiddle *Node
  276. if path[lvl] {
  277. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  278. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  279. } else {
  280. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  281. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  282. }
  283. if err != nil {
  284. return nil, err
  285. }
  286. // Update the node to reflect the modified child
  287. return mt.addNode(tx, newNodeMiddle)
  288. default:
  289. return nil, ErrInvalidNodeFound
  290. }
  291. }
  292. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  293. // they are all the same and assumed to always exist.
  294. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  295. // verify that the MerkleTree is writable
  296. if !mt.writable {
  297. return nil, ErrNotWritable
  298. }
  299. if n.Type == NodeTypeEmpty {
  300. return n.Key()
  301. }
  302. k, err := n.Key()
  303. if err != nil {
  304. return nil, err
  305. }
  306. v := n.Value()
  307. // Check that the node key doesn't already exist
  308. if _, err := tx.Get(k[:]); err == nil {
  309. return nil, ErrNodeKeyAlreadyExists
  310. }
  311. tx.Put(k[:], v)
  312. return k, nil
  313. }
  314. // Get returns the value of the leaf for the given key
  315. func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
  316. // verfy that k is valid and fit inside the Finite Field.
  317. if !cryptoUtils.CheckBigIntInField(k) {
  318. return nil, nil, nil, errors.New("Key not inside the Finite Field")
  319. }
  320. kHash := NewHashFromBigInt(k)
  321. path := getPath(mt.maxLevels, kHash[:])
  322. nextKey := mt.rootKey
  323. var siblings []*Hash
  324. for i := 0; i < mt.maxLevels; i++ {
  325. n, err := mt.GetNode(nextKey)
  326. if err != nil {
  327. return nil, nil, nil, err
  328. }
  329. switch n.Type {
  330. case NodeTypeEmpty:
  331. return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
  332. case NodeTypeLeaf:
  333. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  334. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
  335. } else {
  336. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
  337. }
  338. case NodeTypeMiddle:
  339. if path[i] {
  340. nextKey = n.ChildR
  341. siblings = append(siblings, n.ChildL)
  342. } else {
  343. nextKey = n.ChildL
  344. siblings = append(siblings, n.ChildR)
  345. }
  346. default:
  347. return nil, nil, nil, ErrInvalidNodeFound
  348. }
  349. }
  350. return nil, nil, nil, ErrKeyNotFound
  351. }
  352. // Update updates the value of a specified key in the MerkleTree, and updates
  353. // the path from the leaf to the Root with the new values. Returns the
  354. // CircomProcessorProof.
  355. func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
  356. // verify that the MerkleTree is writable
  357. if !mt.writable {
  358. return nil, ErrNotWritable
  359. }
  360. // verfy that k & are valid and fit inside the Finite Field.
  361. if !cryptoUtils.CheckBigIntInField(k) {
  362. return nil, errors.New("Key not inside the Finite Field")
  363. }
  364. if !cryptoUtils.CheckBigIntInField(v) {
  365. return nil, errors.New("Key not inside the Finite Field")
  366. }
  367. tx, err := mt.db.NewTx()
  368. if err != nil {
  369. return nil, err
  370. }
  371. mt.Lock()
  372. defer mt.Unlock()
  373. kHash := NewHashFromBigInt(k)
  374. vHash := NewHashFromBigInt(v)
  375. path := getPath(mt.maxLevels, kHash[:])
  376. var cp CircomProcessorProof
  377. cp.Fnc = 1
  378. cp.OldRoot = mt.rootKey
  379. cp.OldKey = kHash
  380. cp.NewKey = kHash
  381. cp.NewValue = vHash
  382. nextKey := mt.rootKey
  383. var siblings []*Hash
  384. for i := 0; i < mt.maxLevels; i++ {
  385. n, err := mt.GetNode(nextKey)
  386. if err != nil {
  387. return nil, err
  388. }
  389. switch n.Type {
  390. case NodeTypeEmpty:
  391. return nil, ErrKeyNotFound
  392. case NodeTypeLeaf:
  393. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  394. cp.OldValue = n.Entry[1]
  395. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  396. // update leaf and upload to the root
  397. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  398. _, err := mt.addNode(tx, newNodeLeaf)
  399. if err != nil {
  400. return nil, err
  401. }
  402. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
  403. if err != nil {
  404. return nil, err
  405. }
  406. mt.rootKey = newRootKey
  407. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  408. cp.NewRoot = newRootKey
  409. if err := tx.Commit(); err != nil {
  410. return nil, err
  411. }
  412. return &cp, nil
  413. } else {
  414. return nil, ErrKeyNotFound
  415. }
  416. case NodeTypeMiddle:
  417. if path[i] {
  418. nextKey = n.ChildR
  419. siblings = append(siblings, n.ChildL)
  420. } else {
  421. nextKey = n.ChildL
  422. siblings = append(siblings, n.ChildR)
  423. }
  424. default:
  425. return nil, ErrInvalidNodeFound
  426. }
  427. }
  428. return nil, ErrKeyNotFound
  429. }
  430. // Delete removes the specified Key from the MerkleTree and updates the path
  431. // from the deleted key to the Root with the new values. This method removes
  432. // the key from the MerkleTree, but does not remove the old nodes from the
  433. // key-value database; this means that if the tree is accessed by an old Root
  434. // where the key was not deleted yet, the key will still exist. If is desired
  435. // to remove the key-values from the database that are not under the current
  436. // Root, an option could be to dump all the leafs (using mt.DumpLeafs) and
  437. // import them in a new MerkleTree in a new database (using
  438. // mt.ImportDumpedLeafs), but this will loose all the Root history of the
  439. // MerkleTree
  440. func (mt *MerkleTree) Delete(k *big.Int) error {
  441. // verify that the MerkleTree is writable
  442. if !mt.writable {
  443. return ErrNotWritable
  444. }
  445. // verfy that k is valid and fit inside the Finite Field.
  446. if !cryptoUtils.CheckBigIntInField(k) {
  447. return errors.New("Key not inside the Finite Field")
  448. }
  449. tx, err := mt.db.NewTx()
  450. if err != nil {
  451. return err
  452. }
  453. mt.Lock()
  454. defer mt.Unlock()
  455. kHash := NewHashFromBigInt(k)
  456. path := getPath(mt.maxLevels, kHash[:])
  457. nextKey := mt.rootKey
  458. var siblings []*Hash
  459. for i := 0; i < mt.maxLevels; i++ {
  460. n, err := mt.GetNode(nextKey)
  461. if err != nil {
  462. return err
  463. }
  464. switch n.Type {
  465. case NodeTypeEmpty:
  466. return ErrKeyNotFound
  467. case NodeTypeLeaf:
  468. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  469. // remove and go up with the sibling
  470. err = mt.rmAndUpload(tx, path, kHash, siblings)
  471. return err
  472. } else {
  473. return ErrKeyNotFound
  474. }
  475. case NodeTypeMiddle:
  476. if path[i] {
  477. nextKey = n.ChildR
  478. siblings = append(siblings, n.ChildL)
  479. } else {
  480. nextKey = n.ChildL
  481. siblings = append(siblings, n.ChildR)
  482. }
  483. default:
  484. return ErrInvalidNodeFound
  485. }
  486. }
  487. return ErrKeyNotFound
  488. }
  489. // rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values.
  490. func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
  491. if len(siblings) == 0 {
  492. mt.rootKey = &HashZero
  493. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  494. return tx.Commit()
  495. }
  496. toUpload := siblings[len(siblings)-1]
  497. if len(siblings) < 2 {
  498. mt.rootKey = siblings[0]
  499. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  500. return tx.Commit()
  501. }
  502. for i := len(siblings) - 2; i >= 0; i-- {
  503. if !bytes.Equal(siblings[i][:], HashZero[:]) {
  504. var newNode *Node
  505. if path[i] {
  506. newNode = NewNodeMiddle(siblings[i], toUpload)
  507. } else {
  508. newNode = NewNodeMiddle(toUpload, siblings[i])
  509. }
  510. _, err := mt.addNode(tx, newNode)
  511. if err != ErrNodeKeyAlreadyExists && err != nil {
  512. return err
  513. }
  514. // go up until the root
  515. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode, siblings[:i])
  516. if err != nil {
  517. return err
  518. }
  519. mt.rootKey = newRootKey
  520. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  521. break
  522. }
  523. // if i==0 (root position), stop and store the sibling of the deleted leaf as root
  524. if i == 0 {
  525. mt.rootKey = toUpload
  526. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  527. break
  528. }
  529. }
  530. if err := tx.Commit(); err != nil {
  531. return err
  532. }
  533. return nil
  534. }
  535. // recalculatePathUntilRoot recalculates the nodes until the Root
  536. func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node, siblings []*Hash) (*Hash, error) {
  537. for i := len(siblings) - 1; i >= 0; i-- {
  538. nodeKey, err := node.Key()
  539. if err != nil {
  540. return nil, err
  541. }
  542. if path[i] {
  543. node = NewNodeMiddle(siblings[i], nodeKey)
  544. } else {
  545. node = NewNodeMiddle(nodeKey, siblings[i])
  546. }
  547. _, err = mt.addNode(tx, node)
  548. if err != ErrNodeKeyAlreadyExists && err != nil {
  549. return nil, err
  550. }
  551. }
  552. // return last node added, which is the root
  553. nodeKey, err := node.Key()
  554. return nodeKey, err
  555. }
  556. // dbGet is a helper function to get the node of a key from the internal
  557. // storage.
  558. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
  559. if bytes.Equal(k, HashZero[:]) {
  560. return 0, nil, nil
  561. }
  562. value, err := mt.db.Get(k)
  563. if err != nil {
  564. return 0, nil, err
  565. }
  566. if len(value) < 2 {
  567. return 0, nil, ErrInvalidDBValue
  568. }
  569. nodeType := value[0]
  570. nodeBytes := value[1:]
  571. return NodeType(nodeType), nodeBytes, nil
  572. }
  573. // dbInsert is a helper function to insert a node into a key in an open db
  574. // transaction.
  575. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
  576. v := append([]byte{byte(t)}, data...)
  577. tx.Put(k, v)
  578. }
  579. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  580. // tree; they are all the same and assumed to always exist.
  581. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  582. if bytes.Equal(key[:], HashZero[:]) {
  583. return NewNodeEmpty(), nil
  584. }
  585. nBytes, err := mt.db.Get(key[:])
  586. if err != nil {
  587. return nil, err
  588. }
  589. return NewNodeFromBytes(nBytes)
  590. }
  591. // getPath returns the binary path, from the root to the leaf.
  592. func getPath(numLevels int, k []byte) []bool {
  593. path := make([]bool, numLevels)
  594. for n := 0; n < numLevels; n++ {
  595. path[n] = common.TestBit(k[:], uint(n))
  596. }
  597. return path
  598. }
  599. // NodeAux contains the auxiliary node used in a non-existence proof.
  600. type NodeAux struct {
  601. Key *Hash
  602. Value *Hash
  603. }
  604. // Proof defines the required elements for a MT proof of existence or non-existence.
  605. type Proof struct {
  606. // existence indicates wether this is a proof of existence or non-existence.
  607. Existence bool
  608. // depth indicates how deep in the tree the proof goes.
  609. depth uint
  610. // notempties is a bitmap of non-empty Siblings found in Siblings.
  611. notempties [ElemBytesLen - proofFlagsLen]byte
  612. // Siblings is a list of non-empty sibling keys.
  613. Siblings []*Hash
  614. NodeAux *NodeAux
  615. }
  616. // NewProofFromBytes parses a byte array into a Proof.
  617. func NewProofFromBytes(bs []byte) (*Proof, error) {
  618. if len(bs) < ElemBytesLen {
  619. return nil, ErrInvalidProofBytes
  620. }
  621. p := &Proof{}
  622. if (bs[0] & 0x01) == 0 {
  623. p.Existence = true
  624. }
  625. p.depth = uint(bs[1])
  626. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  627. siblingBytes := bs[ElemBytesLen:]
  628. sibIdx := 0
  629. for i := uint(0); i < p.depth; i++ {
  630. if common.TestBitBigEndian(p.notempties[:], i) {
  631. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  632. return nil, ErrInvalidProofBytes
  633. }
  634. var sib Hash
  635. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  636. p.Siblings = append(p.Siblings, &sib)
  637. sibIdx++
  638. }
  639. }
  640. if !p.Existence && ((bs[0] & 0x02) != 0) {
  641. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  642. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  643. if len(nodeAuxBytes) != 2*ElemBytesLen {
  644. return nil, ErrInvalidProofBytes
  645. }
  646. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  647. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  648. }
  649. return p, nil
  650. }
  651. // Bytes serializes a Proof into a byte array.
  652. func (p *Proof) Bytes() []byte {
  653. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  654. if p.NodeAux != nil {
  655. bsLen += 2 * ElemBytesLen
  656. }
  657. bs := make([]byte, bsLen)
  658. if !p.Existence {
  659. bs[0] |= 0x01
  660. }
  661. bs[1] = byte(p.depth)
  662. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  663. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  664. for i, k := range p.Siblings {
  665. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  666. }
  667. if p.NodeAux != nil {
  668. bs[0] |= 0x02
  669. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  670. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  671. }
  672. return bs
  673. }
  674. // SiblingsFromProof returns all the siblings of the proof.
  675. func SiblingsFromProof(proof *Proof) []*Hash {
  676. sibIdx := 0
  677. var siblings []*Hash
  678. for lvl := 0; lvl < int(proof.depth); lvl++ {
  679. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  680. siblings = append(siblings, proof.Siblings[sibIdx])
  681. sibIdx++
  682. } else {
  683. siblings = append(siblings, &HashZero)
  684. }
  685. }
  686. return siblings
  687. }
  688. // AllSiblings returns all the siblings of the proof.
  689. func (p *Proof) AllSiblings() []*Hash {
  690. return SiblingsFromProof(p)
  691. }
  692. // CircomSiblingsFromSiblings returns the full siblings compatible with circom
  693. func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
  694. // Add the rest of empty levels to the siblings
  695. for i := len(siblings); i < levels; i++ {
  696. siblings = append(siblings, &HashZero)
  697. }
  698. siblings = append(siblings, &HashZero) // add extra level for circom compatibility
  699. return siblings
  700. }
  701. // AllSiblingsCircom returns all the siblings of the proof. This function is
  702. // used to generate the siblings input for the circom circuits.
  703. func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
  704. siblings := p.AllSiblings()
  705. // Add the rest of empty levels to the siblings
  706. for i := len(siblings); i < levels; i++ {
  707. siblings = append(siblings, &HashZero)
  708. }
  709. siblings = append(siblings, &HashZero) // add extra level for circom compatibility
  710. siblingsBigInt := make([]*big.Int, len(siblings))
  711. for i, sibling := range siblings {
  712. siblingsBigInt[i] = sibling.BigInt()
  713. }
  714. return siblingsBigInt
  715. }
  716. // CircomProcessorProof defines the ProcessorProof compatible with circom. Is
  717. // the data of the proof between the transition from one state to another.
  718. type CircomProcessorProof struct {
  719. OldRoot *Hash
  720. NewRoot *Hash
  721. Siblings []*Hash
  722. OldKey *Hash
  723. OldValue *Hash
  724. NewKey *Hash
  725. NewValue *Hash
  726. IsOld0 bool
  727. Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete
  728. }
  729. // String returns a human readable string representation of the
  730. // CircomProcessorProof
  731. func (p CircomProcessorProof) String() string {
  732. buf := bytes.NewBufferString("{")
  733. fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
  734. fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
  735. fmt.Fprintf(buf, " Siblings: [\n ")
  736. for _, s := range p.Siblings {
  737. fmt.Fprintf(buf, "%v, ", s)
  738. }
  739. fmt.Fprintf(buf, "\n ],\n")
  740. fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
  741. fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
  742. fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
  743. fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
  744. fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
  745. fmt.Fprintf(buf, "}\n")
  746. return buf.String()
  747. }
  748. // CircomVerifierProof defines the VerifierProof compatible with circom. Is the
  749. // data of the proof that a certain leaf exists in the MerkleTree.
  750. type CircomVerifierProof struct {
  751. Root *Hash
  752. Siblings []*big.Int
  753. OldKey *Hash
  754. OldValue *Hash
  755. IsOld0 bool
  756. Key *Hash
  757. Value *Hash
  758. Fnc int // 0: inclusion, 1: non inclusion
  759. }
  760. // GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
  761. // key in the MerkleTree. If the rootKey is nil, the current merkletree root
  762. // is used.
  763. func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int, rootKey *Hash) (*CircomVerifierProof, error) {
  764. p, v, err := mt.GenerateProof(k, rootKey)
  765. if err != nil || err != ErrKeyNotFound {
  766. return nil, err
  767. }
  768. var cp CircomVerifierProof
  769. cp.Root = mt.rootKey
  770. cp.Siblings = p.AllSiblingsCircom(mt.maxLevels)
  771. cp.OldKey = &HashZero
  772. cp.OldValue = &HashZero
  773. cp.Key = NewHashFromBigInt(k)
  774. cp.Value = NewHashFromBigInt(v)
  775. if p.Existence {
  776. cp.Fnc = 0 // inclusion
  777. } else {
  778. cp.Fnc = 1 // non inclusion
  779. }
  780. return &cp, nil
  781. }
  782. // GenerateProof generates the proof of existence (or non-existence) of an
  783. // Entry's hash Index for a Merkle Tree given the root.
  784. // If the rootKey is nil, the current merkletree root is used
  785. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, *big.Int, error) {
  786. p := &Proof{}
  787. var siblingKey *Hash
  788. kHash := NewHashFromBigInt(k)
  789. path := getPath(mt.maxLevels, kHash[:])
  790. if rootKey == nil {
  791. rootKey = mt.Root()
  792. }
  793. nextKey := rootKey
  794. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  795. n, err := mt.GetNode(nextKey)
  796. if err != nil {
  797. return nil, nil, err
  798. }
  799. switch n.Type {
  800. case NodeTypeEmpty:
  801. return p, big.NewInt(0), nil
  802. case NodeTypeLeaf:
  803. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  804. p.Existence = true
  805. return p, n.Entry[1].BigInt(), nil
  806. } else {
  807. // We found a leaf whose entry didn't match hIndex
  808. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  809. return p, n.Entry[1].BigInt(), nil
  810. }
  811. case NodeTypeMiddle:
  812. if path[p.depth] {
  813. nextKey = n.ChildR
  814. siblingKey = n.ChildL
  815. } else {
  816. nextKey = n.ChildL
  817. siblingKey = n.ChildR
  818. }
  819. default:
  820. return nil, nil, ErrInvalidNodeFound
  821. }
  822. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  823. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  824. p.Siblings = append(p.Siblings, siblingKey)
  825. }
  826. }
  827. return nil, nil, ErrKeyNotFound
  828. }
  829. // VerifyProof verifies the Merkle Proof for the entry and root.
  830. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  831. rootFromProof, err := RootFromProof(proof, k, v)
  832. if err != nil {
  833. return false
  834. }
  835. return bytes.Equal(rootKey[:], rootFromProof[:])
  836. }
  837. // RootFromProof calculates the root that would correspond to a tree whose
  838. // siblings are the ones in the proof with the leaf hashing to hIndex and
  839. // hValue.
  840. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  841. kHash := NewHashFromBigInt(k)
  842. vHash := NewHashFromBigInt(v)
  843. sibIdx := len(proof.Siblings) - 1
  844. var err error
  845. var midKey *Hash
  846. if proof.Existence {
  847. midKey, err = LeafKey(kHash, vHash)
  848. if err != nil {
  849. return nil, err
  850. }
  851. } else {
  852. if proof.NodeAux == nil {
  853. midKey = &HashZero
  854. } else {
  855. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  856. return nil, fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  857. }
  858. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  859. if err != nil {
  860. return nil, err
  861. }
  862. }
  863. }
  864. path := getPath(int(proof.depth), kHash[:])
  865. var siblingKey *Hash
  866. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  867. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  868. siblingKey = proof.Siblings[sibIdx]
  869. sibIdx--
  870. } else {
  871. siblingKey = &HashZero
  872. }
  873. if path[lvl] {
  874. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  875. if err != nil {
  876. return nil, err
  877. }
  878. } else {
  879. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  880. if err != nil {
  881. return nil, err
  882. }
  883. }
  884. }
  885. return midKey, nil
  886. }
  887. // walk is a helper recursive function to iterate over all tree branches
  888. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  889. n, err := mt.GetNode(key)
  890. if err != nil {
  891. return err
  892. }
  893. switch n.Type {
  894. case NodeTypeEmpty:
  895. f(n)
  896. case NodeTypeLeaf:
  897. f(n)
  898. case NodeTypeMiddle:
  899. f(n)
  900. if err := mt.walk(n.ChildL, f); err != nil {
  901. return err
  902. }
  903. if err := mt.walk(n.ChildR, f); err != nil {
  904. return err
  905. }
  906. default:
  907. return ErrInvalidNodeFound
  908. }
  909. return nil
  910. }
  911. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  912. // if rootKey is nil, it will get the current RootKey of the current state of the MerkleTree.
  913. // For each node, it calls the f function given in the parameters.
  914. // See some examples of the Walk function usage in the merkletree.go and
  915. // merkletree_test.go
  916. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  917. if rootKey == nil {
  918. rootKey = mt.Root()
  919. }
  920. err := mt.walk(rootKey, f)
  921. return err
  922. }
  923. // GraphViz uses Walk function to generate a string GraphViz representation of the
  924. // tree and writes it to w
  925. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  926. fmt.Fprintf(w, `digraph hierarchy {
  927. node [fontname=Monospace,fontsize=10,shape=box]
  928. `)
  929. cnt := 0
  930. var errIn error
  931. err := mt.Walk(rootKey, func(n *Node) {
  932. k, err := n.Key()
  933. if err != nil {
  934. errIn = err
  935. }
  936. switch n.Type {
  937. case NodeTypeEmpty:
  938. case NodeTypeLeaf:
  939. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.String())
  940. case NodeTypeMiddle:
  941. lr := [2]string{n.ChildL.String(), n.ChildR.String()}
  942. emptyNodes := ""
  943. for i := range lr {
  944. if lr[i] == "0" {
  945. lr[i] = fmt.Sprintf("empty%v", cnt)
  946. emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i])
  947. cnt++
  948. }
  949. }
  950. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.String(), lr[0], lr[1])
  951. fmt.Fprint(w, emptyNodes)
  952. default:
  953. }
  954. })
  955. fmt.Fprintf(w, "}\n")
  956. if errIn != nil {
  957. return errIn
  958. }
  959. return err
  960. }
  961. // PrintGraphViz prints directly the GraphViz() output
  962. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  963. if rootKey == nil {
  964. rootKey = mt.Root()
  965. }
  966. w := bytes.NewBufferString("")
  967. fmt.Fprintf(w, "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  968. err := mt.GraphViz(w, nil)
  969. if err != nil {
  970. return err
  971. }
  972. fmt.Fprintf(w, "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  973. fmt.Println(w)
  974. return nil
  975. }
  976. // DumpLeafs returns all the Leafs that exist under the given Root. If no Root
  977. // is given (nil), it uses the current Root of the MerkleTree.
  978. func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
  979. var b []byte
  980. err := mt.Walk(rootKey, func(n *Node) {
  981. if n.Type == NodeTypeLeaf {
  982. l := n.Entry[0].Bytes()
  983. r := n.Entry[1].Bytes()
  984. b = append(b, append(l[:], r[:]...)...)
  985. }
  986. })
  987. return b, err
  988. }
  989. // ImportDumpedLeafs parses and adds to the MerkleTree the dumped list of leafs
  990. // from the DumpLeafs function.
  991. func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
  992. for i := 0; i < len(b); i += 64 {
  993. lr := b[i : i+64]
  994. lB, err := NewBigIntFromBytes(lr[:32])
  995. if err != nil {
  996. return err
  997. }
  998. rB, err := NewBigIntFromBytes(lr[32:])
  999. if err != nil {
  1000. return err
  1001. }
  1002. err = mt.Add(lB, rB)
  1003. if err != nil {
  1004. return err
  1005. }
  1006. }
  1007. return nil
  1008. }