You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1175 lines
32 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "strings"
  10. "sync"
  11. "github.com/iden3/go-iden3-core/common"
  12. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  13. "github.com/iden3/go-merkletree/db"
  14. )
  15. const (
  16. // proofFlagsLen is the byte length of the flags in the proof header
  17. // (first 32 bytes).
  18. proofFlagsLen = 2
  19. // ElemBytesLen is the length of the Hash byte array
  20. ElemBytesLen = 32
  21. numCharPrint = 8
  22. )
  23. var (
  24. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  25. ErrNodeKeyAlreadyExists = errors.New("key already exists")
  26. // ErrKeyNotFound is used when a key is not found in the MerkleTree.
  27. ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
  28. // ErrNodeBytesBadSize is used when the data of a node has an incorrect
  29. // size and can't be parsed.
  30. ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
  31. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  32. // maximum level.
  33. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  34. // ErrInvalidNodeFound is used when an invalid node is found and can't
  35. // be parsed.
  36. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  37. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  38. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  39. // ErrInvalidDBValue is used when a value in the key value DB is
  40. // invalid (for example, it doen't contain a byte header and a []byte
  41. // body of at least len=1.
  42. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  43. // ErrEntryIndexAlreadyExists is used when the entry index already
  44. // exists in the tree.
  45. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  46. // ErrNotWritable is used when the MerkleTree is not writable and a
  47. // write function is called
  48. ErrNotWritable = errors.New("Merkle Tree not writable")
  49. dbKeyRootNode = []byte("currentroot")
  50. // HashZero is used at Empty nodes
  51. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  52. )
  53. // Hash is the generic type stored in the MerkleTree
  54. type Hash [32]byte
  55. // MarshalText implements the marshaler for the Hash type
  56. func (h Hash) MarshalText() ([]byte, error) {
  57. return []byte(h.BigInt().String()), nil
  58. }
  59. // UnmarshalText implements the unmarshaler for the Hash type
  60. func (h *Hash) UnmarshalText(b []byte) error {
  61. ha, err := NewHashFromString(string(b))
  62. copy(h[:], ha[:])
  63. return err
  64. }
  65. // String returns decimal representation in string format of the Hash
  66. func (h Hash) String() string {
  67. s := h.BigInt().String()
  68. if len(s) < numCharPrint {
  69. return s
  70. }
  71. return s[0:numCharPrint] + "..."
  72. }
  73. // Hex returns the hexadecimal representation of the Hash
  74. func (h Hash) Hex() string {
  75. return hex.EncodeToString(h[:])
  76. // alternatively equivalent, but with too extra steps:
  77. // bRaw := h.BigInt().Bytes()
  78. // b := [32]byte{}
  79. // copy(b[:], common.SwapEndianness(bRaw[:]))
  80. // return hex.EncodeToString(b[:])
  81. }
  82. // BigInt returns the *big.Int representation of the *Hash
  83. func (h *Hash) BigInt() *big.Int {
  84. if new(big.Int).SetBytes(common.SwapEndianness(h[:])) == nil {
  85. return big.NewInt(0)
  86. }
  87. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  88. }
  89. // Bytes returns the []byte representation of the *Hash, which always is 32
  90. // bytes length.
  91. func (h *Hash) Bytes() []byte {
  92. bi := new(big.Int).SetBytes(h[:]).Bytes()
  93. b := [32]byte{}
  94. copy(b[:], common.SwapEndianness(bi[:]))
  95. return b[:]
  96. }
  97. // NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
  98. // endianness in the process. This is the intended method to get a *big.Int
  99. // from a byte array that previously has ben generated by the Hash.Bytes()
  100. // method.
  101. func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
  102. if len(b) != ElemBytesLen {
  103. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  104. }
  105. bi := new(big.Int).SetBytes(b[:ElemBytesLen])
  106. if !cryptoUtils.CheckBigIntInField(bi) {
  107. return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
  108. }
  109. return bi, nil
  110. }
  111. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  112. func NewHashFromBigInt(b *big.Int) *Hash {
  113. r := &Hash{}
  114. copy(r[:], common.SwapEndianness(b.Bytes()))
  115. return r
  116. }
  117. // NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
  118. // in the process. This is the intended method to get a *Hash from a byte array
  119. // that previously has ben generated by the Hash.Bytes() method.
  120. func NewHashFromBytes(b []byte) (*Hash, error) {
  121. if len(b) != ElemBytesLen {
  122. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  123. }
  124. var h Hash
  125. copy(h[:], common.SwapEndianness(b))
  126. return &h, nil
  127. }
  128. // NewHashFromHex returns a *Hash representation of the given hex string
  129. func NewHashFromHex(h string) (*Hash, error) {
  130. h = strings.TrimPrefix(h, "0x")
  131. b, err := hex.DecodeString(h)
  132. if err != nil {
  133. return nil, err
  134. }
  135. return NewHashFromBytes(common.SwapEndianness(b[:]))
  136. }
  137. // NewHashFromString returns a *Hash representation of the given decimal string
  138. func NewHashFromString(s string) (*Hash, error) {
  139. bi, ok := new(big.Int).SetString(s, 10)
  140. if !ok {
  141. return nil, fmt.Errorf("Can not parse string to Hash")
  142. }
  143. return NewHashFromBigInt(bi), nil
  144. }
  145. // MerkleTree is the struct with the main elements of the MerkleTree
  146. type MerkleTree struct {
  147. sync.RWMutex
  148. db db.Storage
  149. rootKey *Hash
  150. writable bool
  151. maxLevels int
  152. }
  153. // NewMerkleTree loads a new Merkletree. If in the sotrage already exists one
  154. // will open that one, if not, will create a new one.
  155. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  156. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  157. root, err := mt.dbGetRoot()
  158. if err == db.ErrNotFound {
  159. tx, err := mt.db.NewTx()
  160. if err != nil {
  161. return nil, err
  162. }
  163. mt.rootKey = &HashZero
  164. err = tx.Put(dbKeyRootNode, mt.rootKey[:])
  165. if err != nil {
  166. return nil, err
  167. }
  168. err = tx.Commit()
  169. if err != nil {
  170. return nil, err
  171. }
  172. return &mt, nil
  173. } else if err != nil {
  174. return nil, err
  175. }
  176. mt.rootKey = root
  177. return &mt, nil
  178. }
  179. func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
  180. v, err := mt.db.Get(dbKeyRootNode)
  181. if err != nil {
  182. return nil, err
  183. }
  184. var root Hash
  185. // Skip first byte which contains the NodeType
  186. copy(root[:], v[1:])
  187. return &root, nil
  188. }
  189. // DB returns the MerkleTree.DB()
  190. func (mt *MerkleTree) DB() db.Storage {
  191. return mt.db
  192. }
  193. // Root returns the MerkleRoot
  194. func (mt *MerkleTree) Root() *Hash {
  195. return mt.rootKey
  196. }
  197. // MaxLevels returns the MT maximum level
  198. func (mt *MerkleTree) MaxLevels() int {
  199. return mt.maxLevels
  200. }
  201. // Snapshot returns a read-only copy of the MerkleTree
  202. func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
  203. mt.RLock()
  204. defer mt.RUnlock()
  205. _, err := mt.GetNode(rootKey)
  206. if err != nil {
  207. return nil, err
  208. }
  209. return &MerkleTree{db: mt.db, maxLevels: mt.maxLevels, rootKey: rootKey, writable: false}, nil
  210. }
  211. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the
  212. // path from the Root to the Leaf.
  213. func (mt *MerkleTree) Add(k, v *big.Int) error {
  214. // verify that the MerkleTree is writable
  215. if !mt.writable {
  216. return ErrNotWritable
  217. }
  218. // verfy that k & v are valid and fit inside the Finite Field.
  219. if !cryptoUtils.CheckBigIntInField(k) {
  220. return errors.New("Key not inside the Finite Field")
  221. }
  222. if !cryptoUtils.CheckBigIntInField(v) {
  223. return errors.New("Value not inside the Finite Field")
  224. }
  225. tx, err := mt.db.NewTx()
  226. if err != nil {
  227. return err
  228. }
  229. mt.Lock()
  230. defer mt.Unlock()
  231. kHash := NewHashFromBigInt(k)
  232. vHash := NewHashFromBigInt(v)
  233. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  234. path := getPath(mt.maxLevels, kHash[:])
  235. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  236. if err != nil {
  237. return err
  238. }
  239. mt.rootKey = newRootKey
  240. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  241. if err != nil {
  242. return err
  243. }
  244. if err := tx.Commit(); err != nil {
  245. return err
  246. }
  247. return nil
  248. }
  249. // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
  250. func (mt *MerkleTree) AddAndGetCircomProof(k, v *big.Int) (*CircomProcessorProof, error) {
  251. var cp CircomProcessorProof
  252. cp.Fnc = 2
  253. cp.OldRoot = mt.rootKey
  254. gettedK, gettedV, _, err := mt.Get(k)
  255. if err != nil && err != ErrKeyNotFound {
  256. return nil, err
  257. }
  258. cp.OldKey = NewHashFromBigInt(gettedK)
  259. cp.OldValue = NewHashFromBigInt(gettedV)
  260. if bytes.Equal(cp.OldKey[:], HashZero[:]) {
  261. cp.IsOld0 = true
  262. }
  263. _, _, siblings, err := mt.Get(k)
  264. if err != nil && err != ErrKeyNotFound {
  265. return nil, err
  266. }
  267. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  268. err = mt.Add(k, v)
  269. if err != nil {
  270. return nil, err
  271. }
  272. cp.NewKey = NewHashFromBigInt(k)
  273. cp.NewValue = NewHashFromBigInt(v)
  274. cp.NewRoot = mt.rootKey
  275. return &cp, nil
  276. }
  277. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  278. // from newLeaf, at which point both leafs are stored, all while updating the
  279. // path.
  280. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  281. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  282. if lvl > mt.maxLevels-2 {
  283. return nil, ErrReachedMaxLevel
  284. }
  285. var newNodeMiddle *Node
  286. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  287. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  288. if err != nil {
  289. return nil, err
  290. }
  291. if pathNewLeaf[lvl] {
  292. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  293. } else {
  294. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  295. }
  296. return mt.addNode(tx, newNodeMiddle)
  297. } else {
  298. oldLeafKey, err := oldLeaf.Key()
  299. if err != nil {
  300. return nil, err
  301. }
  302. newLeafKey, err := newLeaf.Key()
  303. if err != nil {
  304. return nil, err
  305. }
  306. if pathNewLeaf[lvl] {
  307. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  308. } else {
  309. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  310. }
  311. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  312. _, err = mt.addNode(tx, newLeaf)
  313. if err != nil {
  314. return nil, err
  315. }
  316. return mt.addNode(tx, newNodeMiddle)
  317. }
  318. }
  319. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  320. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  321. lvl int, path []bool) (*Hash, error) {
  322. var err error
  323. var nextKey *Hash
  324. if lvl > mt.maxLevels-1 {
  325. return nil, ErrReachedMaxLevel
  326. }
  327. n, err := mt.GetNode(key)
  328. if err != nil {
  329. return nil, err
  330. }
  331. switch n.Type {
  332. case NodeTypeEmpty:
  333. // We can add newLeaf now
  334. return mt.addNode(tx, newLeaf)
  335. case NodeTypeLeaf:
  336. nKey := n.Entry[0]
  337. // Check if leaf node found contains the leaf node we are trying to add
  338. newLeafKey := newLeaf.Entry[0]
  339. if bytes.Equal(nKey[:], newLeafKey[:]) {
  340. return nil, ErrEntryIndexAlreadyExists
  341. }
  342. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  343. // We need to push newLeaf down until its path diverges from n's path
  344. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  345. case NodeTypeMiddle:
  346. // We need to go deeper, continue traversing the tree, left or
  347. // right depending on path
  348. var newNodeMiddle *Node
  349. if path[lvl] {
  350. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  351. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  352. } else {
  353. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  354. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  355. }
  356. if err != nil {
  357. return nil, err
  358. }
  359. // Update the node to reflect the modified child
  360. return mt.addNode(tx, newNodeMiddle)
  361. default:
  362. return nil, ErrInvalidNodeFound
  363. }
  364. }
  365. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  366. // they are all the same and assumed to always exist.
  367. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  368. // verify that the MerkleTree is writable
  369. if !mt.writable {
  370. return nil, ErrNotWritable
  371. }
  372. if n.Type == NodeTypeEmpty {
  373. return n.Key()
  374. }
  375. k, err := n.Key()
  376. if err != nil {
  377. return nil, err
  378. }
  379. v := n.Value()
  380. // Check that the node key doesn't already exist
  381. if _, err := tx.Get(k[:]); err == nil {
  382. return nil, ErrNodeKeyAlreadyExists
  383. }
  384. err = tx.Put(k[:], v)
  385. return k, err
  386. }
  387. // updateNode updates an existing node in the MT. Empty nodes are not stored
  388. // in the tree; they are all the same and assumed to always exist.
  389. func (mt *MerkleTree) updateNode(tx db.Tx, n *Node) (*Hash, error) {
  390. // verify that the MerkleTree is writable
  391. if !mt.writable {
  392. return nil, ErrNotWritable
  393. }
  394. if n.Type == NodeTypeEmpty {
  395. return n.Key()
  396. }
  397. k, err := n.Key()
  398. if err != nil {
  399. return nil, err
  400. }
  401. v := n.Value()
  402. err = tx.Put(k[:], v)
  403. return k, err
  404. }
  405. // Get returns the value of the leaf for the given key
  406. func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
  407. // verfy that k is valid and fit inside the Finite Field.
  408. if !cryptoUtils.CheckBigIntInField(k) {
  409. return nil, nil, nil, errors.New("Key not inside the Finite Field")
  410. }
  411. kHash := NewHashFromBigInt(k)
  412. path := getPath(mt.maxLevels, kHash[:])
  413. nextKey := mt.rootKey
  414. var siblings []*Hash
  415. for i := 0; i < mt.maxLevels; i++ {
  416. n, err := mt.GetNode(nextKey)
  417. if err != nil {
  418. return nil, nil, nil, err
  419. }
  420. switch n.Type {
  421. case NodeTypeEmpty:
  422. return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
  423. case NodeTypeLeaf:
  424. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  425. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
  426. } else {
  427. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
  428. }
  429. case NodeTypeMiddle:
  430. if path[i] {
  431. nextKey = n.ChildR
  432. siblings = append(siblings, n.ChildL)
  433. } else {
  434. nextKey = n.ChildL
  435. siblings = append(siblings, n.ChildR)
  436. }
  437. default:
  438. return nil, nil, nil, ErrInvalidNodeFound
  439. }
  440. }
  441. return nil, nil, nil, ErrReachedMaxLevel
  442. }
  443. // Update updates the value of a specified key in the MerkleTree, and updates
  444. // the path from the leaf to the Root with the new values. Returns the
  445. // CircomProcessorProof.
  446. func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
  447. // verify that the MerkleTree is writable
  448. if !mt.writable {
  449. return nil, ErrNotWritable
  450. }
  451. // verfy that k & are valid and fit inside the Finite Field.
  452. if !cryptoUtils.CheckBigIntInField(k) {
  453. return nil, errors.New("Key not inside the Finite Field")
  454. }
  455. if !cryptoUtils.CheckBigIntInField(v) {
  456. return nil, errors.New("Key not inside the Finite Field")
  457. }
  458. tx, err := mt.db.NewTx()
  459. if err != nil {
  460. return nil, err
  461. }
  462. mt.Lock()
  463. defer mt.Unlock()
  464. kHash := NewHashFromBigInt(k)
  465. vHash := NewHashFromBigInt(v)
  466. path := getPath(mt.maxLevels, kHash[:])
  467. var cp CircomProcessorProof
  468. cp.Fnc = 1
  469. cp.OldRoot = mt.rootKey
  470. cp.OldKey = kHash
  471. cp.NewKey = kHash
  472. cp.NewValue = vHash
  473. nextKey := mt.rootKey
  474. var siblings []*Hash
  475. for i := 0; i < mt.maxLevels; i++ {
  476. n, err := mt.GetNode(nextKey)
  477. if err != nil {
  478. return nil, err
  479. }
  480. switch n.Type {
  481. case NodeTypeEmpty:
  482. return nil, ErrKeyNotFound
  483. case NodeTypeLeaf:
  484. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  485. cp.OldValue = n.Entry[1]
  486. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  487. // update leaf and upload to the root
  488. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  489. _, err := mt.updateNode(tx, newNodeLeaf)
  490. if err != nil {
  491. return nil, err
  492. }
  493. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
  494. if err != nil {
  495. return nil, err
  496. }
  497. mt.rootKey = newRootKey
  498. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  499. if err != nil {
  500. return nil, err
  501. }
  502. cp.NewRoot = newRootKey
  503. if err := tx.Commit(); err != nil {
  504. return nil, err
  505. }
  506. return &cp, nil
  507. } else {
  508. return nil, ErrKeyNotFound
  509. }
  510. case NodeTypeMiddle:
  511. if path[i] {
  512. nextKey = n.ChildR
  513. siblings = append(siblings, n.ChildL)
  514. } else {
  515. nextKey = n.ChildL
  516. siblings = append(siblings, n.ChildR)
  517. }
  518. default:
  519. return nil, ErrInvalidNodeFound
  520. }
  521. }
  522. return nil, ErrKeyNotFound
  523. }
  524. // Delete removes the specified Key from the MerkleTree and updates the path
  525. // from the deleted key to the Root with the new values. This method removes
  526. // the key from the MerkleTree, but does not remove the old nodes from the
  527. // key-value database; this means that if the tree is accessed by an old Root
  528. // where the key was not deleted yet, the key will still exist. If is desired
  529. // to remove the key-values from the database that are not under the current
  530. // Root, an option could be to dump all the leafs (using mt.DumpLeafs) and
  531. // import them in a new MerkleTree in a new database (using
  532. // mt.ImportDumpedLeafs), but this will loose all the Root history of the
  533. // MerkleTree
  534. func (mt *MerkleTree) Delete(k *big.Int) error {
  535. // verify that the MerkleTree is writable
  536. if !mt.writable {
  537. return ErrNotWritable
  538. }
  539. // verfy that k is valid and fit inside the Finite Field.
  540. if !cryptoUtils.CheckBigIntInField(k) {
  541. return errors.New("Key not inside the Finite Field")
  542. }
  543. tx, err := mt.db.NewTx()
  544. if err != nil {
  545. return err
  546. }
  547. mt.Lock()
  548. defer mt.Unlock()
  549. kHash := NewHashFromBigInt(k)
  550. path := getPath(mt.maxLevels, kHash[:])
  551. nextKey := mt.rootKey
  552. var siblings []*Hash
  553. for i := 0; i < mt.maxLevels; i++ {
  554. n, err := mt.GetNode(nextKey)
  555. if err != nil {
  556. return err
  557. }
  558. switch n.Type {
  559. case NodeTypeEmpty:
  560. return ErrKeyNotFound
  561. case NodeTypeLeaf:
  562. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  563. // remove and go up with the sibling
  564. err = mt.rmAndUpload(tx, path, kHash, siblings)
  565. return err
  566. } else {
  567. return ErrKeyNotFound
  568. }
  569. case NodeTypeMiddle:
  570. if path[i] {
  571. nextKey = n.ChildR
  572. siblings = append(siblings, n.ChildL)
  573. } else {
  574. nextKey = n.ChildL
  575. siblings = append(siblings, n.ChildR)
  576. }
  577. default:
  578. return ErrInvalidNodeFound
  579. }
  580. }
  581. return ErrKeyNotFound
  582. }
  583. // rmAndUpload removes the key, and goes up until the root updating all the nodes with the new values.
  584. func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
  585. if len(siblings) == 0 {
  586. mt.rootKey = &HashZero
  587. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  588. if err != nil {
  589. return err
  590. }
  591. return tx.Commit()
  592. }
  593. toUpload := siblings[len(siblings)-1]
  594. if len(siblings) < 2 { //nolint:gomnd
  595. mt.rootKey = siblings[0]
  596. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  597. if err != nil {
  598. return err
  599. }
  600. return tx.Commit()
  601. }
  602. for i := len(siblings) - 2; i >= 0; i-- { //nolint:gomnd
  603. if !bytes.Equal(siblings[i][:], HashZero[:]) {
  604. var newNode *Node
  605. if path[i] {
  606. newNode = NewNodeMiddle(siblings[i], toUpload)
  607. } else {
  608. newNode = NewNodeMiddle(toUpload, siblings[i])
  609. }
  610. _, err := mt.addNode(tx, newNode)
  611. if err != ErrNodeKeyAlreadyExists && err != nil {
  612. return err
  613. }
  614. // go up until the root
  615. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode, siblings[:i])
  616. if err != nil {
  617. return err
  618. }
  619. mt.rootKey = newRootKey
  620. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  621. if err != nil {
  622. return err
  623. }
  624. break
  625. }
  626. // if i==0 (root position), stop and store the sibling of the deleted leaf as root
  627. if i == 0 {
  628. mt.rootKey = toUpload
  629. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  630. if err != nil {
  631. return err
  632. }
  633. break
  634. }
  635. }
  636. if err := tx.Commit(); err != nil {
  637. return err
  638. }
  639. return nil
  640. }
  641. // recalculatePathUntilRoot recalculates the nodes until the Root
  642. func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node, siblings []*Hash) (*Hash, error) {
  643. for i := len(siblings) - 1; i >= 0; i-- {
  644. nodeKey, err := node.Key()
  645. if err != nil {
  646. return nil, err
  647. }
  648. if path[i] {
  649. node = NewNodeMiddle(siblings[i], nodeKey)
  650. } else {
  651. node = NewNodeMiddle(nodeKey, siblings[i])
  652. }
  653. _, err = mt.addNode(tx, node)
  654. if err != ErrNodeKeyAlreadyExists && err != nil {
  655. return nil, err
  656. }
  657. }
  658. // return last node added, which is the root
  659. nodeKey, err := node.Key()
  660. return nodeKey, err
  661. }
  662. // dbInsert is a helper function to insert a node into a key in an open db
  663. // transaction.
  664. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) error {
  665. v := append([]byte{byte(t)}, data...)
  666. return tx.Put(k, v)
  667. }
  668. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  669. // tree; they are all the same and assumed to always exist.
  670. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  671. if bytes.Equal(key[:], HashZero[:]) {
  672. return NewNodeEmpty(), nil
  673. }
  674. nBytes, err := mt.db.Get(key[:])
  675. if err != nil {
  676. return nil, err
  677. }
  678. return NewNodeFromBytes(nBytes)
  679. }
  680. // getPath returns the binary path, from the root to the leaf.
  681. func getPath(numLevels int, k []byte) []bool {
  682. path := make([]bool, numLevels)
  683. for n := 0; n < numLevels; n++ {
  684. path[n] = common.TestBit(k[:], uint(n))
  685. }
  686. return path
  687. }
  688. // NodeAux contains the auxiliary node used in a non-existence proof.
  689. type NodeAux struct {
  690. Key *Hash
  691. Value *Hash
  692. }
  693. // Proof defines the required elements for a MT proof of existence or non-existence.
  694. type Proof struct {
  695. // existence indicates wether this is a proof of existence or non-existence.
  696. Existence bool
  697. // depth indicates how deep in the tree the proof goes.
  698. depth uint
  699. // notempties is a bitmap of non-empty Siblings found in Siblings.
  700. notempties [ElemBytesLen - proofFlagsLen]byte
  701. // Siblings is a list of non-empty sibling keys.
  702. Siblings []*Hash
  703. NodeAux *NodeAux
  704. }
  705. // NewProofFromBytes parses a byte array into a Proof.
  706. func NewProofFromBytes(bs []byte) (*Proof, error) {
  707. if len(bs) < ElemBytesLen {
  708. return nil, ErrInvalidProofBytes
  709. }
  710. p := &Proof{}
  711. if (bs[0] & 0x01) == 0 {
  712. p.Existence = true
  713. }
  714. p.depth = uint(bs[1])
  715. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  716. siblingBytes := bs[ElemBytesLen:]
  717. sibIdx := 0
  718. for i := uint(0); i < p.depth; i++ {
  719. if common.TestBitBigEndian(p.notempties[:], i) {
  720. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  721. return nil, ErrInvalidProofBytes
  722. }
  723. var sib Hash
  724. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  725. p.Siblings = append(p.Siblings, &sib)
  726. sibIdx++
  727. }
  728. }
  729. if !p.Existence && ((bs[0] & 0x02) != 0) {
  730. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  731. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  732. if len(nodeAuxBytes) != 2*ElemBytesLen {
  733. return nil, ErrInvalidProofBytes
  734. }
  735. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  736. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  737. }
  738. return p, nil
  739. }
  740. // Bytes serializes a Proof into a byte array.
  741. func (p *Proof) Bytes() []byte {
  742. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  743. if p.NodeAux != nil {
  744. bsLen += 2 * ElemBytesLen //nolint:gomnd
  745. }
  746. bs := make([]byte, bsLen)
  747. if !p.Existence {
  748. bs[0] |= 0x01
  749. }
  750. bs[1] = byte(p.depth)
  751. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  752. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  753. for i, k := range p.Siblings {
  754. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  755. }
  756. if p.NodeAux != nil {
  757. bs[0] |= 0x02
  758. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  759. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  760. }
  761. return bs
  762. }
  763. // SiblingsFromProof returns all the siblings of the proof.
  764. func SiblingsFromProof(proof *Proof) []*Hash {
  765. sibIdx := 0
  766. var siblings []*Hash
  767. for lvl := 0; lvl < int(proof.depth); lvl++ {
  768. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  769. siblings = append(siblings, proof.Siblings[sibIdx])
  770. sibIdx++
  771. } else {
  772. siblings = append(siblings, &HashZero)
  773. }
  774. }
  775. return siblings
  776. }
  777. // AllSiblings returns all the siblings of the proof.
  778. func (p *Proof) AllSiblings() []*Hash {
  779. return SiblingsFromProof(p)
  780. }
  781. // CircomSiblingsFromSiblings returns the full siblings compatible with circom
  782. func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
  783. // Add the rest of empty levels to the siblings
  784. for i := len(siblings); i < levels+1; i++ {
  785. siblings = append(siblings, &HashZero)
  786. }
  787. return siblings
  788. }
  789. // CircomProcessorProof defines the ProcessorProof compatible with circom. Is
  790. // the data of the proof between the transition from one state to another.
  791. type CircomProcessorProof struct {
  792. OldRoot *Hash `json:"oldRoot"`
  793. NewRoot *Hash `json:"newRoot"`
  794. Siblings []*Hash `json:"siblings"`
  795. OldKey *Hash `json:"oldKey"`
  796. OldValue *Hash `json:"oldValue"`
  797. NewKey *Hash `json:"newKey"`
  798. NewValue *Hash `json:"newValue"`
  799. IsOld0 bool `json:"isOld0"`
  800. Fnc int `json:"fnc"` // 0: NOP, 1: Update, 2: Insert, 3: Delete
  801. }
  802. // String returns a human readable string representation of the
  803. // CircomProcessorProof
  804. func (p CircomProcessorProof) String() string {
  805. buf := bytes.NewBufferString("{")
  806. fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
  807. fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
  808. fmt.Fprintf(buf, " Siblings: [\n ")
  809. for _, s := range p.Siblings {
  810. fmt.Fprintf(buf, "%v, ", s)
  811. }
  812. fmt.Fprintf(buf, "\n ],\n")
  813. fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
  814. fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
  815. fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
  816. fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
  817. fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
  818. fmt.Fprintf(buf, "}\n")
  819. return buf.String()
  820. }
  821. // CircomVerifierProof defines the VerifierProof compatible with circom. Is the
  822. // data of the proof that a certain leaf exists in the MerkleTree.
  823. type CircomVerifierProof struct {
  824. Root *Hash `json:"root"`
  825. Siblings []*Hash `json:"siblings"`
  826. OldKey *Hash `json:"oldKey"`
  827. OldValue *Hash `json:"oldValue"`
  828. IsOld0 bool `json:"isOld0"`
  829. Key *Hash `json:"key"`
  830. Value *Hash `json:"value"`
  831. Fnc int `json:"fnc"` // 0: inclusion, 1: non inclusion
  832. }
  833. // GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
  834. // key in the MerkleTree. If the rootKey is nil, the current merkletree root
  835. // is used.
  836. func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int, rootKey *Hash) (*CircomVerifierProof, error) {
  837. if rootKey == nil {
  838. rootKey = mt.Root()
  839. }
  840. p, v, err := mt.GenerateProof(k, rootKey)
  841. if err != nil && err != ErrKeyNotFound {
  842. return nil, err
  843. }
  844. var cp CircomVerifierProof
  845. cp.Root = rootKey
  846. cp.Siblings = CircomSiblingsFromSiblings(p.AllSiblings(), mt.maxLevels)
  847. cp.OldKey = &HashZero
  848. cp.OldValue = &HashZero
  849. cp.Key = NewHashFromBigInt(k)
  850. cp.Value = NewHashFromBigInt(v)
  851. if p.Existence {
  852. cp.Fnc = 0 // inclusion
  853. } else {
  854. cp.Fnc = 1 // non inclusion
  855. }
  856. return &cp, nil
  857. }
  858. // GenerateProof generates the proof of existence (or non-existence) of an
  859. // Entry's hash Index for a Merkle Tree given the root.
  860. // If the rootKey is nil, the current merkletree root is used
  861. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, *big.Int, error) {
  862. p := &Proof{}
  863. var siblingKey *Hash
  864. kHash := NewHashFromBigInt(k)
  865. path := getPath(mt.maxLevels, kHash[:])
  866. if rootKey == nil {
  867. rootKey = mt.Root()
  868. }
  869. nextKey := rootKey
  870. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  871. n, err := mt.GetNode(nextKey)
  872. if err != nil {
  873. return nil, nil, err
  874. }
  875. switch n.Type {
  876. case NodeTypeEmpty:
  877. return p, big.NewInt(0), nil
  878. case NodeTypeLeaf:
  879. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  880. p.Existence = true
  881. return p, n.Entry[1].BigInt(), nil
  882. } else {
  883. // We found a leaf whose entry didn't match hIndex
  884. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  885. return p, n.Entry[1].BigInt(), nil
  886. }
  887. case NodeTypeMiddle:
  888. if path[p.depth] {
  889. nextKey = n.ChildR
  890. siblingKey = n.ChildL
  891. } else {
  892. nextKey = n.ChildL
  893. siblingKey = n.ChildR
  894. }
  895. default:
  896. return nil, nil, ErrInvalidNodeFound
  897. }
  898. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  899. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  900. p.Siblings = append(p.Siblings, siblingKey)
  901. }
  902. }
  903. return nil, nil, ErrKeyNotFound
  904. }
  905. // VerifyProof verifies the Merkle Proof for the entry and root.
  906. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  907. rootFromProof, err := RootFromProof(proof, k, v)
  908. if err != nil {
  909. return false
  910. }
  911. return bytes.Equal(rootKey[:], rootFromProof[:])
  912. }
  913. // RootFromProof calculates the root that would correspond to a tree whose
  914. // siblings are the ones in the proof with the leaf hashing to hIndex and
  915. // hValue.
  916. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  917. kHash := NewHashFromBigInt(k)
  918. vHash := NewHashFromBigInt(v)
  919. sibIdx := len(proof.Siblings) - 1
  920. var err error
  921. var midKey *Hash
  922. if proof.Existence {
  923. midKey, err = LeafKey(kHash, vHash)
  924. if err != nil {
  925. return nil, err
  926. }
  927. } else {
  928. if proof.NodeAux == nil {
  929. midKey = &HashZero
  930. } else {
  931. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  932. return nil, fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  933. }
  934. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  935. if err != nil {
  936. return nil, err
  937. }
  938. }
  939. }
  940. path := getPath(int(proof.depth), kHash[:])
  941. var siblingKey *Hash
  942. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  943. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  944. siblingKey = proof.Siblings[sibIdx]
  945. sibIdx--
  946. } else {
  947. siblingKey = &HashZero
  948. }
  949. if path[lvl] {
  950. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  951. if err != nil {
  952. return nil, err
  953. }
  954. } else {
  955. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  956. if err != nil {
  957. return nil, err
  958. }
  959. }
  960. }
  961. return midKey, nil
  962. }
  963. // walk is a helper recursive function to iterate over all tree branches
  964. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  965. n, err := mt.GetNode(key)
  966. if err != nil {
  967. return err
  968. }
  969. switch n.Type {
  970. case NodeTypeEmpty:
  971. f(n)
  972. case NodeTypeLeaf:
  973. f(n)
  974. case NodeTypeMiddle:
  975. f(n)
  976. if err := mt.walk(n.ChildL, f); err != nil {
  977. return err
  978. }
  979. if err := mt.walk(n.ChildR, f); err != nil {
  980. return err
  981. }
  982. default:
  983. return ErrInvalidNodeFound
  984. }
  985. return nil
  986. }
  987. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  988. // if rootKey is nil, it will get the current RootKey of the current state of the MerkleTree.
  989. // For each node, it calls the f function given in the parameters.
  990. // See some examples of the Walk function usage in the merkletree.go and
  991. // merkletree_test.go
  992. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  993. if rootKey == nil {
  994. rootKey = mt.Root()
  995. }
  996. err := mt.walk(rootKey, f)
  997. return err
  998. }
  999. // GraphViz uses Walk function to generate a string GraphViz representation of the
  1000. // tree and writes it to w
  1001. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  1002. fmt.Fprintf(w, `digraph hierarchy {
  1003. node [fontname=Monospace,fontsize=10,shape=box]
  1004. `)
  1005. cnt := 0
  1006. var errIn error
  1007. err := mt.Walk(rootKey, func(n *Node) {
  1008. k, err := n.Key()
  1009. if err != nil {
  1010. errIn = err
  1011. }
  1012. switch n.Type {
  1013. case NodeTypeEmpty:
  1014. case NodeTypeLeaf:
  1015. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.String())
  1016. case NodeTypeMiddle:
  1017. lr := [2]string{n.ChildL.String(), n.ChildR.String()}
  1018. emptyNodes := ""
  1019. for i := range lr {
  1020. if lr[i] == "0" {
  1021. lr[i] = fmt.Sprintf("empty%v", cnt)
  1022. emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i])
  1023. cnt++
  1024. }
  1025. }
  1026. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.String(), lr[0], lr[1])
  1027. fmt.Fprint(w, emptyNodes)
  1028. default:
  1029. }
  1030. })
  1031. fmt.Fprintf(w, "}\n")
  1032. if errIn != nil {
  1033. return errIn
  1034. }
  1035. return err
  1036. }
  1037. // PrintGraphViz prints directly the GraphViz() output
  1038. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  1039. if rootKey == nil {
  1040. rootKey = mt.Root()
  1041. }
  1042. w := bytes.NewBufferString("")
  1043. fmt.Fprintf(w, "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  1044. err := mt.GraphViz(w, nil)
  1045. if err != nil {
  1046. return err
  1047. }
  1048. fmt.Fprintf(w, "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  1049. fmt.Println(w)
  1050. return nil
  1051. }
  1052. // DumpLeafs returns all the Leafs that exist under the given Root. If no Root
  1053. // is given (nil), it uses the current Root of the MerkleTree.
  1054. func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
  1055. var b []byte
  1056. err := mt.Walk(rootKey, func(n *Node) {
  1057. if n.Type == NodeTypeLeaf {
  1058. l := n.Entry[0].Bytes()
  1059. r := n.Entry[1].Bytes()
  1060. b = append(b, append(l[:], r[:]...)...)
  1061. }
  1062. })
  1063. return b, err
  1064. }
  1065. // ImportDumpedLeafs parses and adds to the MerkleTree the dumped list of leafs
  1066. // from the DumpLeafs function.
  1067. func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
  1068. for i := 0; i < len(b); i += 64 {
  1069. lr := b[i : i+64]
  1070. lB, err := NewBigIntFromHashBytes(lr[:32])
  1071. if err != nil {
  1072. return err
  1073. }
  1074. rB, err := NewBigIntFromHashBytes(lr[32:])
  1075. if err != nil {
  1076. return err
  1077. }
  1078. err = mt.Add(lB, rB)
  1079. if err != nil {
  1080. return err
  1081. }
  1082. }
  1083. return nil
  1084. }