You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1201 lines
32 KiB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "strings"
  10. "sync"
  11. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  12. )
  13. const (
  14. // proofFlagsLen is the byte length of the flags in the proof header
  15. // (first 32 bytes).
  16. proofFlagsLen = 2
  17. // ElemBytesLen is the length of the Hash byte array
  18. ElemBytesLen = 32
  19. numCharPrint = 8
  20. )
  21. var (
  22. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  23. ErrNodeKeyAlreadyExists = errors.New("key already exists")
  24. // ErrKeyNotFound is used when a key is not found in the MerkleTree.
  25. ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
  26. // ErrNodeBytesBadSize is used when the data of a node has an incorrect
  27. // size and can't be parsed.
  28. ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
  29. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  30. // maximum level.
  31. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  32. // ErrInvalidNodeFound is used when an invalid node is found and can't
  33. // be parsed.
  34. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  35. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  36. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  37. // ErrInvalidDBValue is used when a value in the key value DB is
  38. // invalid (for example, it doen't contain a byte header and a []byte
  39. // body of at least len=1.
  40. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  41. // ErrEntryIndexAlreadyExists is used when the entry index already
  42. // exists in the tree.
  43. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  44. // ErrNotWritable is used when the MerkleTree is not writable and a
  45. // write function is called
  46. ErrNotWritable = errors.New("Merkle Tree not writable")
  47. dbKeyRootNode = []byte("currentroot")
  48. // HashZero is used at Empty nodes
  49. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  50. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  51. )
  52. // Hash is the generic type stored in the MerkleTree
  53. type Hash [32]byte
  54. // MarshalText implements the marshaler for the Hash type
  55. func (h Hash) MarshalText() ([]byte, error) {
  56. return []byte(h.BigInt().String()), nil
  57. }
  58. // UnmarshalText implements the unmarshaler for the Hash type
  59. func (h *Hash) UnmarshalText(b []byte) error {
  60. ha, err := NewHashFromString(string(b))
  61. copy(h[:], ha[:])
  62. return err
  63. }
  64. // String returns decimal representation in string format of the Hash
  65. func (h Hash) String() string {
  66. s := h.BigInt().String()
  67. if len(s) < numCharPrint {
  68. return s
  69. }
  70. return s[0:numCharPrint] + "..."
  71. }
  72. // Hex returns the hexadecimal representation of the Hash
  73. func (h Hash) Hex() string {
  74. return hex.EncodeToString(h[:])
  75. // alternatively equivalent, but with too extra steps:
  76. // bRaw := h.BigInt().Bytes()
  77. // b := [32]byte{}
  78. // copy(b[:], SwapEndianness(bRaw[:]))
  79. // return hex.EncodeToString(b[:])
  80. }
  81. // BigInt returns the *big.Int representation of the *Hash
  82. func (h *Hash) BigInt() *big.Int {
  83. if new(big.Int).SetBytes(SwapEndianness(h[:])) == nil {
  84. return big.NewInt(0)
  85. }
  86. return new(big.Int).SetBytes(SwapEndianness(h[:]))
  87. }
  88. // Bytes returns the []byte representation of the *Hash, which always is 32
  89. // bytes length.
  90. func (h *Hash) Bytes() []byte {
  91. bi := new(big.Int).SetBytes(h[:]).Bytes()
  92. b := [32]byte{}
  93. copy(b[:], SwapEndianness(bi[:]))
  94. return b[:]
  95. }
  96. // NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
  97. // endianness in the process. This is the intended method to get a *big.Int
  98. // from a byte array that previously has ben generated by the Hash.Bytes()
  99. // method.
  100. func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
  101. if len(b) != ElemBytesLen {
  102. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  103. }
  104. bi := new(big.Int).SetBytes(b[:ElemBytesLen])
  105. if !cryptoUtils.CheckBigIntInField(bi) {
  106. return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
  107. }
  108. return bi, nil
  109. }
  110. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  111. func NewHashFromBigInt(b *big.Int) *Hash {
  112. r := &Hash{}
  113. copy(r[:], SwapEndianness(b.Bytes()))
  114. return r
  115. }
  116. // NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
  117. // in the process. This is the intended method to get a *Hash from a byte array
  118. // that previously has ben generated by the Hash.Bytes() method.
  119. func NewHashFromBytes(b []byte) (*Hash, error) {
  120. if len(b) != ElemBytesLen {
  121. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  122. }
  123. var h Hash
  124. copy(h[:], SwapEndianness(b))
  125. return &h, nil
  126. }
  127. // NewHashFromHex returns a *Hash representation of the given hex string
  128. func NewHashFromHex(h string) (*Hash, error) {
  129. h = strings.TrimPrefix(h, "0x")
  130. b, err := hex.DecodeString(h)
  131. if err != nil {
  132. return nil, err
  133. }
  134. return NewHashFromBytes(SwapEndianness(b[:]))
  135. }
  136. // NewHashFromString returns a *Hash representation of the given decimal string
  137. func NewHashFromString(s string) (*Hash, error) {
  138. bi, ok := new(big.Int).SetString(s, 10)
  139. if !ok {
  140. return nil, fmt.Errorf("Can not parse string to Hash")
  141. }
  142. return NewHashFromBigInt(bi), nil
  143. }
  144. // MerkleTree is the struct with the main elements of the MerkleTree
  145. type MerkleTree struct {
  146. sync.RWMutex
  147. db Storage
  148. rootKey *Hash
  149. writable bool
  150. maxLevels int
  151. }
  152. // NewMerkleTree loads a new MerkleTree. If in the storage already exists one
  153. // will open that one, if not, will create a new one.
  154. func NewMerkleTree(storage Storage, maxLevels int) (*MerkleTree, error) {
  155. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  156. root, err := mt.dbGetRoot()
  157. if err == ErrNotFound {
  158. tx, err := mt.db.NewTx()
  159. if err != nil {
  160. return nil, err
  161. }
  162. mt.rootKey = &HashZero
  163. err = tx.SetRoot(mt.rootKey)
  164. if err != nil {
  165. return nil, err
  166. }
  167. err = tx.Commit()
  168. if err != nil {
  169. return nil, err
  170. }
  171. return &mt, nil
  172. } else if err != nil {
  173. return nil, err
  174. }
  175. mt.rootKey = root
  176. return &mt, nil
  177. }
  178. func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
  179. hash, err := mt.db.GetRoot()
  180. if err != nil {
  181. return nil, err
  182. }
  183. return hash, nil
  184. }
  185. // DB returns the MerkleTree.DB()
  186. func (mt *MerkleTree) DB() Storage {
  187. return mt.db
  188. }
  189. // Root returns the MerkleRoot
  190. func (mt *MerkleTree) Root() *Hash {
  191. return mt.rootKey
  192. }
  193. // MaxLevels returns the MT maximum level
  194. func (mt *MerkleTree) MaxLevels() int {
  195. return mt.maxLevels
  196. }
  197. // Snapshot returns a read-only copy of the MerkleTree
  198. func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
  199. mt.RLock()
  200. defer mt.RUnlock()
  201. _, err := mt.GetNode(rootKey)
  202. if err != nil {
  203. return nil, err
  204. }
  205. return &MerkleTree{db: mt.db, maxLevels: mt.maxLevels, rootKey: rootKey, writable: false}, nil
  206. }
  207. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the
  208. // path from the Root to the Leaf.
  209. func (mt *MerkleTree) Add(k, v *big.Int) error {
  210. // verify that the MerkleTree is writable
  211. if !mt.writable {
  212. return ErrNotWritable
  213. }
  214. // verify that k & v are valid and fit inside the Finite Field.
  215. if !cryptoUtils.CheckBigIntInField(k) {
  216. return errors.New("Key not inside the Finite Field")
  217. }
  218. if !cryptoUtils.CheckBigIntInField(v) {
  219. return errors.New("Value not inside the Finite Field")
  220. }
  221. tx, err := mt.db.NewTx()
  222. if err != nil {
  223. return err
  224. }
  225. mt.Lock()
  226. defer mt.Unlock()
  227. kHash := NewHashFromBigInt(k)
  228. vHash := NewHashFromBigInt(v)
  229. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  230. path := getPath(mt.maxLevels, kHash[:])
  231. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  232. if err != nil {
  233. return err
  234. }
  235. mt.rootKey = newRootKey
  236. err = mt.setCurrentRoot(tx, mt.rootKey)
  237. if err != nil {
  238. return err
  239. }
  240. if err := tx.Commit(); err != nil {
  241. return err
  242. }
  243. return nil
  244. }
  245. // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
  246. func (mt *MerkleTree) AddAndGetCircomProof(k,
  247. v *big.Int) (*CircomProcessorProof, error) {
  248. var cp CircomProcessorProof
  249. cp.Fnc = 2
  250. cp.OldRoot = mt.rootKey
  251. gotK, gotV, _, err := mt.Get(k)
  252. if err != nil && err != ErrKeyNotFound {
  253. return nil, err
  254. }
  255. cp.OldKey = NewHashFromBigInt(gotK)
  256. cp.OldValue = NewHashFromBigInt(gotV)
  257. if bytes.Equal(cp.OldKey[:], HashZero[:]) {
  258. cp.IsOld0 = true
  259. }
  260. _, _, siblings, err := mt.Get(k)
  261. if err != nil && err != ErrKeyNotFound {
  262. return nil, err
  263. }
  264. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  265. err = mt.Add(k, v)
  266. if err != nil {
  267. return nil, err
  268. }
  269. cp.NewKey = NewHashFromBigInt(k)
  270. cp.NewValue = NewHashFromBigInt(v)
  271. cp.NewRoot = mt.rootKey
  272. return &cp, nil
  273. }
  274. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  275. // from newLeaf, at which point both leafs are stored, all while updating the
  276. // path.
  277. func (mt *MerkleTree) pushLeaf(tx Tx, newLeaf *Node, oldLeaf *Node, lvl int,
  278. pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  279. if lvl > mt.maxLevels-2 {
  280. return nil, ErrReachedMaxLevel
  281. }
  282. var newNodeMiddle *Node
  283. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  284. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  285. if err != nil {
  286. return nil, err
  287. }
  288. if pathNewLeaf[lvl] { // go right
  289. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey)
  290. } else { // go left
  291. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero)
  292. }
  293. return mt.addNode(tx, newNodeMiddle)
  294. }
  295. oldLeafKey, err := oldLeaf.Key()
  296. if err != nil {
  297. return nil, err
  298. }
  299. newLeafKey, err := newLeaf.Key()
  300. if err != nil {
  301. return nil, err
  302. }
  303. if pathNewLeaf[lvl] {
  304. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  305. } else {
  306. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  307. }
  308. // We can add newLeaf now. We don't need to add oldLeaf because it's
  309. // already in the tree.
  310. _, err = mt.addNode(tx, newLeaf)
  311. if err != nil {
  312. return nil, err
  313. }
  314. return mt.addNode(tx, newNodeMiddle)
  315. }
  316. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  317. func (mt *MerkleTree) addLeaf(tx Tx, newLeaf *Node, key *Hash,
  318. lvl int, path []bool) (*Hash, error) {
  319. var err error
  320. var nextKey *Hash
  321. if lvl > mt.maxLevels-1 {
  322. return nil, ErrReachedMaxLevel
  323. }
  324. n, err := mt.GetNode(key)
  325. if err != nil {
  326. return nil, err
  327. }
  328. switch n.Type {
  329. case NodeTypeEmpty:
  330. // We can add newLeaf now
  331. return mt.addNode(tx, newLeaf)
  332. case NodeTypeLeaf:
  333. nKey := n.Entry[0]
  334. // Check if leaf node found contains the leaf node we are
  335. // trying to add
  336. newLeafKey := newLeaf.Entry[0]
  337. if bytes.Equal(nKey[:], newLeafKey[:]) {
  338. return nil, ErrEntryIndexAlreadyExists
  339. }
  340. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  341. // We need to push newLeaf down until its path diverges from
  342. // n's path
  343. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  344. case NodeTypeMiddle:
  345. // We need to go deeper, continue traversing the tree, left or
  346. // right depending on path
  347. var newNodeMiddle *Node
  348. if path[lvl] { // go right
  349. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path)
  350. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  351. } else { // go left
  352. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path)
  353. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  354. }
  355. if err != nil {
  356. return nil, err
  357. }
  358. // Update the node to reflect the modified child
  359. return mt.addNode(tx, newNodeMiddle)
  360. default:
  361. return nil, ErrInvalidNodeFound
  362. }
  363. }
  364. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  365. // they are all the same and assumed to always exist.
  366. func (mt *MerkleTree) addNode(tx Tx, n *Node) (*Hash, error) {
  367. // verify that the MerkleTree is writable
  368. if !mt.writable {
  369. return nil, ErrNotWritable
  370. }
  371. if n.Type == NodeTypeEmpty {
  372. return n.Key()
  373. }
  374. k, err := n.Key()
  375. if err != nil {
  376. return nil, err
  377. }
  378. //v := n.Value()
  379. // Check that the node key doesn't already exist
  380. if _, err := tx.Get(k[:]); err == nil {
  381. return nil, ErrNodeKeyAlreadyExists
  382. }
  383. err = tx.Put(k[:], n)
  384. return k, err
  385. }
  386. // updateNode updates an existing node in the MT. Empty nodes are not stored
  387. // in the tree; they are all the same and assumed to always exist.
  388. func (mt *MerkleTree) updateNode(tx Tx, n *Node) (*Hash, error) {
  389. // verify that the MerkleTree is writable
  390. if !mt.writable {
  391. return nil, ErrNotWritable
  392. }
  393. if n.Type == NodeTypeEmpty {
  394. return n.Key()
  395. }
  396. k, err := n.Key()
  397. if err != nil {
  398. return nil, err
  399. }
  400. //v := n.Value()
  401. err = tx.Put(k[:], n)
  402. return k, err
  403. }
  404. // Get returns the value of the leaf for the given key
  405. func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
  406. // verify that k is valid and fits inside the Finite Field.
  407. if !cryptoUtils.CheckBigIntInField(k) {
  408. return nil, nil, nil, errors.New("Key not inside the Finite Field")
  409. }
  410. kHash := NewHashFromBigInt(k)
  411. path := getPath(mt.maxLevels, kHash[:])
  412. nextKey := mt.rootKey
  413. siblings := []*Hash{}
  414. for i := 0; i < mt.maxLevels; i++ {
  415. n, err := mt.GetNode(nextKey)
  416. if err != nil {
  417. return nil, nil, nil, err
  418. }
  419. switch n.Type {
  420. case NodeTypeEmpty:
  421. return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
  422. case NodeTypeLeaf:
  423. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  424. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
  425. }
  426. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
  427. case NodeTypeMiddle:
  428. if path[i] {
  429. nextKey = n.ChildR
  430. siblings = append(siblings, n.ChildL)
  431. } else {
  432. nextKey = n.ChildL
  433. siblings = append(siblings, n.ChildR)
  434. }
  435. default:
  436. return nil, nil, nil, ErrInvalidNodeFound
  437. }
  438. }
  439. return nil, nil, nil, ErrReachedMaxLevel
  440. }
  441. // Update updates the value of a specified key in the MerkleTree, and updates
  442. // the path from the leaf to the Root with the new values. Returns the
  443. // CircomProcessorProof.
  444. func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
  445. // verify that the MerkleTree is writable
  446. if !mt.writable {
  447. return nil, ErrNotWritable
  448. }
  449. // verify that k & v are valid and fit inside the Finite Field.
  450. if !cryptoUtils.CheckBigIntInField(k) {
  451. return nil, errors.New("Key not inside the Finite Field")
  452. }
  453. if !cryptoUtils.CheckBigIntInField(v) {
  454. return nil, errors.New("Key not inside the Finite Field")
  455. }
  456. tx, err := mt.db.NewTx()
  457. if err != nil {
  458. return nil, err
  459. }
  460. mt.Lock()
  461. defer mt.Unlock()
  462. kHash := NewHashFromBigInt(k)
  463. vHash := NewHashFromBigInt(v)
  464. path := getPath(mt.maxLevels, kHash[:])
  465. var cp CircomProcessorProof
  466. cp.Fnc = 1
  467. cp.OldRoot = mt.rootKey
  468. cp.OldKey = kHash
  469. cp.NewKey = kHash
  470. cp.NewValue = vHash
  471. nextKey := mt.rootKey
  472. siblings := []*Hash{}
  473. for i := 0; i < mt.maxLevels; i++ {
  474. n, err := mt.GetNode(nextKey)
  475. if err != nil {
  476. return nil, err
  477. }
  478. switch n.Type {
  479. case NodeTypeEmpty:
  480. return nil, ErrKeyNotFound
  481. case NodeTypeLeaf:
  482. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  483. cp.OldValue = n.Entry[1]
  484. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  485. // update leaf and upload to the root
  486. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  487. _, err := mt.updateNode(tx, newNodeLeaf)
  488. if err != nil {
  489. return nil, err
  490. }
  491. newRootKey, err :=
  492. mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
  493. if err != nil {
  494. return nil, err
  495. }
  496. mt.rootKey = newRootKey
  497. err = mt.setCurrentRoot(tx, mt.rootKey)
  498. if err != nil {
  499. return nil, err
  500. }
  501. cp.NewRoot = newRootKey
  502. if err := tx.Commit(); err != nil {
  503. return nil, err
  504. }
  505. return &cp, nil
  506. }
  507. return nil, ErrKeyNotFound
  508. case NodeTypeMiddle:
  509. if path[i] {
  510. nextKey = n.ChildR
  511. siblings = append(siblings, n.ChildL)
  512. } else {
  513. nextKey = n.ChildL
  514. siblings = append(siblings, n.ChildR)
  515. }
  516. default:
  517. return nil, ErrInvalidNodeFound
  518. }
  519. }
  520. return nil, ErrKeyNotFound
  521. }
  522. // Delete removes the specified Key from the MerkleTree and updates the path
  523. // from the deleted key to the Root with the new values. This method removes
  524. // the key from the MerkleTree, but does not remove the old nodes from the
  525. // key-value database; this means that if the tree is accessed by an old Root
  526. // where the key was not deleted yet, the key will still exist. If is desired
  527. // to remove the key-values from the database that are not under the current
  528. // Root, an option could be to dump all the leaves (using mt.DumpLeafs) and
  529. // import them in a new MerkleTree in a new database (using
  530. // mt.ImportDumpedLeafs), but this will loose all the Root history of the
  531. // MerkleTree
  532. func (mt *MerkleTree) Delete(k *big.Int) error {
  533. // verify that the MerkleTree is writable
  534. if !mt.writable {
  535. return ErrNotWritable
  536. }
  537. // verify that k is valid and fits inside the Finite Field.
  538. if !cryptoUtils.CheckBigIntInField(k) {
  539. return errors.New("Key not inside the Finite Field")
  540. }
  541. tx, err := mt.db.NewTx()
  542. if err != nil {
  543. return err
  544. }
  545. mt.Lock()
  546. defer mt.Unlock()
  547. kHash := NewHashFromBigInt(k)
  548. path := getPath(mt.maxLevels, kHash[:])
  549. nextKey := mt.rootKey
  550. siblings := []*Hash{}
  551. for i := 0; i < mt.maxLevels; i++ {
  552. n, err := mt.GetNode(nextKey)
  553. if err != nil {
  554. return err
  555. }
  556. switch n.Type {
  557. case NodeTypeEmpty:
  558. return ErrKeyNotFound
  559. case NodeTypeLeaf:
  560. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  561. // remove and go up with the sibling
  562. err = mt.rmAndUpload(tx, path, kHash, siblings)
  563. return err
  564. }
  565. return ErrKeyNotFound
  566. case NodeTypeMiddle:
  567. if path[i] {
  568. nextKey = n.ChildR
  569. siblings = append(siblings, n.ChildL)
  570. } else {
  571. nextKey = n.ChildL
  572. siblings = append(siblings, n.ChildR)
  573. }
  574. default:
  575. return ErrInvalidNodeFound
  576. }
  577. }
  578. return ErrKeyNotFound
  579. }
  580. // rmAndUpload removes the key, and goes up until the root updating all the
  581. // nodes with the new values.
  582. func (mt *MerkleTree) rmAndUpload(tx Tx, path []bool, kHash *Hash, siblings []*Hash) error {
  583. if len(siblings) == 0 {
  584. mt.rootKey = &HashZero
  585. err := mt.setCurrentRoot(tx, mt.rootKey)
  586. if err != nil {
  587. return err
  588. }
  589. return tx.Commit()
  590. }
  591. toUpload := siblings[len(siblings)-1]
  592. if len(siblings) < 2 { //nolint:gomnd
  593. mt.rootKey = siblings[0]
  594. err := mt.setCurrentRoot(tx, mt.rootKey)
  595. if err != nil {
  596. return err
  597. }
  598. return tx.Commit()
  599. }
  600. for i := len(siblings) - 2; i >= 0; i-- { //nolint:gomnd
  601. if !bytes.Equal(siblings[i][:], HashZero[:]) {
  602. var newNode *Node
  603. if path[i] {
  604. newNode = NewNodeMiddle(siblings[i], toUpload)
  605. } else {
  606. newNode = NewNodeMiddle(toUpload, siblings[i])
  607. }
  608. _, err := mt.addNode(tx, newNode)
  609. if err != ErrNodeKeyAlreadyExists && err != nil {
  610. return err
  611. }
  612. // go up until the root
  613. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode,
  614. siblings[:i])
  615. if err != nil {
  616. return err
  617. }
  618. mt.rootKey = newRootKey
  619. err = mt.setCurrentRoot(tx, mt.rootKey)
  620. if err != nil {
  621. return err
  622. }
  623. break
  624. }
  625. // if i==0 (root position), stop and store the sibling of the
  626. // deleted leaf as root
  627. if i == 0 {
  628. mt.rootKey = toUpload
  629. err := mt.setCurrentRoot(tx, mt.rootKey)
  630. if err != nil {
  631. return err
  632. }
  633. break
  634. }
  635. }
  636. if err := tx.Commit(); err != nil {
  637. return err
  638. }
  639. return nil
  640. }
  641. // recalculatePathUntilRoot recalculates the nodes until the Root
  642. func (mt *MerkleTree) recalculatePathUntilRoot(tx Tx, path []bool, node *Node,
  643. siblings []*Hash) (*Hash, error) {
  644. for i := len(siblings) - 1; i >= 0; i-- {
  645. nodeKey, err := node.Key()
  646. if err != nil {
  647. return nil, err
  648. }
  649. if path[i] {
  650. node = NewNodeMiddle(siblings[i], nodeKey)
  651. } else {
  652. node = NewNodeMiddle(nodeKey, siblings[i])
  653. }
  654. _, err = mt.addNode(tx, node)
  655. if err != ErrNodeKeyAlreadyExists && err != nil {
  656. return nil, err
  657. }
  658. }
  659. // return last node added, which is the root
  660. nodeKey, err := node.Key()
  661. return nodeKey, err
  662. }
  663. // setCurrentRoot is a helper function to update current root in an open db
  664. // transaction.
  665. func (mt *MerkleTree) setCurrentRoot(tx Tx, hash *Hash) error {
  666. return tx.SetRoot(hash)
  667. }
  668. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  669. // tree; they are all the same and assumed to always exist.
  670. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  671. if bytes.Equal(key[:], HashZero[:]) {
  672. return NewNodeEmpty(), nil
  673. }
  674. n, err := mt.db.Get(key[:])
  675. if err != nil {
  676. return nil, err
  677. }
  678. return n, nil
  679. }
  680. // getPath returns the binary path, from the root to the leaf.
  681. func getPath(numLevels int, k []byte) []bool {
  682. path := make([]bool, numLevels)
  683. for n := 0; n < numLevels; n++ {
  684. path[n] = TestBit(k[:], uint(n))
  685. }
  686. return path
  687. }
  688. // NodeAux contains the auxiliary node used in a non-existence proof.
  689. type NodeAux struct {
  690. Key *Hash
  691. Value *Hash
  692. }
  693. // Proof defines the required elements for a MT proof of existence or
  694. // non-existence.
  695. type Proof struct {
  696. // existence indicates wether this is a proof of existence or
  697. // non-existence.
  698. Existence bool
  699. // depth indicates how deep in the tree the proof goes.
  700. depth uint
  701. // notempties is a bitmap of non-empty Siblings found in Siblings.
  702. notempties [ElemBytesLen - proofFlagsLen]byte
  703. // Siblings is a list of non-empty sibling keys.
  704. Siblings []*Hash
  705. NodeAux *NodeAux
  706. }
  707. // NewProofFromBytes parses a byte array into a Proof.
  708. func NewProofFromBytes(bs []byte) (*Proof, error) {
  709. if len(bs) < ElemBytesLen {
  710. return nil, ErrInvalidProofBytes
  711. }
  712. p := &Proof{}
  713. if (bs[0] & 0x01) == 0 {
  714. p.Existence = true
  715. }
  716. p.depth = uint(bs[1])
  717. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  718. siblingBytes := bs[ElemBytesLen:]
  719. sibIdx := 0
  720. for i := uint(0); i < p.depth; i++ {
  721. if TestBitBigEndian(p.notempties[:], i) {
  722. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  723. return nil, ErrInvalidProofBytes
  724. }
  725. var sib Hash
  726. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  727. p.Siblings = append(p.Siblings, &sib)
  728. sibIdx++
  729. }
  730. }
  731. if !p.Existence && ((bs[0] & 0x02) != 0) {
  732. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  733. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  734. if len(nodeAuxBytes) != 2*ElemBytesLen {
  735. return nil, ErrInvalidProofBytes
  736. }
  737. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  738. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  739. }
  740. return p, nil
  741. }
  742. // Bytes serializes a Proof into a byte array.
  743. func (p *Proof) Bytes() []byte {
  744. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  745. if p.NodeAux != nil {
  746. bsLen += 2 * ElemBytesLen //nolint:gomnd
  747. }
  748. bs := make([]byte, bsLen)
  749. if !p.Existence {
  750. bs[0] |= 0x01
  751. }
  752. bs[1] = byte(p.depth)
  753. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  754. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  755. for i, k := range p.Siblings {
  756. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  757. }
  758. if p.NodeAux != nil {
  759. bs[0] |= 0x02
  760. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  761. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  762. }
  763. return bs
  764. }
  765. // SiblingsFromProof returns all the siblings of the proof.
  766. func SiblingsFromProof(proof *Proof) []*Hash {
  767. sibIdx := 0
  768. siblings := []*Hash{}
  769. for lvl := 0; lvl < int(proof.depth); lvl++ {
  770. if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  771. siblings = append(siblings, proof.Siblings[sibIdx])
  772. sibIdx++
  773. } else {
  774. siblings = append(siblings, &HashZero)
  775. }
  776. }
  777. return siblings
  778. }
  779. // AllSiblings returns all the siblings of the proof.
  780. func (p *Proof) AllSiblings() []*Hash {
  781. return SiblingsFromProof(p)
  782. }
  783. // CircomSiblingsFromSiblings returns the full siblings compatible with circom
  784. func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
  785. // Add the rest of empty levels to the siblings
  786. for i := len(siblings); i < levels+1; i++ {
  787. siblings = append(siblings, &HashZero)
  788. }
  789. return siblings
  790. }
  791. // CircomProcessorProof defines the ProcessorProof compatible with circom. Is
  792. // the data of the proof between the transition from one state to another.
  793. type CircomProcessorProof struct {
  794. OldRoot *Hash `json:"oldRoot"`
  795. NewRoot *Hash `json:"newRoot"`
  796. Siblings []*Hash `json:"siblings"`
  797. OldKey *Hash `json:"oldKey"`
  798. OldValue *Hash `json:"oldValue"`
  799. NewKey *Hash `json:"newKey"`
  800. NewValue *Hash `json:"newValue"`
  801. IsOld0 bool `json:"isOld0"`
  802. // 0: NOP, 1: Update, 2: Insert, 3: Delete
  803. Fnc int `json:"fnc"`
  804. }
  805. // String returns a human readable string representation of the
  806. // CircomProcessorProof
  807. func (p CircomProcessorProof) String() string {
  808. buf := bytes.NewBufferString("{")
  809. fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
  810. fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
  811. fmt.Fprintf(buf, " Siblings: [\n ")
  812. for _, s := range p.Siblings {
  813. fmt.Fprintf(buf, "%v, ", s)
  814. }
  815. fmt.Fprintf(buf, "\n ],\n")
  816. fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
  817. fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
  818. fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
  819. fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
  820. fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
  821. fmt.Fprintf(buf, "}\n")
  822. return buf.String()
  823. }
  824. // CircomVerifierProof defines the VerifierProof compatible with circom. Is the
  825. // data of the proof that a certain leaf exists in the MerkleTree.
  826. type CircomVerifierProof struct {
  827. Root *Hash `json:"root"`
  828. Siblings []*Hash `json:"siblings"`
  829. OldKey *Hash `json:"oldKey"`
  830. OldValue *Hash `json:"oldValue"`
  831. IsOld0 bool `json:"isOld0"`
  832. Key *Hash `json:"key"`
  833. Value *Hash `json:"value"`
  834. Fnc int `json:"fnc"` // 0: inclusion, 1: non inclusion
  835. }
  836. // GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
  837. // key in the MerkleTree. If the rootKey is nil, the current merkletree root
  838. // is used.
  839. func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int,
  840. rootKey *Hash) (*CircomVerifierProof, error) {
  841. cp, err := mt.GenerateSCVerifierProof(k, rootKey)
  842. if err != nil {
  843. return nil, err
  844. }
  845. cp.Siblings = CircomSiblingsFromSiblings(cp.Siblings, mt.maxLevels)
  846. return cp, nil
  847. }
  848. // GenerateSCVerifierProof returns the CircomVerifierProof for a certain key in
  849. // the MerkleTree with the Siblings without the extra 0 needed at the circom
  850. // circuits, which makes it straight forward to verifiy inside a Smart
  851. // Contract. If the rootKey is nil, the current merkletree root is used.
  852. func (mt *MerkleTree) GenerateSCVerifierProof(k *big.Int,
  853. rootKey *Hash) (*CircomVerifierProof, error) {
  854. if rootKey == nil {
  855. rootKey = mt.Root()
  856. }
  857. p, v, err := mt.GenerateProof(k, rootKey)
  858. if err != nil && err != ErrKeyNotFound {
  859. return nil, err
  860. }
  861. var cp CircomVerifierProof
  862. cp.Root = rootKey
  863. cp.Siblings = p.AllSiblings()
  864. if p.NodeAux != nil {
  865. cp.OldKey = p.NodeAux.Key
  866. cp.OldValue = p.NodeAux.Value
  867. } else {
  868. cp.OldKey = &HashZero
  869. cp.OldValue = &HashZero
  870. }
  871. cp.Key = NewHashFromBigInt(k)
  872. cp.Value = NewHashFromBigInt(v)
  873. if p.Existence {
  874. cp.Fnc = 0 // inclusion
  875. } else {
  876. cp.Fnc = 1 // non inclusion
  877. }
  878. return &cp, nil
  879. }
  880. // GenerateProof generates the proof of existence (or non-existence) of an
  881. // Entry's hash Index for a Merkle Tree given the root.
  882. // If the rootKey is nil, the current merkletree root is used
  883. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof,
  884. *big.Int, error) {
  885. p := &Proof{}
  886. var siblingKey *Hash
  887. kHash := NewHashFromBigInt(k)
  888. path := getPath(mt.maxLevels, kHash[:])
  889. if rootKey == nil {
  890. rootKey = mt.Root()
  891. }
  892. nextKey := rootKey
  893. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  894. n, err := mt.GetNode(nextKey)
  895. if err != nil {
  896. return nil, nil, err
  897. }
  898. switch n.Type {
  899. case NodeTypeEmpty:
  900. return p, big.NewInt(0), nil
  901. case NodeTypeLeaf:
  902. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  903. p.Existence = true
  904. return p, n.Entry[1].BigInt(), nil
  905. }
  906. // We found a leaf whose entry didn't match hIndex
  907. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  908. return p, n.Entry[1].BigInt(), nil
  909. case NodeTypeMiddle:
  910. if path[p.depth] {
  911. nextKey = n.ChildR
  912. siblingKey = n.ChildL
  913. } else {
  914. nextKey = n.ChildL
  915. siblingKey = n.ChildR
  916. }
  917. default:
  918. return nil, nil, ErrInvalidNodeFound
  919. }
  920. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  921. SetBitBigEndian(p.notempties[:], uint(p.depth))
  922. p.Siblings = append(p.Siblings, siblingKey)
  923. }
  924. }
  925. return nil, nil, ErrKeyNotFound
  926. }
  927. // VerifyProof verifies the Merkle Proof for the entry and root.
  928. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  929. rootFromProof, err := RootFromProof(proof, k, v)
  930. if err != nil {
  931. return false
  932. }
  933. return bytes.Equal(rootKey[:], rootFromProof[:])
  934. }
  935. // RootFromProof calculates the root that would correspond to a tree whose
  936. // siblings are the ones in the proof with the leaf hashing to hIndex and
  937. // hValue.
  938. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  939. kHash := NewHashFromBigInt(k)
  940. vHash := NewHashFromBigInt(v)
  941. sibIdx := len(proof.Siblings) - 1
  942. var err error
  943. var midKey *Hash
  944. if proof.Existence {
  945. midKey, err = LeafKey(kHash, vHash)
  946. if err != nil {
  947. return nil, err
  948. }
  949. } else {
  950. if proof.NodeAux == nil {
  951. midKey = &HashZero
  952. } else {
  953. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  954. return nil,
  955. fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  956. }
  957. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  958. if err != nil {
  959. return nil, err
  960. }
  961. }
  962. }
  963. path := getPath(int(proof.depth), kHash[:])
  964. var siblingKey *Hash
  965. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  966. if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  967. siblingKey = proof.Siblings[sibIdx]
  968. sibIdx--
  969. } else {
  970. siblingKey = &HashZero
  971. }
  972. if path[lvl] {
  973. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  974. if err != nil {
  975. return nil, err
  976. }
  977. } else {
  978. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  979. if err != nil {
  980. return nil, err
  981. }
  982. }
  983. }
  984. return midKey, nil
  985. }
  986. // walk is a helper recursive function to iterate over all tree branches
  987. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  988. n, err := mt.GetNode(key)
  989. if err != nil {
  990. return err
  991. }
  992. switch n.Type {
  993. case NodeTypeEmpty:
  994. f(n)
  995. case NodeTypeLeaf:
  996. f(n)
  997. case NodeTypeMiddle:
  998. f(n)
  999. if err := mt.walk(n.ChildL, f); err != nil {
  1000. return err
  1001. }
  1002. if err := mt.walk(n.ChildR, f); err != nil {
  1003. return err
  1004. }
  1005. default:
  1006. return ErrInvalidNodeFound
  1007. }
  1008. return nil
  1009. }
  1010. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  1011. // if rootKey is nil, it will get the current RootKey of the current state of
  1012. // the MerkleTree. For each node, it calls the f function given in the
  1013. // parameters. See some examples of the Walk function usage in the
  1014. // merkletree.go and merkletree_test.go
  1015. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  1016. if rootKey == nil {
  1017. rootKey = mt.Root()
  1018. }
  1019. err := mt.walk(rootKey, f)
  1020. return err
  1021. }
  1022. // GraphViz uses Walk function to generate a string GraphViz representation of
  1023. // the tree and writes it to w
  1024. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  1025. fmt.Fprintf(w, `digraph hierarchy {
  1026. node [fontname=Monospace,fontsize=10,shape=box]
  1027. `)
  1028. cnt := 0
  1029. var errIn error
  1030. err := mt.Walk(rootKey, func(n *Node) {
  1031. k, err := n.Key()
  1032. if err != nil {
  1033. errIn = err
  1034. }
  1035. switch n.Type {
  1036. case NodeTypeEmpty:
  1037. case NodeTypeLeaf:
  1038. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.String())
  1039. case NodeTypeMiddle:
  1040. lr := [2]string{n.ChildL.String(), n.ChildR.String()}
  1041. emptyNodes := ""
  1042. for i := range lr {
  1043. if lr[i] == "0" {
  1044. lr[i] = fmt.Sprintf("empty%v", cnt)
  1045. emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i])
  1046. cnt++
  1047. }
  1048. }
  1049. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.String(), lr[0], lr[1])
  1050. fmt.Fprint(w, emptyNodes)
  1051. default:
  1052. }
  1053. })
  1054. fmt.Fprintf(w, "}\n")
  1055. if errIn != nil {
  1056. return errIn
  1057. }
  1058. return err
  1059. }
  1060. // PrintGraphViz prints directly the GraphViz() output
  1061. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  1062. if rootKey == nil {
  1063. rootKey = mt.Root()
  1064. }
  1065. w := bytes.NewBufferString("")
  1066. fmt.Fprintf(w,
  1067. "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  1068. err := mt.GraphViz(w, nil)
  1069. if err != nil {
  1070. return err
  1071. }
  1072. fmt.Fprintf(w,
  1073. "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  1074. fmt.Println(w)
  1075. return nil
  1076. }
  1077. // DumpLeafs returns all the Leafs that exist under the given Root. If no Root
  1078. // is given (nil), it uses the current Root of the MerkleTree.
  1079. func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
  1080. var b []byte
  1081. err := mt.Walk(rootKey, func(n *Node) {
  1082. if n.Type == NodeTypeLeaf {
  1083. l := n.Entry[0].Bytes()
  1084. r := n.Entry[1].Bytes()
  1085. b = append(b, append(l[:], r[:]...)...)
  1086. }
  1087. })
  1088. return b, err
  1089. }
  1090. // ImportDumpedLeafs parses and adds to the MerkleTree the dumped list of leafs
  1091. // from the DumpLeafs function.
  1092. func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
  1093. for i := 0; i < len(b); i += 64 {
  1094. lr := b[i : i+64]
  1095. lB, err := NewBigIntFromHashBytes(lr[:32])
  1096. if err != nil {
  1097. return err
  1098. }
  1099. rB, err := NewBigIntFromHashBytes(lr[32:])
  1100. if err != nil {
  1101. return err
  1102. }
  1103. err = mt.Add(lB, rB)
  1104. if err != nil {
  1105. return err
  1106. }
  1107. }
  1108. return nil
  1109. }