You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1134 lines
31 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "sync"
  10. "github.com/iden3/go-iden3-core/common"
  11. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  12. "github.com/iden3/go-merkletree/db"
  13. )
  14. const (
  15. // proofFlagsLen is the byte length of the flags in the proof header
  16. // (first 32 bytes).
  17. proofFlagsLen = 2
  18. // ElemBytesLen is the length of the Hash byte array
  19. ElemBytesLen = 32
  20. )
  21. var (
  22. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  23. ErrNodeKeyAlreadyExists = errors.New("key already exists")
  24. // ErrKeyNotFound is used when a key is not found in the MerkleTree.
  25. ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
  26. // ErrNodeBytesBadSize is used when the data of a node has an incorrect
  27. // size and can't be parsed.
  28. ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
  29. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  30. // maximum level.
  31. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  32. // ErrInvalidNodeFound is used when an invalid node is found and can't
  33. // be parsed.
  34. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  35. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  36. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  37. // ErrInvalidDBValue is used when a value in the key value DB is
  38. // invalid (for example, it doen't contain a byte header and a []byte
  39. // body of at least len=1.
  40. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  41. // ErrEntryIndexAlreadyExists is used when the entry index already
  42. // exists in the tree.
  43. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  44. // ErrNotWritable is used when the MerkleTree is not writable and a
  45. // write function is called
  46. ErrNotWritable = errors.New("Merkle Tree not writable")
  47. rootNodeValue = []byte("currentroot")
  48. // HashZero is used at Empty nodes
  49. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  50. )
  51. // Hash is the generic type stored in the MerkleTree
  52. type Hash [32]byte
  53. // String returns decimal representation in string format of the Hash
  54. func (h Hash) String() string {
  55. s := h.BigInt().String()
  56. if len(s) < 8 {
  57. return s
  58. }
  59. return s[0:8] + "..."
  60. }
  61. // Hex returns the hexadecimal representation of the Hash
  62. func (h Hash) Hex() string {
  63. return hex.EncodeToString(h.BigInt().Bytes())
  64. }
  65. // BigInt returns the *big.Int representation of the *Hash
  66. func (h *Hash) BigInt() *big.Int {
  67. if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
  68. return big.NewInt(0)
  69. }
  70. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  71. }
  72. // Bytes returns the []byte representation of the *Hash, which always is 32
  73. // bytes length.
  74. func (h *Hash) Bytes() []byte {
  75. bi := new(big.Int).SetBytes(common.SwapEndianness(h[:])).Bytes()
  76. b := [32]byte{}
  77. copy(b[:], bi[:])
  78. return b[:]
  79. }
  80. // NewBigIntFromBytes returns a *big.Int from a byte array, swapping the
  81. // endianness in the process. This is the intended method to get a *big.Int
  82. // from a byte array that previously has ben generated by the Hash.Bytes()
  83. // method.
  84. func NewBigIntFromBytes(b []byte) (*big.Int, error) {
  85. if len(b) != 32 {
  86. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  87. }
  88. return new(big.Int).SetBytes(common.SwapEndianness(b[:32])), nil
  89. }
  90. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  91. func NewHashFromBigInt(b *big.Int) *Hash {
  92. r := &Hash{}
  93. copy(r[:], common.SwapEndianness(b.Bytes()))
  94. return r
  95. }
  96. // NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
  97. // in the process. This is the intended method to get a *Hash from a byte array
  98. // that previously has ben generated by the Hash.Bytes() method.
  99. func NewHashFromBytes(b []byte) (*Hash, error) {
  100. if len(b) != 32 {
  101. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  102. }
  103. var h Hash
  104. copy(h[:], common.SwapEndianness(b))
  105. return &h, nil
  106. }
  107. // MerkleTree is the struct with the main elements of the MerkleTree
  108. type MerkleTree struct {
  109. sync.RWMutex
  110. db db.Storage
  111. rootKey *Hash
  112. writable bool
  113. maxLevels int
  114. }
  115. // NewMerkleTree loads a new Merkletree. If in the sotrage already exists one
  116. // will open that one, if not, will create a new one.
  117. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  118. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  119. v, err := mt.db.Get(rootNodeValue)
  120. if err != nil {
  121. tx, err := mt.db.NewTx()
  122. if err != nil {
  123. return nil, err
  124. }
  125. mt.rootKey = &HashZero
  126. tx.Put(rootNodeValue, mt.rootKey[:])
  127. err = tx.Commit()
  128. if err != nil {
  129. return nil, err
  130. }
  131. return &mt, nil
  132. }
  133. mt.rootKey = &Hash{}
  134. copy(mt.rootKey[:], v)
  135. return &mt, nil
  136. }
  137. // DB returns the MerkleTree.DB()
  138. func (mt *MerkleTree) DB() db.Storage {
  139. return mt.db
  140. }
  141. // Root returns the MerkleRoot
  142. func (mt *MerkleTree) Root() *Hash {
  143. return mt.rootKey
  144. }
  145. // MaxLevels returns the MT maximum level
  146. func (mt *MerkleTree) MaxLevels() int {
  147. return mt.maxLevels
  148. }
  149. // Snapshot returns a read-only copy of the MerkleTree
  150. func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
  151. mt.RLock()
  152. defer mt.RUnlock()
  153. _, err := mt.GetNode(rootKey)
  154. if err != nil {
  155. return nil, err
  156. }
  157. return &MerkleTree{db: mt.db, maxLevels: mt.maxLevels, rootKey: rootKey, writable: false}, nil
  158. }
  159. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the
  160. // path from the Root to the Leaf.
  161. func (mt *MerkleTree) Add(k, v *big.Int) error {
  162. // verify that the MerkleTree is writable
  163. if !mt.writable {
  164. return ErrNotWritable
  165. }
  166. // verfy that k & v are valid and fit inside the Finite Field.
  167. if !cryptoUtils.CheckBigIntInField(k) {
  168. return errors.New("Key not inside the Finite Field")
  169. }
  170. if !cryptoUtils.CheckBigIntInField(v) {
  171. return errors.New("Value not inside the Finite Field")
  172. }
  173. tx, err := mt.db.NewTx()
  174. if err != nil {
  175. return err
  176. }
  177. mt.Lock()
  178. defer mt.Unlock()
  179. kHash := NewHashFromBigInt(k)
  180. vHash := NewHashFromBigInt(v)
  181. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  182. path := getPath(mt.maxLevels, kHash[:])
  183. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  184. if err != nil {
  185. return err
  186. }
  187. mt.rootKey = newRootKey
  188. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  189. if err := tx.Commit(); err != nil {
  190. return err
  191. }
  192. return nil
  193. }
  194. // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
  195. func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
  196. var cp CircomProcessorProof
  197. cp.Fnc = 2
  198. cp.OldRoot = mt.rootKey
  199. gettedK, gettedV, siblings, err := mt.Get(k)
  200. if err != nil && err != ErrKeyNotFound {
  201. return nil, err
  202. }
  203. cp.OldKey = NewHashFromBigInt(gettedK)
  204. cp.OldValue = NewHashFromBigInt(gettedV)
  205. if bytes.Equal(cp.OldKey[:], HashZero[:]) {
  206. cp.IsOld0 = true
  207. }
  208. _, _, siblings, err = mt.Get(k)
  209. if err != nil && err != ErrKeyNotFound {
  210. return nil, err
  211. }
  212. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  213. err = mt.Add(k, v)
  214. if err != nil {
  215. return nil, err
  216. }
  217. cp.NewKey = NewHashFromBigInt(k)
  218. cp.NewValue = NewHashFromBigInt(v)
  219. cp.NewRoot = mt.rootKey
  220. return &cp, nil
  221. }
  222. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  223. // from newLeaf, at which point both leafs are stored, all while updating the
  224. // path.
  225. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  226. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  227. if lvl > mt.maxLevels-2 {
  228. return nil, ErrReachedMaxLevel
  229. }
  230. var newNodeMiddle *Node
  231. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  232. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  233. if err != nil {
  234. return nil, err
  235. }
  236. if pathNewLeaf[lvl] {
  237. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  238. } else {
  239. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  240. }
  241. return mt.addNode(tx, newNodeMiddle)
  242. } else {
  243. oldLeafKey, err := oldLeaf.Key()
  244. if err != nil {
  245. return nil, err
  246. }
  247. newLeafKey, err := newLeaf.Key()
  248. if err != nil {
  249. return nil, err
  250. }
  251. if pathNewLeaf[lvl] {
  252. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  253. } else {
  254. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  255. }
  256. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  257. _, err = mt.addNode(tx, newLeaf)
  258. if err != nil {
  259. return nil, err
  260. }
  261. return mt.addNode(tx, newNodeMiddle)
  262. }
  263. }
  264. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  265. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  266. lvl int, path []bool) (*Hash, error) {
  267. var err error
  268. var nextKey *Hash
  269. if lvl > mt.maxLevels-1 {
  270. return nil, ErrReachedMaxLevel
  271. }
  272. n, err := mt.GetNode(key)
  273. if err != nil {
  274. return nil, err
  275. }
  276. switch n.Type {
  277. case NodeTypeEmpty:
  278. // We can add newLeaf now
  279. return mt.addNode(tx, newLeaf)
  280. case NodeTypeLeaf:
  281. nKey := n.Entry[0]
  282. // Check if leaf node found contains the leaf node we are trying to add
  283. newLeafKey := newLeaf.Entry[0]
  284. if bytes.Equal(nKey[:], newLeafKey[:]) {
  285. return nil, ErrEntryIndexAlreadyExists
  286. }
  287. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  288. // We need to push newLeaf down until its path diverges from n's path
  289. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  290. case NodeTypeMiddle:
  291. // We need to go deeper, continue traversing the tree, left or
  292. // right depending on path
  293. var newNodeMiddle *Node
  294. if path[lvl] {
  295. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  296. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  297. } else {
  298. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  299. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  300. }
  301. if err != nil {
  302. return nil, err
  303. }
  304. // Update the node to reflect the modified child
  305. return mt.addNode(tx, newNodeMiddle)
  306. default:
  307. return nil, ErrInvalidNodeFound
  308. }
  309. }
  310. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  311. // they are all the same and assumed to always exist.
  312. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  313. // verify that the MerkleTree is writable
  314. if !mt.writable {
  315. return nil, ErrNotWritable
  316. }
  317. if n.Type == NodeTypeEmpty {
  318. return n.Key()
  319. }
  320. k, err := n.Key()
  321. if err != nil {
  322. return nil, err
  323. }
  324. v := n.Value()
  325. // Check that the node key doesn't already exist
  326. if _, err := tx.Get(k[:]); err == nil {
  327. return nil, ErrNodeKeyAlreadyExists
  328. }
  329. tx.Put(k[:], v)
  330. return k, nil
  331. }
  332. // updateNode updates an existing node in the MT. Empty nodes are not stored
  333. // in the tree; they are all the same and assumed to always exist.
  334. func (mt *MerkleTree) updateNode(tx db.Tx, n *Node) (*Hash, error) {
  335. // verify that the MerkleTree is writable
  336. if !mt.writable {
  337. return nil, ErrNotWritable
  338. }
  339. if n.Type == NodeTypeEmpty {
  340. return n.Key()
  341. }
  342. k, err := n.Key()
  343. if err != nil {
  344. return nil, err
  345. }
  346. v := n.Value()
  347. tx.Put(k[:], v)
  348. return k, nil
  349. }
  350. // Get returns the value of the leaf for the given key
  351. func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
  352. // verfy that k is valid and fit inside the Finite Field.
  353. if !cryptoUtils.CheckBigIntInField(k) {
  354. return nil, nil, nil, errors.New("Key not inside the Finite Field")
  355. }
  356. kHash := NewHashFromBigInt(k)
  357. path := getPath(mt.maxLevels, kHash[:])
  358. nextKey := mt.rootKey
  359. var siblings []*Hash
  360. for i := 0; i < mt.maxLevels; i++ {
  361. n, err := mt.GetNode(nextKey)
  362. if err != nil {
  363. return nil, nil, nil, err
  364. }
  365. switch n.Type {
  366. case NodeTypeEmpty:
  367. return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
  368. case NodeTypeLeaf:
  369. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  370. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
  371. } else {
  372. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
  373. }
  374. case NodeTypeMiddle:
  375. if path[i] {
  376. nextKey = n.ChildR
  377. siblings = append(siblings, n.ChildL)
  378. } else {
  379. nextKey = n.ChildL
  380. siblings = append(siblings, n.ChildR)
  381. }
  382. default:
  383. return nil, nil, nil, ErrInvalidNodeFound
  384. }
  385. }
  386. return nil, nil, nil, ErrKeyNotFound
  387. }
  388. // Update updates the value of a specified key in the MerkleTree, and updates
  389. // the path from the leaf to the Root with the new values. Returns the
  390. // CircomProcessorProof.
  391. func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
  392. // verify that the MerkleTree is writable
  393. if !mt.writable {
  394. return nil, ErrNotWritable
  395. }
  396. // verfy that k & are valid and fit inside the Finite Field.
  397. if !cryptoUtils.CheckBigIntInField(k) {
  398. return nil, errors.New("Key not inside the Finite Field")
  399. }
  400. if !cryptoUtils.CheckBigIntInField(v) {
  401. return nil, errors.New("Key not inside the Finite Field")
  402. }
  403. tx, err := mt.db.NewTx()
  404. if err != nil {
  405. return nil, err
  406. }
  407. mt.Lock()
  408. defer mt.Unlock()
  409. kHash := NewHashFromBigInt(k)
  410. vHash := NewHashFromBigInt(v)
  411. path := getPath(mt.maxLevels, kHash[:])
  412. var cp CircomProcessorProof
  413. cp.Fnc = 1
  414. cp.OldRoot = mt.rootKey
  415. cp.OldKey = kHash
  416. cp.NewKey = kHash
  417. cp.NewValue = vHash
  418. nextKey := mt.rootKey
  419. var siblings []*Hash
  420. for i := 0; i < mt.maxLevels; i++ {
  421. n, err := mt.GetNode(nextKey)
  422. if err != nil {
  423. return nil, err
  424. }
  425. switch n.Type {
  426. case NodeTypeEmpty:
  427. return nil, ErrKeyNotFound
  428. case NodeTypeLeaf:
  429. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  430. cp.OldValue = n.Entry[1]
  431. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  432. // update leaf and upload to the root
  433. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  434. _, err := mt.updateNode(tx, newNodeLeaf)
  435. if err != nil {
  436. return nil, err
  437. }
  438. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
  439. if err != nil {
  440. return nil, err
  441. }
  442. mt.rootKey = newRootKey
  443. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  444. cp.NewRoot = newRootKey
  445. if err := tx.Commit(); err != nil {
  446. return nil, err
  447. }
  448. return &cp, nil
  449. } else {
  450. return nil, ErrKeyNotFound
  451. }
  452. case NodeTypeMiddle:
  453. if path[i] {
  454. nextKey = n.ChildR
  455. siblings = append(siblings, n.ChildL)
  456. } else {
  457. nextKey = n.ChildL
  458. siblings = append(siblings, n.ChildR)
  459. }
  460. default:
  461. return nil, ErrInvalidNodeFound
  462. }
  463. }
  464. return nil, ErrKeyNotFound
  465. }
  466. // Delete removes the specified Key from the MerkleTree and updates the path
  467. // from the deleted key to the Root with the new values. This method removes
  468. // the key from the MerkleTree, but does not remove the old nodes from the
  469. // key-value database; this means that if the tree is accessed by an old Root
  470. // where the key was not deleted yet, the key will still exist. If is desired
  471. // to remove the key-values from the database that are not under the current
  472. // Root, an option could be to dump all the leafs (using mt.DumpLeafs) and
  473. // import them in a new MerkleTree in a new database (using
  474. // mt.ImportDumpedLeafs), but this will loose all the Root history of the
  475. // MerkleTree
  476. func (mt *MerkleTree) Delete(k *big.Int) error {
  477. // verify that the MerkleTree is writable
  478. if !mt.writable {
  479. return ErrNotWritable
  480. }
  481. // verfy that k is valid and fit inside the Finite Field.
  482. if !cryptoUtils.CheckBigIntInField(k) {
  483. return errors.New("Key not inside the Finite Field")
  484. }
  485. tx, err := mt.db.NewTx()
  486. if err != nil {
  487. return err
  488. }
  489. mt.Lock()
  490. defer mt.Unlock()
  491. kHash := NewHashFromBigInt(k)
  492. path := getPath(mt.maxLevels, kHash[:])
  493. nextKey := mt.rootKey
  494. var siblings []*Hash
  495. for i := 0; i < mt.maxLevels; i++ {
  496. n, err := mt.GetNode(nextKey)
  497. if err != nil {
  498. return err
  499. }
  500. switch n.Type {
  501. case NodeTypeEmpty:
  502. return ErrKeyNotFound
  503. case NodeTypeLeaf:
  504. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  505. // remove and go up with the sibling
  506. err = mt.rmAndUpload(tx, path, kHash, siblings)
  507. return err
  508. } else {
  509. return ErrKeyNotFound
  510. }
  511. case NodeTypeMiddle:
  512. if path[i] {
  513. nextKey = n.ChildR
  514. siblings = append(siblings, n.ChildL)
  515. } else {
  516. nextKey = n.ChildL
  517. siblings = append(siblings, n.ChildR)
  518. }
  519. default:
  520. return ErrInvalidNodeFound
  521. }
  522. }
  523. return ErrKeyNotFound
  524. }
  525. // rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values.
  526. func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
  527. if len(siblings) == 0 {
  528. mt.rootKey = &HashZero
  529. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  530. return tx.Commit()
  531. }
  532. toUpload := siblings[len(siblings)-1]
  533. if len(siblings) < 2 {
  534. mt.rootKey = siblings[0]
  535. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  536. return tx.Commit()
  537. }
  538. for i := len(siblings) - 2; i >= 0; i-- {
  539. if !bytes.Equal(siblings[i][:], HashZero[:]) {
  540. var newNode *Node
  541. if path[i] {
  542. newNode = NewNodeMiddle(siblings[i], toUpload)
  543. } else {
  544. newNode = NewNodeMiddle(toUpload, siblings[i])
  545. }
  546. _, err := mt.addNode(tx, newNode)
  547. if err != ErrNodeKeyAlreadyExists && err != nil {
  548. return err
  549. }
  550. // go up until the root
  551. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode, siblings[:i])
  552. if err != nil {
  553. return err
  554. }
  555. mt.rootKey = newRootKey
  556. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  557. break
  558. }
  559. // if i==0 (root position), stop and store the sibling of the deleted leaf as root
  560. if i == 0 {
  561. mt.rootKey = toUpload
  562. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  563. break
  564. }
  565. }
  566. if err := tx.Commit(); err != nil {
  567. return err
  568. }
  569. return nil
  570. }
  571. // recalculatePathUntilRoot recalculates the nodes until the Root
  572. func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node, siblings []*Hash) (*Hash, error) {
  573. for i := len(siblings) - 1; i >= 0; i-- {
  574. nodeKey, err := node.Key()
  575. if err != nil {
  576. return nil, err
  577. }
  578. if path[i] {
  579. node = NewNodeMiddle(siblings[i], nodeKey)
  580. } else {
  581. node = NewNodeMiddle(nodeKey, siblings[i])
  582. }
  583. _, err = mt.addNode(tx, node)
  584. if err != ErrNodeKeyAlreadyExists && err != nil {
  585. return nil, err
  586. }
  587. }
  588. // return last node added, which is the root
  589. nodeKey, err := node.Key()
  590. return nodeKey, err
  591. }
  592. // dbGet is a helper function to get the node of a key from the internal
  593. // storage.
  594. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
  595. if bytes.Equal(k, HashZero[:]) {
  596. return 0, nil, nil
  597. }
  598. value, err := mt.db.Get(k)
  599. if err != nil {
  600. return 0, nil, err
  601. }
  602. if len(value) < 2 {
  603. return 0, nil, ErrInvalidDBValue
  604. }
  605. nodeType := value[0]
  606. nodeBytes := value[1:]
  607. return NodeType(nodeType), nodeBytes, nil
  608. }
  609. // dbInsert is a helper function to insert a node into a key in an open db
  610. // transaction.
  611. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
  612. v := append([]byte{byte(t)}, data...)
  613. tx.Put(k, v)
  614. }
  615. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  616. // tree; they are all the same and assumed to always exist.
  617. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  618. if bytes.Equal(key[:], HashZero[:]) {
  619. return NewNodeEmpty(), nil
  620. }
  621. nBytes, err := mt.db.Get(key[:])
  622. if err != nil {
  623. return nil, err
  624. }
  625. return NewNodeFromBytes(nBytes)
  626. }
  627. // getPath returns the binary path, from the root to the leaf.
  628. func getPath(numLevels int, k []byte) []bool {
  629. path := make([]bool, numLevels)
  630. for n := 0; n < numLevels; n++ {
  631. path[n] = common.TestBit(k[:], uint(n))
  632. }
  633. return path
  634. }
  635. // NodeAux contains the auxiliary node used in a non-existence proof.
  636. type NodeAux struct {
  637. Key *Hash
  638. Value *Hash
  639. }
  640. // Proof defines the required elements for a MT proof of existence or non-existence.
  641. type Proof struct {
  642. // existence indicates wether this is a proof of existence or non-existence.
  643. Existence bool
  644. // depth indicates how deep in the tree the proof goes.
  645. depth uint
  646. // notempties is a bitmap of non-empty Siblings found in Siblings.
  647. notempties [ElemBytesLen - proofFlagsLen]byte
  648. // Siblings is a list of non-empty sibling keys.
  649. Siblings []*Hash
  650. NodeAux *NodeAux
  651. }
  652. // NewProofFromBytes parses a byte array into a Proof.
  653. func NewProofFromBytes(bs []byte) (*Proof, error) {
  654. if len(bs) < ElemBytesLen {
  655. return nil, ErrInvalidProofBytes
  656. }
  657. p := &Proof{}
  658. if (bs[0] & 0x01) == 0 {
  659. p.Existence = true
  660. }
  661. p.depth = uint(bs[1])
  662. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  663. siblingBytes := bs[ElemBytesLen:]
  664. sibIdx := 0
  665. for i := uint(0); i < p.depth; i++ {
  666. if common.TestBitBigEndian(p.notempties[:], i) {
  667. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  668. return nil, ErrInvalidProofBytes
  669. }
  670. var sib Hash
  671. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  672. p.Siblings = append(p.Siblings, &sib)
  673. sibIdx++
  674. }
  675. }
  676. if !p.Existence && ((bs[0] & 0x02) != 0) {
  677. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  678. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  679. if len(nodeAuxBytes) != 2*ElemBytesLen {
  680. return nil, ErrInvalidProofBytes
  681. }
  682. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  683. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  684. }
  685. return p, nil
  686. }
  687. // Bytes serializes a Proof into a byte array.
  688. func (p *Proof) Bytes() []byte {
  689. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  690. if p.NodeAux != nil {
  691. bsLen += 2 * ElemBytesLen
  692. }
  693. bs := make([]byte, bsLen)
  694. if !p.Existence {
  695. bs[0] |= 0x01
  696. }
  697. bs[1] = byte(p.depth)
  698. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  699. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  700. for i, k := range p.Siblings {
  701. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  702. }
  703. if p.NodeAux != nil {
  704. bs[0] |= 0x02
  705. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  706. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  707. }
  708. return bs
  709. }
  710. // SiblingsFromProof returns all the siblings of the proof.
  711. func SiblingsFromProof(proof *Proof) []*Hash {
  712. sibIdx := 0
  713. var siblings []*Hash
  714. for lvl := 0; lvl < int(proof.depth); lvl++ {
  715. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  716. siblings = append(siblings, proof.Siblings[sibIdx])
  717. sibIdx++
  718. } else {
  719. siblings = append(siblings, &HashZero)
  720. }
  721. }
  722. return siblings
  723. }
  724. // AllSiblings returns all the siblings of the proof.
  725. func (p *Proof) AllSiblings() []*Hash {
  726. return SiblingsFromProof(p)
  727. }
  728. // CircomSiblingsFromSiblings returns the full siblings compatible with circom
  729. func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
  730. // Add the rest of empty levels to the siblings
  731. for i := len(siblings); i < levels; i++ {
  732. siblings = append(siblings, &HashZero)
  733. }
  734. siblings = append(siblings, &HashZero) // add extra level for circom compatibility
  735. return siblings
  736. }
  737. // AllSiblingsCircom returns all the siblings of the proof. This function is
  738. // used to generate the siblings input for the circom circuits.
  739. func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
  740. siblings := p.AllSiblings()
  741. // Add the rest of empty levels to the siblings
  742. for i := len(siblings); i < levels; i++ {
  743. siblings = append(siblings, &HashZero)
  744. }
  745. siblings = append(siblings, &HashZero) // add extra level for circom compatibility
  746. siblingsBigInt := make([]*big.Int, len(siblings))
  747. for i, sibling := range siblings {
  748. siblingsBigInt[i] = sibling.BigInt()
  749. }
  750. return siblingsBigInt
  751. }
  752. // CircomProcessorProof defines the ProcessorProof compatible with circom. Is
  753. // the data of the proof between the transition from one state to another.
  754. type CircomProcessorProof struct {
  755. OldRoot *Hash
  756. NewRoot *Hash
  757. Siblings []*Hash
  758. OldKey *Hash
  759. OldValue *Hash
  760. NewKey *Hash
  761. NewValue *Hash
  762. IsOld0 bool
  763. Fnc int // 0: NOP, 1: Update, 2: Insert, 3: Delete
  764. }
  765. // String returns a human readable string representation of the
  766. // CircomProcessorProof
  767. func (p CircomProcessorProof) String() string {
  768. buf := bytes.NewBufferString("{")
  769. fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
  770. fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
  771. fmt.Fprintf(buf, " Siblings: [\n ")
  772. for _, s := range p.Siblings {
  773. fmt.Fprintf(buf, "%v, ", s)
  774. }
  775. fmt.Fprintf(buf, "\n ],\n")
  776. fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
  777. fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
  778. fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
  779. fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
  780. fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
  781. fmt.Fprintf(buf, "}\n")
  782. return buf.String()
  783. }
  784. // CircomVerifierProof defines the VerifierProof compatible with circom. Is the
  785. // data of the proof that a certain leaf exists in the MerkleTree.
  786. type CircomVerifierProof struct {
  787. Root *Hash
  788. Siblings []*big.Int
  789. OldKey *Hash
  790. OldValue *Hash
  791. IsOld0 bool
  792. Key *Hash
  793. Value *Hash
  794. Fnc int // 0: inclusion, 1: non inclusion
  795. }
  796. // GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
  797. // key in the MerkleTree. If the rootKey is nil, the current merkletree root
  798. // is used.
  799. func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int, rootKey *Hash) (*CircomVerifierProof, error) {
  800. p, v, err := mt.GenerateProof(k, rootKey)
  801. if err != nil && err != ErrKeyNotFound {
  802. return nil, err
  803. }
  804. var cp CircomVerifierProof
  805. cp.Root = mt.rootKey
  806. cp.Siblings = p.AllSiblingsCircom(mt.maxLevels)
  807. cp.OldKey = &HashZero
  808. cp.OldValue = &HashZero
  809. cp.Key = NewHashFromBigInt(k)
  810. cp.Value = NewHashFromBigInt(v)
  811. if p.Existence {
  812. cp.Fnc = 0 // inclusion
  813. } else {
  814. cp.Fnc = 1 // non inclusion
  815. }
  816. return &cp, nil
  817. }
  818. // GenerateProof generates the proof of existence (or non-existence) of an
  819. // Entry's hash Index for a Merkle Tree given the root.
  820. // If the rootKey is nil, the current merkletree root is used
  821. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, *big.Int, error) {
  822. p := &Proof{}
  823. var siblingKey *Hash
  824. kHash := NewHashFromBigInt(k)
  825. path := getPath(mt.maxLevels, kHash[:])
  826. if rootKey == nil {
  827. rootKey = mt.Root()
  828. }
  829. nextKey := rootKey
  830. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  831. n, err := mt.GetNode(nextKey)
  832. if err != nil {
  833. return nil, nil, err
  834. }
  835. switch n.Type {
  836. case NodeTypeEmpty:
  837. return p, big.NewInt(0), nil
  838. case NodeTypeLeaf:
  839. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  840. p.Existence = true
  841. return p, n.Entry[1].BigInt(), nil
  842. } else {
  843. // We found a leaf whose entry didn't match hIndex
  844. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  845. return p, n.Entry[1].BigInt(), nil
  846. }
  847. case NodeTypeMiddle:
  848. if path[p.depth] {
  849. nextKey = n.ChildR
  850. siblingKey = n.ChildL
  851. } else {
  852. nextKey = n.ChildL
  853. siblingKey = n.ChildR
  854. }
  855. default:
  856. return nil, nil, ErrInvalidNodeFound
  857. }
  858. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  859. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  860. p.Siblings = append(p.Siblings, siblingKey)
  861. }
  862. }
  863. return nil, nil, ErrKeyNotFound
  864. }
  865. // VerifyProof verifies the Merkle Proof for the entry and root.
  866. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  867. rootFromProof, err := RootFromProof(proof, k, v)
  868. if err != nil {
  869. return false
  870. }
  871. return bytes.Equal(rootKey[:], rootFromProof[:])
  872. }
  873. // RootFromProof calculates the root that would correspond to a tree whose
  874. // siblings are the ones in the proof with the leaf hashing to hIndex and
  875. // hValue.
  876. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  877. kHash := NewHashFromBigInt(k)
  878. vHash := NewHashFromBigInt(v)
  879. sibIdx := len(proof.Siblings) - 1
  880. var err error
  881. var midKey *Hash
  882. if proof.Existence {
  883. midKey, err = LeafKey(kHash, vHash)
  884. if err != nil {
  885. return nil, err
  886. }
  887. } else {
  888. if proof.NodeAux == nil {
  889. midKey = &HashZero
  890. } else {
  891. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  892. return nil, fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  893. }
  894. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  895. if err != nil {
  896. return nil, err
  897. }
  898. }
  899. }
  900. path := getPath(int(proof.depth), kHash[:])
  901. var siblingKey *Hash
  902. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  903. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  904. siblingKey = proof.Siblings[sibIdx]
  905. sibIdx--
  906. } else {
  907. siblingKey = &HashZero
  908. }
  909. if path[lvl] {
  910. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  911. if err != nil {
  912. return nil, err
  913. }
  914. } else {
  915. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  916. if err != nil {
  917. return nil, err
  918. }
  919. }
  920. }
  921. return midKey, nil
  922. }
  923. // walk is a helper recursive function to iterate over all tree branches
  924. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  925. n, err := mt.GetNode(key)
  926. if err != nil {
  927. return err
  928. }
  929. switch n.Type {
  930. case NodeTypeEmpty:
  931. f(n)
  932. case NodeTypeLeaf:
  933. f(n)
  934. case NodeTypeMiddle:
  935. f(n)
  936. if err := mt.walk(n.ChildL, f); err != nil {
  937. return err
  938. }
  939. if err := mt.walk(n.ChildR, f); err != nil {
  940. return err
  941. }
  942. default:
  943. return ErrInvalidNodeFound
  944. }
  945. return nil
  946. }
  947. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  948. // if rootKey is nil, it will get the current RootKey of the current state of the MerkleTree.
  949. // For each node, it calls the f function given in the parameters.
  950. // See some examples of the Walk function usage in the merkletree.go and
  951. // merkletree_test.go
  952. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  953. if rootKey == nil {
  954. rootKey = mt.Root()
  955. }
  956. err := mt.walk(rootKey, f)
  957. return err
  958. }
  959. // GraphViz uses Walk function to generate a string GraphViz representation of the
  960. // tree and writes it to w
  961. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  962. fmt.Fprintf(w, `digraph hierarchy {
  963. node [fontname=Monospace,fontsize=10,shape=box]
  964. `)
  965. cnt := 0
  966. var errIn error
  967. err := mt.Walk(rootKey, func(n *Node) {
  968. k, err := n.Key()
  969. if err != nil {
  970. errIn = err
  971. }
  972. switch n.Type {
  973. case NodeTypeEmpty:
  974. case NodeTypeLeaf:
  975. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.String())
  976. case NodeTypeMiddle:
  977. lr := [2]string{n.ChildL.String(), n.ChildR.String()}
  978. emptyNodes := ""
  979. for i := range lr {
  980. if lr[i] == "0" {
  981. lr[i] = fmt.Sprintf("empty%v", cnt)
  982. emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i])
  983. cnt++
  984. }
  985. }
  986. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.String(), lr[0], lr[1])
  987. fmt.Fprint(w, emptyNodes)
  988. default:
  989. }
  990. })
  991. fmt.Fprintf(w, "}\n")
  992. if errIn != nil {
  993. return errIn
  994. }
  995. return err
  996. }
  997. // PrintGraphViz prints directly the GraphViz() output
  998. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  999. if rootKey == nil {
  1000. rootKey = mt.Root()
  1001. }
  1002. w := bytes.NewBufferString("")
  1003. fmt.Fprintf(w, "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  1004. err := mt.GraphViz(w, nil)
  1005. if err != nil {
  1006. return err
  1007. }
  1008. fmt.Fprintf(w, "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  1009. fmt.Println(w)
  1010. return nil
  1011. }
  1012. // DumpLeafs returns all the Leafs that exist under the given Root. If no Root
  1013. // is given (nil), it uses the current Root of the MerkleTree.
  1014. func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
  1015. var b []byte
  1016. err := mt.Walk(rootKey, func(n *Node) {
  1017. if n.Type == NodeTypeLeaf {
  1018. l := n.Entry[0].Bytes()
  1019. r := n.Entry[1].Bytes()
  1020. b = append(b, append(l[:], r[:]...)...)
  1021. }
  1022. })
  1023. return b, err
  1024. }
  1025. // ImportDumpedLeafs parses and adds to the MerkleTree the dumped list of leafs
  1026. // from the DumpLeafs function.
  1027. func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
  1028. for i := 0; i < len(b); i += 64 {
  1029. lr := b[i : i+64]
  1030. lB, err := NewBigIntFromBytes(lr[:32])
  1031. if err != nil {
  1032. return err
  1033. }
  1034. rB, err := NewBigIntFromBytes(lr[32:])
  1035. if err != nil {
  1036. return err
  1037. }
  1038. err = mt.Add(lB, rB)
  1039. if err != nil {
  1040. return err
  1041. }
  1042. }
  1043. return nil
  1044. }