You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1195 lines
32 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "strings"
  10. "sync"
  11. "github.com/iden3/go-iden3-core/common"
  12. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  13. "github.com/iden3/go-merkletree/db"
  14. )
  15. const (
  16. // proofFlagsLen is the byte length of the flags in the proof header
  17. // (first 32 bytes).
  18. proofFlagsLen = 2
  19. // ElemBytesLen is the length of the Hash byte array
  20. ElemBytesLen = 32
  21. numCharPrint = 8
  22. )
  23. var (
  24. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  25. ErrNodeKeyAlreadyExists = errors.New("key already exists")
  26. // ErrKeyNotFound is used when a key is not found in the MerkleTree.
  27. ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
  28. // ErrNodeBytesBadSize is used when the data of a node has an incorrect
  29. // size and can't be parsed.
  30. ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
  31. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  32. // maximum level.
  33. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  34. // ErrInvalidNodeFound is used when an invalid node is found and can't
  35. // be parsed.
  36. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  37. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  38. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  39. // ErrInvalidDBValue is used when a value in the key value DB is
  40. // invalid (for example, it doen't contain a byte header and a []byte
  41. // body of at least len=1.
  42. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  43. // ErrEntryIndexAlreadyExists is used when the entry index already
  44. // exists in the tree.
  45. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  46. // ErrNotWritable is used when the MerkleTree is not writable and a
  47. // write function is called
  48. ErrNotWritable = errors.New("Merkle Tree not writable")
  49. dbKeyRootNode = []byte("currentroot")
  50. // HashZero is used at Empty nodes
  51. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  52. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  53. )
  54. // Hash is the generic type stored in the MerkleTree
  55. type Hash [32]byte
  56. // MarshalText implements the marshaler for the Hash type
  57. func (h Hash) MarshalText() ([]byte, error) {
  58. return []byte(h.BigInt().String()), nil
  59. }
  60. // UnmarshalText implements the unmarshaler for the Hash type
  61. func (h *Hash) UnmarshalText(b []byte) error {
  62. ha, err := NewHashFromString(string(b))
  63. copy(h[:], ha[:])
  64. return err
  65. }
  66. // String returns decimal representation in string format of the Hash
  67. func (h Hash) String() string {
  68. s := h.BigInt().String()
  69. if len(s) < numCharPrint {
  70. return s
  71. }
  72. return s[0:numCharPrint] + "..."
  73. }
  74. // Hex returns the hexadecimal representation of the Hash
  75. func (h Hash) Hex() string {
  76. return hex.EncodeToString(h[:])
  77. // alternatively equivalent, but with too extra steps:
  78. // bRaw := h.BigInt().Bytes()
  79. // b := [32]byte{}
  80. // copy(b[:], common.SwapEndianness(bRaw[:]))
  81. // return hex.EncodeToString(b[:])
  82. }
  83. // BigInt returns the *big.Int representation of the *Hash
  84. func (h *Hash) BigInt() *big.Int {
  85. if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
  86. return big.NewInt(0)
  87. }
  88. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  89. }
  90. // Bytes returns the []byte representation of the *Hash, which always is 32
  91. // bytes length.
  92. func (h *Hash) Bytes() []byte {
  93. bi := new(big.Int).SetBytes(h[:]).Bytes()
  94. b := [32]byte{}
  95. copy(b[:], common.SwapEndianness(bi[:]))
  96. return b[:]
  97. }
  98. // NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
  99. // endianness in the process. This is the intended method to get a *big.Int
  100. // from a byte array that previously has ben generated by the Hash.Bytes()
  101. // method.
  102. func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
  103. if len(b) != ElemBytesLen {
  104. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes",
  105. len(b))
  106. }
  107. bi := new(big.Int).SetBytes(b[:ElemBytesLen])
  108. if !cryptoUtils.CheckBigIntInField(bi) {
  109. return nil,
  110. fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
  111. }
  112. return bi, nil
  113. }
  114. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  115. func NewHashFromBigInt(b *big.Int) *Hash {
  116. r := &Hash{}
  117. copy(r[:], common.SwapEndianness(b.Bytes()))
  118. return r
  119. }
  120. // NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
  121. // in the process. This is the intended method to get a *Hash from a byte array
  122. // that previously has ben generated by the Hash.Bytes() method.
  123. func NewHashFromBytes(b []byte) (*Hash, error) {
  124. if len(b) != ElemBytesLen {
  125. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes",
  126. len(b))
  127. }
  128. var h Hash
  129. copy(h[:], common.SwapEndianness(b))
  130. return &h, nil
  131. }
  132. // NewHashFromHex returns a *Hash representation of the given hex string
  133. func NewHashFromHex(h string) (*Hash, error) {
  134. h = strings.TrimPrefix(h, "0x")
  135. b, err := hex.DecodeString(h)
  136. if err != nil {
  137. return nil, err
  138. }
  139. return NewHashFromBytes(common.SwapEndianness(b[:]))
  140. }
  141. // NewHashFromString returns a *Hash representation of the given decimal string
  142. func NewHashFromString(s string) (*Hash, error) {
  143. bi, ok := new(big.Int).SetString(s, 10)
  144. if !ok {
  145. return nil, fmt.Errorf("Can not parse string to Hash")
  146. }
  147. return NewHashFromBigInt(bi), nil
  148. }
  149. // MerkleTree is the struct with the main elements of the MerkleTree
  150. type MerkleTree struct {
  151. sync.RWMutex
  152. db db.Storage
  153. rootKey *Hash
  154. writable bool
  155. maxLevels int
  156. }
  157. // NewMerkleTree loads a new Merkletree. If in the sotrage already exists one
  158. // will open that one, if not, will create a new one.
  159. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  160. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  161. root, err := mt.dbGetRoot()
  162. if err == db.ErrNotFound {
  163. tx, err := mt.db.NewTx()
  164. if err != nil {
  165. return nil, err
  166. }
  167. mt.rootKey = &HashZero
  168. err = tx.Put(dbKeyRootNode, mt.rootKey[:])
  169. if err != nil {
  170. return nil, err
  171. }
  172. err = tx.Commit()
  173. if err != nil {
  174. return nil, err
  175. }
  176. return &mt, nil
  177. } else if err != nil {
  178. return nil, err
  179. }
  180. mt.rootKey = root
  181. return &mt, nil
  182. }
  183. func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
  184. v, err := mt.db.Get(dbKeyRootNode)
  185. if err != nil {
  186. return nil, err
  187. }
  188. var root Hash
  189. // Skip first byte which contains the NodeType
  190. copy(root[:], v[1:])
  191. return &root, nil
  192. }
  193. // DB returns the MerkleTree.DB()
  194. func (mt *MerkleTree) DB() db.Storage {
  195. return mt.db
  196. }
  197. // Root returns the MerkleRoot
  198. func (mt *MerkleTree) Root() *Hash {
  199. return mt.rootKey
  200. }
  201. // MaxLevels returns the MT maximum level
  202. func (mt *MerkleTree) MaxLevels() int {
  203. return mt.maxLevels
  204. }
  205. // Snapshot returns a read-only copy of the MerkleTree
  206. func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
  207. mt.RLock()
  208. defer mt.RUnlock()
  209. _, err := mt.GetNode(rootKey)
  210. if err != nil {
  211. return nil, err
  212. }
  213. return &MerkleTree{db: mt.db, maxLevels: mt.maxLevels, rootKey: rootKey, writable: false}, nil
  214. }
  215. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the
  216. // path from the Root to the Leaf.
  217. func (mt *MerkleTree) Add(k, v *big.Int) error {
  218. // verify that the MerkleTree is writable
  219. if !mt.writable {
  220. return ErrNotWritable
  221. }
  222. // verfy that k & v are valid and fit inside the Finite Field.
  223. if !cryptoUtils.CheckBigIntInField(k) {
  224. return errors.New("Key not inside the Finite Field")
  225. }
  226. if !cryptoUtils.CheckBigIntInField(v) {
  227. return errors.New("Value not inside the Finite Field")
  228. }
  229. tx, err := mt.db.NewTx()
  230. if err != nil {
  231. return err
  232. }
  233. mt.Lock()
  234. defer mt.Unlock()
  235. kHash := NewHashFromBigInt(k)
  236. vHash := NewHashFromBigInt(v)
  237. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  238. path := getPath(mt.maxLevels, kHash[:])
  239. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  240. if err != nil {
  241. return err
  242. }
  243. mt.rootKey = newRootKey
  244. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  245. if err != nil {
  246. return err
  247. }
  248. if err := tx.Commit(); err != nil {
  249. return err
  250. }
  251. return nil
  252. }
  253. // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
  254. func (mt *MerkleTree) AddAndGetCircomProof(k,
  255. v *big.Int) (*CircomProcessorProof, error) {
  256. var cp CircomProcessorProof
  257. cp.Fnc = 2
  258. cp.OldRoot = mt.rootKey
  259. gettedK, gettedV, _, err := mt.Get(k)
  260. if err != nil && err != ErrKeyNotFound {
  261. return nil, err
  262. }
  263. cp.OldKey = NewHashFromBigInt(gettedK)
  264. cp.OldValue = NewHashFromBigInt(gettedV)
  265. if bytes.Equal(cp.OldKey[:], HashZero[:]) {
  266. cp.IsOld0 = true
  267. }
  268. _, _, siblings, err := mt.Get(k)
  269. if err != nil && err != ErrKeyNotFound {
  270. return nil, err
  271. }
  272. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  273. err = mt.Add(k, v)
  274. if err != nil {
  275. return nil, err
  276. }
  277. cp.NewKey = NewHashFromBigInt(k)
  278. cp.NewValue = NewHashFromBigInt(v)
  279. cp.NewRoot = mt.rootKey
  280. return &cp, nil
  281. }
  282. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  283. // from newLeaf, at which point both leafs are stored, all while updating the
  284. // path.
  285. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node, lvl int,
  286. pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  287. if lvl > mt.maxLevels-2 {
  288. return nil, ErrReachedMaxLevel
  289. }
  290. var newNodeMiddle *Node
  291. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  292. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1,
  293. pathNewLeaf, pathOldLeaf)
  294. if err != nil {
  295. return nil, err
  296. }
  297. if pathNewLeaf[lvl] { // go right
  298. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey)
  299. } else { // go left
  300. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero)
  301. }
  302. return mt.addNode(tx, newNodeMiddle)
  303. }
  304. oldLeafKey, err := oldLeaf.Key()
  305. if err != nil {
  306. return nil, err
  307. }
  308. newLeafKey, err := newLeaf.Key()
  309. if err != nil {
  310. return nil, err
  311. }
  312. if pathNewLeaf[lvl] {
  313. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  314. } else {
  315. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  316. }
  317. // We can add newLeaf now. We don't need to add oldLeaf because it's
  318. // already in the tree.
  319. _, err = mt.addNode(tx, newLeaf)
  320. if err != nil {
  321. return nil, err
  322. }
  323. return mt.addNode(tx, newNodeMiddle)
  324. }
  325. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  326. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  327. lvl int, path []bool) (*Hash, error) {
  328. var err error
  329. var nextKey *Hash
  330. if lvl > mt.maxLevels-1 {
  331. return nil, ErrReachedMaxLevel
  332. }
  333. n, err := mt.GetNode(key)
  334. if err != nil {
  335. return nil, err
  336. }
  337. switch n.Type {
  338. case NodeTypeEmpty:
  339. // We can add newLeaf now
  340. return mt.addNode(tx, newLeaf)
  341. case NodeTypeLeaf:
  342. nKey := n.Entry[0]
  343. // Check if leaf node found contains the leaf node we are
  344. // trying to add
  345. newLeafKey := newLeaf.Entry[0]
  346. if bytes.Equal(nKey[:], newLeafKey[:]) {
  347. return nil, ErrEntryIndexAlreadyExists
  348. }
  349. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  350. // We need to push newLeaf down until its path diverges from
  351. // n's path
  352. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  353. case NodeTypeMiddle:
  354. // We need to go deeper, continue traversing the tree, left or
  355. // right depending on path
  356. var newNodeMiddle *Node
  357. if path[lvl] { // go right
  358. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1,
  359. path)
  360. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  361. } else { // go left
  362. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1,
  363. path)
  364. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  365. }
  366. if err != nil {
  367. return nil, err
  368. }
  369. // Update the node to reflect the modified child
  370. return mt.addNode(tx, newNodeMiddle)
  371. default:
  372. return nil, ErrInvalidNodeFound
  373. }
  374. }
  375. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  376. // they are all the same and assumed to always exist.
  377. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  378. // verify that the MerkleTree is writable
  379. if !mt.writable {
  380. return nil, ErrNotWritable
  381. }
  382. if n.Type == NodeTypeEmpty {
  383. return n.Key()
  384. }
  385. k, err := n.Key()
  386. if err != nil {
  387. return nil, err
  388. }
  389. v := n.Value()
  390. // Check that the node key doesn't already exist
  391. if _, err := tx.Get(k[:]); err == nil {
  392. return nil, ErrNodeKeyAlreadyExists
  393. }
  394. err = tx.Put(k[:], v)
  395. return k, err
  396. }
  397. // updateNode updates an existing node in the MT. Empty nodes are not stored
  398. // in the tree; they are all the same and assumed to always exist.
  399. func (mt *MerkleTree) updateNode(tx db.Tx, n *Node) (*Hash, error) {
  400. // verify that the MerkleTree is writable
  401. if !mt.writable {
  402. return nil, ErrNotWritable
  403. }
  404. if n.Type == NodeTypeEmpty {
  405. return n.Key()
  406. }
  407. k, err := n.Key()
  408. if err != nil {
  409. return nil, err
  410. }
  411. v := n.Value()
  412. err = tx.Put(k[:], v)
  413. return k, err
  414. }
  415. // Get returns the value of the leaf for the given key
  416. func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
  417. // verfy that k is valid and fit inside the Finite Field.
  418. if !cryptoUtils.CheckBigIntInField(k) {
  419. return nil, nil, nil,
  420. errors.New("Key not inside the Finite Field")
  421. }
  422. kHash := NewHashFromBigInt(k)
  423. path := getPath(mt.maxLevels, kHash[:])
  424. nextKey := mt.rootKey
  425. var siblings []*Hash
  426. for i := 0; i < mt.maxLevels; i++ {
  427. n, err := mt.GetNode(nextKey)
  428. if err != nil {
  429. return nil, nil, nil, err
  430. }
  431. switch n.Type {
  432. case NodeTypeEmpty:
  433. return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
  434. case NodeTypeLeaf:
  435. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  436. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
  437. }
  438. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
  439. case NodeTypeMiddle:
  440. if path[i] {
  441. nextKey = n.ChildR
  442. siblings = append(siblings, n.ChildL)
  443. } else {
  444. nextKey = n.ChildL
  445. siblings = append(siblings, n.ChildR)
  446. }
  447. default:
  448. return nil, nil, nil, ErrInvalidNodeFound
  449. }
  450. }
  451. return nil, nil, nil, ErrReachedMaxLevel
  452. }
  453. // Update updates the value of a specified key in the MerkleTree, and updates
  454. // the path from the leaf to the Root with the new values. Returns the
  455. // CircomProcessorProof.
  456. func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
  457. // verify that the MerkleTree is writable
  458. if !mt.writable {
  459. return nil, ErrNotWritable
  460. }
  461. // verfy that k & are valid and fit inside the Finite Field.
  462. if !cryptoUtils.CheckBigIntInField(k) {
  463. return nil, errors.New("Key not inside the Finite Field")
  464. }
  465. if !cryptoUtils.CheckBigIntInField(v) {
  466. return nil, errors.New("Key not inside the Finite Field")
  467. }
  468. tx, err := mt.db.NewTx()
  469. if err != nil {
  470. return nil, err
  471. }
  472. mt.Lock()
  473. defer mt.Unlock()
  474. kHash := NewHashFromBigInt(k)
  475. vHash := NewHashFromBigInt(v)
  476. path := getPath(mt.maxLevels, kHash[:])
  477. var cp CircomProcessorProof
  478. cp.Fnc = 1
  479. cp.OldRoot = mt.rootKey
  480. cp.OldKey = kHash
  481. cp.NewKey = kHash
  482. cp.NewValue = vHash
  483. nextKey := mt.rootKey
  484. var siblings []*Hash
  485. for i := 0; i < mt.maxLevels; i++ {
  486. n, err := mt.GetNode(nextKey)
  487. if err != nil {
  488. return nil, err
  489. }
  490. switch n.Type {
  491. case NodeTypeEmpty:
  492. return nil, ErrKeyNotFound
  493. case NodeTypeLeaf:
  494. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  495. cp.OldValue = n.Entry[1]
  496. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  497. // update leaf and upload to the root
  498. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  499. _, err := mt.updateNode(tx, newNodeLeaf)
  500. if err != nil {
  501. return nil, err
  502. }
  503. newRootKey, err :=
  504. mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
  505. if err != nil {
  506. return nil, err
  507. }
  508. mt.rootKey = newRootKey
  509. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  510. if err != nil {
  511. return nil, err
  512. }
  513. cp.NewRoot = newRootKey
  514. if err := tx.Commit(); err != nil {
  515. return nil, err
  516. }
  517. return &cp, nil
  518. }
  519. return nil, ErrKeyNotFound
  520. case NodeTypeMiddle:
  521. if path[i] {
  522. nextKey = n.ChildR
  523. siblings = append(siblings, n.ChildL)
  524. } else {
  525. nextKey = n.ChildL
  526. siblings = append(siblings, n.ChildR)
  527. }
  528. default:
  529. return nil, ErrInvalidNodeFound
  530. }
  531. }
  532. return nil, ErrKeyNotFound
  533. }
  534. // Delete removes the specified Key from the MerkleTree and updates the path
  535. // from the deleted key to the Root with the new values. This method removes
  536. // the key from the MerkleTree, but does not remove the old nodes from the
  537. // key-value database; this means that if the tree is accessed by an old Root
  538. // where the key was not deleted yet, the key will still exist. If is desired
  539. // to remove the key-values from the database that are not under the current
  540. // Root, an option could be to dump all the leafs (using mt.DumpLeafs) and
  541. // import them in a new MerkleTree in a new database (using
  542. // mt.ImportDumpedLeafs), but this will loose all the Root history of the
  543. // MerkleTree
  544. func (mt *MerkleTree) Delete(k *big.Int) error {
  545. // verify that the MerkleTree is writable
  546. if !mt.writable {
  547. return ErrNotWritable
  548. }
  549. // verfy that k is valid and fit inside the Finite Field.
  550. if !cryptoUtils.CheckBigIntInField(k) {
  551. return errors.New("Key not inside the Finite Field")
  552. }
  553. tx, err := mt.db.NewTx()
  554. if err != nil {
  555. return err
  556. }
  557. mt.Lock()
  558. defer mt.Unlock()
  559. kHash := NewHashFromBigInt(k)
  560. path := getPath(mt.maxLevels, kHash[:])
  561. nextKey := mt.rootKey
  562. var siblings []*Hash
  563. for i := 0; i < mt.maxLevels; i++ {
  564. n, err := mt.GetNode(nextKey)
  565. if err != nil {
  566. return err
  567. }
  568. switch n.Type {
  569. case NodeTypeEmpty:
  570. return ErrKeyNotFound
  571. case NodeTypeLeaf:
  572. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  573. // remove and go up with the sibling
  574. err = mt.rmAndUpload(tx, path, kHash, siblings)
  575. return err
  576. }
  577. return ErrKeyNotFound
  578. case NodeTypeMiddle:
  579. if path[i] {
  580. nextKey = n.ChildR
  581. siblings = append(siblings, n.ChildL)
  582. } else {
  583. nextKey = n.ChildL
  584. siblings = append(siblings, n.ChildR)
  585. }
  586. default:
  587. return ErrInvalidNodeFound
  588. }
  589. }
  590. return ErrKeyNotFound
  591. }
  592. // rmAndUpload removes the key, and goes up until the root updating all the
  593. // nodes with the new values.
  594. func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
  595. if len(siblings) == 0 {
  596. mt.rootKey = &HashZero
  597. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  598. if err != nil {
  599. return err
  600. }
  601. return tx.Commit()
  602. }
  603. toUpload := siblings[len(siblings)-1]
  604. if len(siblings) < 2 { //nolint:gomnd
  605. mt.rootKey = siblings[0]
  606. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  607. if err != nil {
  608. return err
  609. }
  610. return tx.Commit()
  611. }
  612. for i := len(siblings) - 2; i >= 0; i-- { //nolint:gomnd
  613. if !bytes.Equal(siblings[i][:], HashZero[:]) {
  614. var newNode *Node
  615. if path[i] {
  616. newNode = NewNodeMiddle(siblings[i], toUpload)
  617. } else {
  618. newNode = NewNodeMiddle(toUpload, siblings[i])
  619. }
  620. _, err := mt.addNode(tx, newNode)
  621. if err != ErrNodeKeyAlreadyExists && err != nil {
  622. return err
  623. }
  624. // go up until the root
  625. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode,
  626. siblings[:i])
  627. if err != nil {
  628. return err
  629. }
  630. mt.rootKey = newRootKey
  631. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  632. if err != nil {
  633. return err
  634. }
  635. break
  636. }
  637. // if i==0 (root position), stop and store the sibling of the
  638. // deleted leaf as root
  639. if i == 0 {
  640. mt.rootKey = toUpload
  641. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  642. if err != nil {
  643. return err
  644. }
  645. break
  646. }
  647. }
  648. if err := tx.Commit(); err != nil {
  649. return err
  650. }
  651. return nil
  652. }
  653. // recalculatePathUntilRoot recalculates the nodes until the Root
  654. func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node,
  655. siblings []*Hash) (*Hash, error) {
  656. for i := len(siblings) - 1; i >= 0; i-- {
  657. nodeKey, err := node.Key()
  658. if err != nil {
  659. return nil, err
  660. }
  661. if path[i] {
  662. node = NewNodeMiddle(siblings[i], nodeKey)
  663. } else {
  664. node = NewNodeMiddle(nodeKey, siblings[i])
  665. }
  666. _, err = mt.addNode(tx, node)
  667. if err != ErrNodeKeyAlreadyExists && err != nil {
  668. return nil, err
  669. }
  670. }
  671. // return last node added, which is the root
  672. nodeKey, err := node.Key()
  673. return nodeKey, err
  674. }
  675. // dbInsert is a helper function to insert a node into a key in an open db
  676. // transaction.
  677. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) error {
  678. v := append([]byte{byte(t)}, data...)
  679. return tx.Put(k, v)
  680. }
  681. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  682. // tree; they are all the same and assumed to always exist.
  683. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  684. if bytes.Equal(key[:], HashZero[:]) {
  685. return NewNodeEmpty(), nil
  686. }
  687. nBytes, err := mt.db.Get(key[:])
  688. if err != nil {
  689. return nil, err
  690. }
  691. return NewNodeFromBytes(nBytes)
  692. }
  693. // getPath returns the binary path, from the root to the leaf.
  694. func getPath(numLevels int, k []byte) []bool {
  695. path := make([]bool, numLevels)
  696. for n := 0; n < numLevels; n++ {
  697. path[n] = common.TestBit(k[:], uint(n))
  698. }
  699. return path
  700. }
  701. // NodeAux contains the auxiliary node used in a non-existence proof.
  702. type NodeAux struct {
  703. Key *Hash
  704. Value *Hash
  705. }
  706. // Proof defines the required elements for a MT proof of existence or
  707. // non-existence.
  708. type Proof struct {
  709. // existence indicates wether this is a proof of existence or
  710. // non-existence.
  711. Existence bool
  712. // depth indicates how deep in the tree the proof goes.
  713. depth uint
  714. // notempties is a bitmap of non-empty Siblings found in Siblings.
  715. notempties [ElemBytesLen - proofFlagsLen]byte
  716. // Siblings is a list of non-empty sibling keys.
  717. Siblings []*Hash
  718. NodeAux *NodeAux
  719. }
  720. // NewProofFromBytes parses a byte array into a Proof.
  721. func NewProofFromBytes(bs []byte) (*Proof, error) {
  722. if len(bs) < ElemBytesLen {
  723. return nil, ErrInvalidProofBytes
  724. }
  725. p := &Proof{}
  726. if (bs[0] & 0x01) == 0 {
  727. p.Existence = true
  728. }
  729. p.depth = uint(bs[1])
  730. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  731. siblingBytes := bs[ElemBytesLen:]
  732. sibIdx := 0
  733. for i := uint(0); i < p.depth; i++ {
  734. if common.TestBitBigEndian(p.notempties[:], i) {
  735. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  736. return nil, ErrInvalidProofBytes
  737. }
  738. var sib Hash
  739. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  740. p.Siblings = append(p.Siblings, &sib)
  741. sibIdx++
  742. }
  743. }
  744. if !p.Existence && ((bs[0] & 0x02) != 0) {
  745. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  746. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  747. if len(nodeAuxBytes) != 2*ElemBytesLen {
  748. return nil, ErrInvalidProofBytes
  749. }
  750. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  751. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  752. }
  753. return p, nil
  754. }
  755. // Bytes serializes a Proof into a byte array.
  756. func (p *Proof) Bytes() []byte {
  757. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  758. if p.NodeAux != nil {
  759. bsLen += 2 * ElemBytesLen //nolint:gomnd
  760. }
  761. bs := make([]byte, bsLen)
  762. if !p.Existence {
  763. bs[0] |= 0x01
  764. }
  765. bs[1] = byte(p.depth)
  766. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  767. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  768. for i, k := range p.Siblings {
  769. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  770. }
  771. if p.NodeAux != nil {
  772. bs[0] |= 0x02
  773. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  774. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  775. }
  776. return bs
  777. }
  778. // SiblingsFromProof returns all the siblings of the proof.
  779. func SiblingsFromProof(proof *Proof) []*Hash {
  780. sibIdx := 0
  781. var siblings []*Hash
  782. for lvl := 0; lvl < int(proof.depth); lvl++ {
  783. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  784. siblings = append(siblings, proof.Siblings[sibIdx])
  785. sibIdx++
  786. } else {
  787. siblings = append(siblings, &HashZero)
  788. }
  789. }
  790. return siblings
  791. }
  792. // AllSiblings returns all the siblings of the proof.
  793. func (p *Proof) AllSiblings() []*Hash {
  794. return SiblingsFromProof(p)
  795. }
  796. // CircomSiblingsFromSiblings returns the full siblings compatible with circom
  797. func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
  798. // Add the rest of empty levels to the siblings
  799. for i := len(siblings); i < levels+1; i++ {
  800. siblings = append(siblings, &HashZero)
  801. }
  802. return siblings
  803. }
  804. // CircomProcessorProof defines the ProcessorProof compatible with circom. Is
  805. // the data of the proof between the transition from one state to another.
  806. type CircomProcessorProof struct {
  807. OldRoot *Hash `json:"oldRoot"`
  808. NewRoot *Hash `json:"newRoot"`
  809. Siblings []*Hash `json:"siblings"`
  810. OldKey *Hash `json:"oldKey"`
  811. OldValue *Hash `json:"oldValue"`
  812. NewKey *Hash `json:"newKey"`
  813. NewValue *Hash `json:"newValue"`
  814. IsOld0 bool `json:"isOld0"`
  815. // 0: NOP, 1: Update, 2: Insert, 3: Delete
  816. Fnc int `json:"fnc"`
  817. }
  818. // String returns a human readable string representation of the
  819. // CircomProcessorProof
  820. func (p CircomProcessorProof) String() string {
  821. buf := bytes.NewBufferString("{")
  822. fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
  823. fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
  824. fmt.Fprintf(buf, " Siblings: [\n ")
  825. for _, s := range p.Siblings {
  826. fmt.Fprintf(buf, "%v, ", s)
  827. }
  828. fmt.Fprintf(buf, "\n ],\n")
  829. fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
  830. fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
  831. fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
  832. fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
  833. fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
  834. fmt.Fprintf(buf, "}\n")
  835. return buf.String()
  836. }
  837. // CircomVerifierProof defines the VerifierProof compatible with circom. Is the
  838. // data of the proof that a certain leaf exists in the MerkleTree.
  839. type CircomVerifierProof struct {
  840. Root *Hash `json:"root"`
  841. Siblings []*Hash `json:"siblings"`
  842. OldKey *Hash `json:"oldKey"`
  843. OldValue *Hash `json:"oldValue"`
  844. IsOld0 bool `json:"isOld0"`
  845. Key *Hash `json:"key"`
  846. Value *Hash `json:"value"`
  847. Fnc int `json:"fnc"` // 0: inclusion, 1: non inclusion
  848. }
  849. // GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
  850. // key in the MerkleTree. If the rootKey is nil, the current merkletree root
  851. // is used.
  852. func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int,
  853. rootKey *Hash) (*CircomVerifierProof, error) {
  854. if rootKey == nil {
  855. rootKey = mt.Root()
  856. }
  857. p, v, err := mt.GenerateProof(k, rootKey)
  858. if err != nil && err != ErrKeyNotFound {
  859. return nil, err
  860. }
  861. var cp CircomVerifierProof
  862. cp.Root = rootKey
  863. cp.Siblings = CircomSiblingsFromSiblings(p.AllSiblings(), mt.maxLevels)
  864. cp.OldKey = &HashZero
  865. cp.OldValue = &HashZero
  866. cp.Key = NewHashFromBigInt(k)
  867. cp.Value = NewHashFromBigInt(v)
  868. if p.Existence {
  869. cp.Fnc = 0 // inclusion
  870. } else {
  871. cp.Fnc = 1 // non inclusion
  872. }
  873. return &cp, nil
  874. }
  875. // GenerateProof generates the proof of existence (or non-existence) of an
  876. // Entry's hash Index for a Merkle Tree given the root.
  877. // If the rootKey is nil, the current merkletree root is used
  878. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof,
  879. *big.Int, error) {
  880. p := &Proof{}
  881. var siblingKey *Hash
  882. kHash := NewHashFromBigInt(k)
  883. path := getPath(mt.maxLevels, kHash[:])
  884. if rootKey == nil {
  885. rootKey = mt.Root()
  886. }
  887. nextKey := rootKey
  888. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  889. n, err := mt.GetNode(nextKey)
  890. if err != nil {
  891. return nil, nil, err
  892. }
  893. switch n.Type {
  894. case NodeTypeEmpty:
  895. return p, big.NewInt(0), nil
  896. case NodeTypeLeaf:
  897. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  898. p.Existence = true
  899. return p, n.Entry[1].BigInt(), nil
  900. }
  901. // We found a leaf whose entry didn't match hIndex
  902. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  903. return p, n.Entry[1].BigInt(), nil
  904. case NodeTypeMiddle:
  905. if path[p.depth] {
  906. nextKey = n.ChildR
  907. siblingKey = n.ChildL
  908. } else {
  909. nextKey = n.ChildL
  910. siblingKey = n.ChildR
  911. }
  912. default:
  913. return nil, nil, ErrInvalidNodeFound
  914. }
  915. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  916. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  917. p.Siblings = append(p.Siblings, siblingKey)
  918. }
  919. }
  920. return nil, nil, ErrKeyNotFound
  921. }
  922. // VerifyProof verifies the Merkle Proof for the entry and root.
  923. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  924. rootFromProof, err := RootFromProof(proof, k, v)
  925. if err != nil {
  926. return false
  927. }
  928. return bytes.Equal(rootKey[:], rootFromProof[:])
  929. }
  930. // RootFromProof calculates the root that would correspond to a tree whose
  931. // siblings are the ones in the proof with the leaf hashing to hIndex and
  932. // hValue.
  933. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  934. kHash := NewHashFromBigInt(k)
  935. vHash := NewHashFromBigInt(v)
  936. sibIdx := len(proof.Siblings) - 1
  937. var err error
  938. var midKey *Hash
  939. if proof.Existence {
  940. midKey, err = LeafKey(kHash, vHash)
  941. if err != nil {
  942. return nil, err
  943. }
  944. } else {
  945. if proof.NodeAux == nil {
  946. midKey = &HashZero
  947. } else {
  948. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  949. return nil,
  950. fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  951. }
  952. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  953. if err != nil {
  954. return nil, err
  955. }
  956. }
  957. }
  958. path := getPath(int(proof.depth), kHash[:])
  959. var siblingKey *Hash
  960. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  961. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  962. siblingKey = proof.Siblings[sibIdx]
  963. sibIdx--
  964. } else {
  965. siblingKey = &HashZero
  966. }
  967. if path[lvl] {
  968. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  969. if err != nil {
  970. return nil, err
  971. }
  972. } else {
  973. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  974. if err != nil {
  975. return nil, err
  976. }
  977. }
  978. }
  979. return midKey, nil
  980. }
  981. // walk is a helper recursive function to iterate over all tree branches
  982. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  983. n, err := mt.GetNode(key)
  984. if err != nil {
  985. return err
  986. }
  987. switch n.Type {
  988. case NodeTypeEmpty:
  989. f(n)
  990. case NodeTypeLeaf:
  991. f(n)
  992. case NodeTypeMiddle:
  993. f(n)
  994. if err := mt.walk(n.ChildL, f); err != nil {
  995. return err
  996. }
  997. if err := mt.walk(n.ChildR, f); err != nil {
  998. return err
  999. }
  1000. default:
  1001. return ErrInvalidNodeFound
  1002. }
  1003. return nil
  1004. }
  1005. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  1006. // if rootKey is nil, it will get the current RootKey of the current state of
  1007. // the MerkleTree. For each node, it calls the f function given in the
  1008. // parameters. See some examples of the Walk function usage in the
  1009. // merkletree.go and merkletree_test.go
  1010. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  1011. if rootKey == nil {
  1012. rootKey = mt.Root()
  1013. }
  1014. err := mt.walk(rootKey, f)
  1015. return err
  1016. }
  1017. // GraphViz uses Walk function to generate a string GraphViz representation of
  1018. // the tree and writes it to w
  1019. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  1020. fmt.Fprintf(w, `digraph hierarchy {
  1021. node [fontname=Monospace,fontsize=10,shape=box]
  1022. `)
  1023. cnt := 0
  1024. var errIn error
  1025. err := mt.Walk(rootKey, func(n *Node) {
  1026. k, err := n.Key()
  1027. if err != nil {
  1028. errIn = err
  1029. }
  1030. switch n.Type {
  1031. case NodeTypeEmpty:
  1032. case NodeTypeLeaf:
  1033. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.String())
  1034. case NodeTypeMiddle:
  1035. lr := [2]string{n.ChildL.String(), n.ChildR.String()}
  1036. emptyNodes := ""
  1037. for i := range lr {
  1038. if lr[i] == "0" {
  1039. lr[i] = fmt.Sprintf("empty%v", cnt)
  1040. emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i])
  1041. cnt++
  1042. }
  1043. }
  1044. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.String(), lr[0], lr[1])
  1045. fmt.Fprint(w, emptyNodes)
  1046. default:
  1047. }
  1048. })
  1049. fmt.Fprintf(w, "}\n")
  1050. if errIn != nil {
  1051. return errIn
  1052. }
  1053. return err
  1054. }
  1055. // PrintGraphViz prints directly the GraphViz() output
  1056. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  1057. if rootKey == nil {
  1058. rootKey = mt.Root()
  1059. }
  1060. w := bytes.NewBufferString("")
  1061. fmt.Fprintf(w,
  1062. "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  1063. err := mt.GraphViz(w, nil)
  1064. if err != nil {
  1065. return err
  1066. }
  1067. fmt.Fprintf(w,
  1068. "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  1069. fmt.Println(w)
  1070. return nil
  1071. }
  1072. // DumpLeafs returns all the Leafs that exist under the given Root. If no Root
  1073. // is given (nil), it uses the current Root of the MerkleTree.
  1074. func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
  1075. var b []byte
  1076. err := mt.Walk(rootKey, func(n *Node) {
  1077. if n.Type == NodeTypeLeaf {
  1078. l := n.Entry[0].Bytes()
  1079. r := n.Entry[1].Bytes()
  1080. b = append(b, append(l[:], r[:]...)...)
  1081. }
  1082. })
  1083. return b, err
  1084. }
  1085. // ImportDumpedLeafs parses and adds to the MerkleTree the dumped list of leafs
  1086. // from the DumpLeafs function.
  1087. func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
  1088. for i := 0; i < len(b); i += 64 {
  1089. lr := b[i : i+64]
  1090. lB, err := NewBigIntFromHashBytes(lr[:32])
  1091. if err != nil {
  1092. return err
  1093. }
  1094. rB, err := NewBigIntFromHashBytes(lr[32:])
  1095. if err != nil {
  1096. return err
  1097. }
  1098. err = mt.Add(lB, rB)
  1099. if err != nil {
  1100. return err
  1101. }
  1102. }
  1103. return nil
  1104. }