You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1206 lines
33 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math/big"
  9. "strings"
  10. "sync"
  11. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  12. "github.com/iden3/go-merkletree/db"
  13. )
  14. const (
  15. // proofFlagsLen is the byte length of the flags in the proof header
  16. // (first 32 bytes).
  17. proofFlagsLen = 2
  18. // ElemBytesLen is the length of the Hash byte array
  19. ElemBytesLen = 32
  20. numCharPrint = 8
  21. )
  22. var (
  23. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  24. ErrNodeKeyAlreadyExists = errors.New("key already exists")
  25. // ErrKeyNotFound is used when a key is not found in the MerkleTree.
  26. ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
  27. // ErrNodeBytesBadSize is used when the data of a node has an incorrect
  28. // size and can't be parsed.
  29. ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
  30. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  31. // maximum level.
  32. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  33. // ErrInvalidNodeFound is used when an invalid node is found and can't
  34. // be parsed.
  35. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  36. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  37. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  38. // ErrInvalidDBValue is used when a value in the key value DB is
  39. // invalid (for example, it doen't contain a byte header and a []byte
  40. // body of at least len=1.
  41. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  42. // ErrEntryIndexAlreadyExists is used when the entry index already
  43. // exists in the tree.
  44. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  45. // ErrNotWritable is used when the MerkleTree is not writable and a
  46. // write function is called
  47. ErrNotWritable = errors.New("Merkle Tree not writable")
  48. dbKeyRootNode = []byte("currentroot")
  49. // HashZero is used at Empty nodes
  50. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  51. 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  52. )
  53. // Hash is the generic type stored in the MerkleTree
  54. type Hash [32]byte
  55. // MarshalText implements the marshaler for the Hash type
  56. func (h Hash) MarshalText() ([]byte, error) {
  57. return []byte(h.BigInt().String()), nil
  58. }
  59. // UnmarshalText implements the unmarshaler for the Hash type
  60. func (h *Hash) UnmarshalText(b []byte) error {
  61. ha, err := NewHashFromString(string(b))
  62. copy(h[:], ha[:])
  63. return err
  64. }
  65. // String returns decimal representation in string format of the Hash
  66. func (h Hash) String() string {
  67. s := h.BigInt().String()
  68. if len(s) < numCharPrint {
  69. return s
  70. }
  71. return s[0:numCharPrint] + "..."
  72. }
  73. // Hex returns the hexadecimal representation of the Hash
  74. func (h Hash) Hex() string {
  75. return hex.EncodeToString(h[:])
  76. // alternatively equivalent, but with too extra steps:
  77. // bRaw := h.BigInt().Bytes()
  78. // b := [32]byte{}
  79. // copy(b[:], SwapEndianness(bRaw[:]))
  80. // return hex.EncodeToString(b[:])
  81. }
  82. // BigInt returns the *big.Int representation of the *Hash
  83. func (h *Hash) BigInt() *big.Int {
  84. if new(big.Int).SetBytes(SwapEndianness(h[:])) == nil {
  85. return big.NewInt(0)
  86. }
  87. return new(big.Int).SetBytes(SwapEndianness(h[:]))
  88. }
  89. // Bytes returns the []byte representation of the *Hash, which always is 32
  90. // bytes length.
  91. func (h *Hash) Bytes() []byte {
  92. bi := new(big.Int).SetBytes(h[:]).Bytes()
  93. b := [32]byte{}
  94. copy(b[:], SwapEndianness(bi[:]))
  95. return b[:]
  96. }
  97. // NewBigIntFromHashBytes returns a *big.Int from a byte array, swapping the
  98. // endianness in the process. This is the intended method to get a *big.Int
  99. // from a byte array that previously has ben generated by the Hash.Bytes()
  100. // method.
  101. func NewBigIntFromHashBytes(b []byte) (*big.Int, error) {
  102. if len(b) != ElemBytesLen {
  103. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  104. }
  105. bi := new(big.Int).SetBytes(b[:ElemBytesLen])
  106. if !cryptoUtils.CheckBigIntInField(bi) {
  107. return nil, fmt.Errorf("NewBigIntFromHashBytes: Value not inside the Finite Field")
  108. }
  109. return bi, nil
  110. }
  111. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  112. func NewHashFromBigInt(b *big.Int) *Hash {
  113. r := &Hash{}
  114. copy(r[:], SwapEndianness(b.Bytes()))
  115. return r
  116. }
  117. // NewHashFromBytes returns a *Hash from a byte array, swapping the endianness
  118. // in the process. This is the intended method to get a *Hash from a byte array
  119. // that previously has ben generated by the Hash.Bytes() method.
  120. func NewHashFromBytes(b []byte) (*Hash, error) {
  121. if len(b) != ElemBytesLen {
  122. return nil, fmt.Errorf("Expected 32 bytes, found %d bytes", len(b))
  123. }
  124. var h Hash
  125. copy(h[:], SwapEndianness(b))
  126. return &h, nil
  127. }
  128. // NewHashFromHex returns a *Hash representation of the given hex string
  129. func NewHashFromHex(h string) (*Hash, error) {
  130. h = strings.TrimPrefix(h, "0x")
  131. b, err := hex.DecodeString(h)
  132. if err != nil {
  133. return nil, err
  134. }
  135. return NewHashFromBytes(SwapEndianness(b[:]))
  136. }
  137. // NewHashFromString returns a *Hash representation of the given decimal string
  138. func NewHashFromString(s string) (*Hash, error) {
  139. bi, ok := new(big.Int).SetString(s, 10)
  140. if !ok {
  141. return nil, fmt.Errorf("Can not parse string to Hash")
  142. }
  143. return NewHashFromBigInt(bi), nil
  144. }
  145. // MerkleTree is the struct with the main elements of the MerkleTree
  146. type MerkleTree struct {
  147. sync.RWMutex
  148. db db.Storage
  149. rootKey *Hash
  150. writable bool
  151. maxLevels int
  152. }
  153. // NewMerkleTree loads a new Merkletree. If in the sotrage already exists one
  154. // will open that one, if not, will create a new one.
  155. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  156. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  157. root, err := mt.dbGetRoot()
  158. if err == db.ErrNotFound {
  159. tx, err := mt.db.NewTx()
  160. if err != nil {
  161. return nil, err
  162. }
  163. mt.rootKey = &HashZero
  164. err = tx.Put(dbKeyRootNode, mt.rootKey[:])
  165. if err != nil {
  166. return nil, err
  167. }
  168. err = tx.Commit()
  169. if err != nil {
  170. return nil, err
  171. }
  172. return &mt, nil
  173. } else if err != nil {
  174. return nil, err
  175. }
  176. mt.rootKey = root
  177. return &mt, nil
  178. }
  179. func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
  180. v, err := mt.db.Get(dbKeyRootNode)
  181. if err != nil {
  182. return nil, err
  183. }
  184. var root Hash
  185. // Skip first byte which contains the NodeType
  186. copy(root[:], v[1:])
  187. return &root, nil
  188. }
  189. // DB returns the MerkleTree.DB()
  190. func (mt *MerkleTree) DB() db.Storage {
  191. return mt.db
  192. }
  193. // Root returns the MerkleRoot
  194. func (mt *MerkleTree) Root() *Hash {
  195. return mt.rootKey
  196. }
  197. // MaxLevels returns the MT maximum level
  198. func (mt *MerkleTree) MaxLevels() int {
  199. return mt.maxLevels
  200. }
  201. // Snapshot returns a read-only copy of the MerkleTree
  202. func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
  203. mt.RLock()
  204. defer mt.RUnlock()
  205. _, err := mt.GetNode(rootKey)
  206. if err != nil {
  207. return nil, err
  208. }
  209. return &MerkleTree{db: mt.db, maxLevels: mt.maxLevels, rootKey: rootKey, writable: false}, nil
  210. }
  211. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the
  212. // path from the Root to the Leaf.
  213. func (mt *MerkleTree) Add(k, v *big.Int) error {
  214. // verify that the MerkleTree is writable
  215. if !mt.writable {
  216. return ErrNotWritable
  217. }
  218. // verfy that k & v are valid and fit inside the Finite Field.
  219. if !cryptoUtils.CheckBigIntInField(k) {
  220. return errors.New("Key not inside the Finite Field")
  221. }
  222. if !cryptoUtils.CheckBigIntInField(v) {
  223. return errors.New("Value not inside the Finite Field")
  224. }
  225. tx, err := mt.db.NewTx()
  226. if err != nil {
  227. return err
  228. }
  229. mt.Lock()
  230. defer mt.Unlock()
  231. kHash := NewHashFromBigInt(k)
  232. vHash := NewHashFromBigInt(v)
  233. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  234. path := getPath(mt.maxLevels, kHash[:])
  235. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  236. if err != nil {
  237. return err
  238. }
  239. mt.rootKey = newRootKey
  240. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  241. if err != nil {
  242. return err
  243. }
  244. if err := tx.Commit(); err != nil {
  245. return err
  246. }
  247. return nil
  248. }
  249. // AddAndGetCircomProof does an Add, and returns a CircomProcessorProof
  250. func (mt *MerkleTree) AddAndGetCircomProof(k,
  251. v *big.Int) (*CircomProcessorProof, error) {
  252. var cp CircomProcessorProof
  253. cp.Fnc = 2
  254. cp.OldRoot = mt.rootKey
  255. gettedK, gettedV, _, err := mt.Get(k)
  256. if err != nil && err != ErrKeyNotFound {
  257. return nil, err
  258. }
  259. cp.OldKey = NewHashFromBigInt(gettedK)
  260. cp.OldValue = NewHashFromBigInt(gettedV)
  261. if bytes.Equal(cp.OldKey[:], HashZero[:]) {
  262. cp.IsOld0 = true
  263. }
  264. _, _, siblings, err := mt.Get(k)
  265. if err != nil && err != ErrKeyNotFound {
  266. return nil, err
  267. }
  268. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  269. err = mt.Add(k, v)
  270. if err != nil {
  271. return nil, err
  272. }
  273. cp.NewKey = NewHashFromBigInt(k)
  274. cp.NewValue = NewHashFromBigInt(v)
  275. cp.NewRoot = mt.rootKey
  276. return &cp, nil
  277. }
  278. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  279. // from newLeaf, at which point both leafs are stored, all while updating the
  280. // path.
  281. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node, lvl int,
  282. pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  283. if lvl > mt.maxLevels-2 {
  284. return nil, ErrReachedMaxLevel
  285. }
  286. var newNodeMiddle *Node
  287. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  288. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  289. if err != nil {
  290. return nil, err
  291. }
  292. if pathNewLeaf[lvl] { // go right
  293. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey)
  294. } else { // go left
  295. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero)
  296. }
  297. return mt.addNode(tx, newNodeMiddle)
  298. }
  299. oldLeafKey, err := oldLeaf.Key()
  300. if err != nil {
  301. return nil, err
  302. }
  303. newLeafKey, err := newLeaf.Key()
  304. if err != nil {
  305. return nil, err
  306. }
  307. if pathNewLeaf[lvl] {
  308. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  309. } else {
  310. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  311. }
  312. // We can add newLeaf now. We don't need to add oldLeaf because it's
  313. // already in the tree.
  314. _, err = mt.addNode(tx, newLeaf)
  315. if err != nil {
  316. return nil, err
  317. }
  318. return mt.addNode(tx, newNodeMiddle)
  319. }
  320. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  321. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  322. lvl int, path []bool) (*Hash, error) {
  323. var err error
  324. var nextKey *Hash
  325. if lvl > mt.maxLevels-1 {
  326. return nil, ErrReachedMaxLevel
  327. }
  328. n, err := mt.GetNode(key)
  329. if err != nil {
  330. return nil, err
  331. }
  332. switch n.Type {
  333. case NodeTypeEmpty:
  334. // We can add newLeaf now
  335. return mt.addNode(tx, newLeaf)
  336. case NodeTypeLeaf:
  337. nKey := n.Entry[0]
  338. // Check if leaf node found contains the leaf node we are
  339. // trying to add
  340. newLeafKey := newLeaf.Entry[0]
  341. if bytes.Equal(nKey[:], newLeafKey[:]) {
  342. return nil, ErrEntryIndexAlreadyExists
  343. }
  344. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  345. // We need to push newLeaf down until its path diverges from
  346. // n's path
  347. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  348. case NodeTypeMiddle:
  349. // We need to go deeper, continue traversing the tree, left or
  350. // right depending on path
  351. var newNodeMiddle *Node
  352. if path[lvl] { // go right
  353. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path)
  354. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  355. } else { // go left
  356. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path)
  357. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  358. }
  359. if err != nil {
  360. return nil, err
  361. }
  362. // Update the node to reflect the modified child
  363. return mt.addNode(tx, newNodeMiddle)
  364. default:
  365. return nil, ErrInvalidNodeFound
  366. }
  367. }
  368. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  369. // they are all the same and assumed to always exist.
  370. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  371. // verify that the MerkleTree is writable
  372. if !mt.writable {
  373. return nil, ErrNotWritable
  374. }
  375. if n.Type == NodeTypeEmpty {
  376. return n.Key()
  377. }
  378. k, err := n.Key()
  379. if err != nil {
  380. return nil, err
  381. }
  382. v := n.Value()
  383. // Check that the node key doesn't already exist
  384. if _, err := tx.Get(k[:]); err == nil {
  385. return nil, ErrNodeKeyAlreadyExists
  386. }
  387. err = tx.Put(k[:], v)
  388. return k, err
  389. }
  390. // updateNode updates an existing node in the MT. Empty nodes are not stored
  391. // in the tree; they are all the same and assumed to always exist.
  392. func (mt *MerkleTree) updateNode(tx db.Tx, n *Node) (*Hash, error) {
  393. // verify that the MerkleTree is writable
  394. if !mt.writable {
  395. return nil, ErrNotWritable
  396. }
  397. if n.Type == NodeTypeEmpty {
  398. return n.Key()
  399. }
  400. k, err := n.Key()
  401. if err != nil {
  402. return nil, err
  403. }
  404. v := n.Value()
  405. err = tx.Put(k[:], v)
  406. return k, err
  407. }
  408. // Get returns the value of the leaf for the given key
  409. func (mt *MerkleTree) Get(k *big.Int) (*big.Int, *big.Int, []*Hash, error) {
  410. // verfy that k is valid and fit inside the Finite Field.
  411. if !cryptoUtils.CheckBigIntInField(k) {
  412. return nil, nil, nil, errors.New("Key not inside the Finite Field")
  413. }
  414. kHash := NewHashFromBigInt(k)
  415. path := getPath(mt.maxLevels, kHash[:])
  416. nextKey := mt.rootKey
  417. siblings := []*Hash{}
  418. for i := 0; i < mt.maxLevels; i++ {
  419. n, err := mt.GetNode(nextKey)
  420. if err != nil {
  421. return nil, nil, nil, err
  422. }
  423. switch n.Type {
  424. case NodeTypeEmpty:
  425. return big.NewInt(0), big.NewInt(0), siblings, ErrKeyNotFound
  426. case NodeTypeLeaf:
  427. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  428. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, nil
  429. }
  430. return n.Entry[0].BigInt(), n.Entry[1].BigInt(), siblings, ErrKeyNotFound
  431. case NodeTypeMiddle:
  432. if path[i] {
  433. nextKey = n.ChildR
  434. siblings = append(siblings, n.ChildL)
  435. } else {
  436. nextKey = n.ChildL
  437. siblings = append(siblings, n.ChildR)
  438. }
  439. default:
  440. return nil, nil, nil, ErrInvalidNodeFound
  441. }
  442. }
  443. return nil, nil, nil, ErrReachedMaxLevel
  444. }
  445. // Update updates the value of a specified key in the MerkleTree, and updates
  446. // the path from the leaf to the Root with the new values. Returns the
  447. // CircomProcessorProof.
  448. func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
  449. // verify that the MerkleTree is writable
  450. if !mt.writable {
  451. return nil, ErrNotWritable
  452. }
  453. // verfy that k & are valid and fit inside the Finite Field.
  454. if !cryptoUtils.CheckBigIntInField(k) {
  455. return nil, errors.New("Key not inside the Finite Field")
  456. }
  457. if !cryptoUtils.CheckBigIntInField(v) {
  458. return nil, errors.New("Key not inside the Finite Field")
  459. }
  460. tx, err := mt.db.NewTx()
  461. if err != nil {
  462. return nil, err
  463. }
  464. mt.Lock()
  465. defer mt.Unlock()
  466. kHash := NewHashFromBigInt(k)
  467. vHash := NewHashFromBigInt(v)
  468. path := getPath(mt.maxLevels, kHash[:])
  469. var cp CircomProcessorProof
  470. cp.Fnc = 1
  471. cp.OldRoot = mt.rootKey
  472. cp.OldKey = kHash
  473. cp.NewKey = kHash
  474. cp.NewValue = vHash
  475. nextKey := mt.rootKey
  476. siblings := []*Hash{}
  477. for i := 0; i < mt.maxLevels; i++ {
  478. n, err := mt.GetNode(nextKey)
  479. if err != nil {
  480. return nil, err
  481. }
  482. switch n.Type {
  483. case NodeTypeEmpty:
  484. return nil, ErrKeyNotFound
  485. case NodeTypeLeaf:
  486. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  487. cp.OldValue = n.Entry[1]
  488. cp.Siblings = CircomSiblingsFromSiblings(siblings, mt.maxLevels)
  489. // update leaf and upload to the root
  490. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  491. _, err := mt.updateNode(tx, newNodeLeaf)
  492. if err != nil {
  493. return nil, err
  494. }
  495. newRootKey, err :=
  496. mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
  497. if err != nil {
  498. return nil, err
  499. }
  500. mt.rootKey = newRootKey
  501. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  502. if err != nil {
  503. return nil, err
  504. }
  505. cp.NewRoot = newRootKey
  506. if err := tx.Commit(); err != nil {
  507. return nil, err
  508. }
  509. return &cp, nil
  510. }
  511. return nil, ErrKeyNotFound
  512. case NodeTypeMiddle:
  513. if path[i] {
  514. nextKey = n.ChildR
  515. siblings = append(siblings, n.ChildL)
  516. } else {
  517. nextKey = n.ChildL
  518. siblings = append(siblings, n.ChildR)
  519. }
  520. default:
  521. return nil, ErrInvalidNodeFound
  522. }
  523. }
  524. return nil, ErrKeyNotFound
  525. }
  526. // Delete removes the specified Key from the MerkleTree and updates the path
  527. // from the deleted key to the Root with the new values. This method removes
  528. // the key from the MerkleTree, but does not remove the old nodes from the
  529. // key-value database; this means that if the tree is accessed by an old Root
  530. // where the key was not deleted yet, the key will still exist. If is desired
  531. // to remove the key-values from the database that are not under the current
  532. // Root, an option could be to dump all the leafs (using mt.DumpLeafs) and
  533. // import them in a new MerkleTree in a new database (using
  534. // mt.ImportDumpedLeafs), but this will loose all the Root history of the
  535. // MerkleTree
  536. func (mt *MerkleTree) Delete(k *big.Int) error {
  537. // verify that the MerkleTree is writable
  538. if !mt.writable {
  539. return ErrNotWritable
  540. }
  541. // verfy that k is valid and fit inside the Finite Field.
  542. if !cryptoUtils.CheckBigIntInField(k) {
  543. return errors.New("Key not inside the Finite Field")
  544. }
  545. tx, err := mt.db.NewTx()
  546. if err != nil {
  547. return err
  548. }
  549. mt.Lock()
  550. defer mt.Unlock()
  551. kHash := NewHashFromBigInt(k)
  552. path := getPath(mt.maxLevels, kHash[:])
  553. nextKey := mt.rootKey
  554. siblings := []*Hash{}
  555. for i := 0; i < mt.maxLevels; i++ {
  556. n, err := mt.GetNode(nextKey)
  557. if err != nil {
  558. return err
  559. }
  560. switch n.Type {
  561. case NodeTypeEmpty:
  562. return ErrKeyNotFound
  563. case NodeTypeLeaf:
  564. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  565. // remove and go up with the sibling
  566. err = mt.rmAndUpload(tx, path, kHash, siblings)
  567. return err
  568. }
  569. return ErrKeyNotFound
  570. case NodeTypeMiddle:
  571. if path[i] {
  572. nextKey = n.ChildR
  573. siblings = append(siblings, n.ChildL)
  574. } else {
  575. nextKey = n.ChildL
  576. siblings = append(siblings, n.ChildR)
  577. }
  578. default:
  579. return ErrInvalidNodeFound
  580. }
  581. }
  582. return ErrKeyNotFound
  583. }
  584. // rmAndUpload removes the key, and goes up until the root updating all the
  585. // nodes with the new values.
  586. func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
  587. if len(siblings) == 0 {
  588. mt.rootKey = &HashZero
  589. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  590. if err != nil {
  591. return err
  592. }
  593. return tx.Commit()
  594. }
  595. toUpload := siblings[len(siblings)-1]
  596. if len(siblings) < 2 { //nolint:gomnd
  597. mt.rootKey = siblings[0]
  598. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  599. if err != nil {
  600. return err
  601. }
  602. return tx.Commit()
  603. }
  604. for i := len(siblings) - 2; i >= 0; i-- { //nolint:gomnd
  605. if !bytes.Equal(siblings[i][:], HashZero[:]) {
  606. var newNode *Node
  607. if path[i] {
  608. newNode = NewNodeMiddle(siblings[i], toUpload)
  609. } else {
  610. newNode = NewNodeMiddle(toUpload, siblings[i])
  611. }
  612. _, err := mt.addNode(tx, newNode)
  613. if err != ErrNodeKeyAlreadyExists && err != nil {
  614. return err
  615. }
  616. // go up until the root
  617. newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNode,
  618. siblings[:i])
  619. if err != nil {
  620. return err
  621. }
  622. mt.rootKey = newRootKey
  623. err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  624. if err != nil {
  625. return err
  626. }
  627. break
  628. }
  629. // if i==0 (root position), stop and store the sibling of the
  630. // deleted leaf as root
  631. if i == 0 {
  632. mt.rootKey = toUpload
  633. err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
  634. if err != nil {
  635. return err
  636. }
  637. break
  638. }
  639. }
  640. if err := tx.Commit(); err != nil {
  641. return err
  642. }
  643. return nil
  644. }
  645. // recalculatePathUntilRoot recalculates the nodes until the Root
  646. func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node,
  647. siblings []*Hash) (*Hash, error) {
  648. for i := len(siblings) - 1; i >= 0; i-- {
  649. nodeKey, err := node.Key()
  650. if err != nil {
  651. return nil, err
  652. }
  653. if path[i] {
  654. node = NewNodeMiddle(siblings[i], nodeKey)
  655. } else {
  656. node = NewNodeMiddle(nodeKey, siblings[i])
  657. }
  658. _, err = mt.addNode(tx, node)
  659. if err != ErrNodeKeyAlreadyExists && err != nil {
  660. return nil, err
  661. }
  662. }
  663. // return last node added, which is the root
  664. nodeKey, err := node.Key()
  665. return nodeKey, err
  666. }
  667. // dbInsert is a helper function to insert a node into a key in an open db
  668. // transaction.
  669. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) error {
  670. v := append([]byte{byte(t)}, data...)
  671. return tx.Put(k, v)
  672. }
  673. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  674. // tree; they are all the same and assumed to always exist.
  675. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  676. if bytes.Equal(key[:], HashZero[:]) {
  677. return NewNodeEmpty(), nil
  678. }
  679. nBytes, err := mt.db.Get(key[:])
  680. if err != nil {
  681. return nil, err
  682. }
  683. return NewNodeFromBytes(nBytes)
  684. }
  685. // getPath returns the binary path, from the root to the leaf.
  686. func getPath(numLevels int, k []byte) []bool {
  687. path := make([]bool, numLevels)
  688. for n := 0; n < numLevels; n++ {
  689. path[n] = TestBit(k[:], uint(n))
  690. }
  691. return path
  692. }
  693. // NodeAux contains the auxiliary node used in a non-existence proof.
  694. type NodeAux struct {
  695. Key *Hash
  696. Value *Hash
  697. }
  698. // Proof defines the required elements for a MT proof of existence or
  699. // non-existence.
  700. type Proof struct {
  701. // existence indicates wether this is a proof of existence or
  702. // non-existence.
  703. Existence bool
  704. // depth indicates how deep in the tree the proof goes.
  705. depth uint
  706. // notempties is a bitmap of non-empty Siblings found in Siblings.
  707. notempties [ElemBytesLen - proofFlagsLen]byte
  708. // Siblings is a list of non-empty sibling keys.
  709. Siblings []*Hash
  710. NodeAux *NodeAux
  711. }
  712. // NewProofFromBytes parses a byte array into a Proof.
  713. func NewProofFromBytes(bs []byte) (*Proof, error) {
  714. if len(bs) < ElemBytesLen {
  715. return nil, ErrInvalidProofBytes
  716. }
  717. p := &Proof{}
  718. if (bs[0] & 0x01) == 0 {
  719. p.Existence = true
  720. }
  721. p.depth = uint(bs[1])
  722. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  723. siblingBytes := bs[ElemBytesLen:]
  724. sibIdx := 0
  725. for i := uint(0); i < p.depth; i++ {
  726. if TestBitBigEndian(p.notempties[:], i) {
  727. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  728. return nil, ErrInvalidProofBytes
  729. }
  730. var sib Hash
  731. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  732. p.Siblings = append(p.Siblings, &sib)
  733. sibIdx++
  734. }
  735. }
  736. if !p.Existence && ((bs[0] & 0x02) != 0) {
  737. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  738. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  739. if len(nodeAuxBytes) != 2*ElemBytesLen {
  740. return nil, ErrInvalidProofBytes
  741. }
  742. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  743. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  744. }
  745. return p, nil
  746. }
  747. // Bytes serializes a Proof into a byte array.
  748. func (p *Proof) Bytes() []byte {
  749. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  750. if p.NodeAux != nil {
  751. bsLen += 2 * ElemBytesLen //nolint:gomnd
  752. }
  753. bs := make([]byte, bsLen)
  754. if !p.Existence {
  755. bs[0] |= 0x01
  756. }
  757. bs[1] = byte(p.depth)
  758. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  759. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  760. for i, k := range p.Siblings {
  761. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  762. }
  763. if p.NodeAux != nil {
  764. bs[0] |= 0x02
  765. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  766. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  767. }
  768. return bs
  769. }
  770. // SiblingsFromProof returns all the siblings of the proof.
  771. func SiblingsFromProof(proof *Proof) []*Hash {
  772. sibIdx := 0
  773. siblings := []*Hash{}
  774. for lvl := 0; lvl < int(proof.depth); lvl++ {
  775. if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  776. siblings = append(siblings, proof.Siblings[sibIdx])
  777. sibIdx++
  778. } else {
  779. siblings = append(siblings, &HashZero)
  780. }
  781. }
  782. return siblings
  783. }
  784. // AllSiblings returns all the siblings of the proof.
  785. func (p *Proof) AllSiblings() []*Hash {
  786. return SiblingsFromProof(p)
  787. }
  788. // CircomSiblingsFromSiblings returns the full siblings compatible with circom
  789. func CircomSiblingsFromSiblings(siblings []*Hash, levels int) []*Hash {
  790. // Add the rest of empty levels to the siblings
  791. for i := len(siblings); i < levels+1; i++ {
  792. siblings = append(siblings, &HashZero)
  793. }
  794. return siblings
  795. }
  796. // CircomProcessorProof defines the ProcessorProof compatible with circom. Is
  797. // the data of the proof between the transition from one state to another.
  798. type CircomProcessorProof struct {
  799. OldRoot *Hash `json:"oldRoot"`
  800. NewRoot *Hash `json:"newRoot"`
  801. Siblings []*Hash `json:"siblings"`
  802. OldKey *Hash `json:"oldKey"`
  803. OldValue *Hash `json:"oldValue"`
  804. NewKey *Hash `json:"newKey"`
  805. NewValue *Hash `json:"newValue"`
  806. IsOld0 bool `json:"isOld0"`
  807. // 0: NOP, 1: Update, 2: Insert, 3: Delete
  808. Fnc int `json:"fnc"`
  809. }
  810. // String returns a human readable string representation of the
  811. // CircomProcessorProof
  812. func (p CircomProcessorProof) String() string {
  813. buf := bytes.NewBufferString("{")
  814. fmt.Fprintf(buf, " OldRoot: %v,\n", p.OldRoot)
  815. fmt.Fprintf(buf, " NewRoot: %v,\n", p.NewRoot)
  816. fmt.Fprintf(buf, " Siblings: [\n ")
  817. for _, s := range p.Siblings {
  818. fmt.Fprintf(buf, "%v, ", s)
  819. }
  820. fmt.Fprintf(buf, "\n ],\n")
  821. fmt.Fprintf(buf, " OldKey: %v,\n", p.OldKey)
  822. fmt.Fprintf(buf, " OldValue: %v,\n", p.OldValue)
  823. fmt.Fprintf(buf, " NewKey: %v,\n", p.NewKey)
  824. fmt.Fprintf(buf, " NewValue: %v,\n", p.NewValue)
  825. fmt.Fprintf(buf, " IsOld0: %v,\n", p.IsOld0)
  826. fmt.Fprintf(buf, "}\n")
  827. return buf.String()
  828. }
  829. // CircomVerifierProof defines the VerifierProof compatible with circom. Is the
  830. // data of the proof that a certain leaf exists in the MerkleTree.
  831. type CircomVerifierProof struct {
  832. Root *Hash `json:"root"`
  833. Siblings []*Hash `json:"siblings"`
  834. OldKey *Hash `json:"oldKey"`
  835. OldValue *Hash `json:"oldValue"`
  836. IsOld0 bool `json:"isOld0"`
  837. Key *Hash `json:"key"`
  838. Value *Hash `json:"value"`
  839. Fnc int `json:"fnc"` // 0: inclusion, 1: non inclusion
  840. }
  841. // GenerateCircomVerifierProof returns the CircomVerifierProof for a certain
  842. // key in the MerkleTree. If the rootKey is nil, the current merkletree root
  843. // is used.
  844. func (mt *MerkleTree) GenerateCircomVerifierProof(k *big.Int,
  845. rootKey *Hash) (*CircomVerifierProof, error) {
  846. cp, err := mt.GenerateSCVerifierProof(k, rootKey)
  847. if err != nil {
  848. return nil, err
  849. }
  850. cp.Siblings = CircomSiblingsFromSiblings(cp.Siblings, mt.maxLevels)
  851. return cp, nil
  852. }
  853. // GenerateSCVerifierProof returns the CircomVerifierProof for a certain key in
  854. // the MerkleTree with the Siblings without the extra 0 needed at the circom
  855. // circuits, which makes it straight forward to verifiy inside a Smart
  856. // Contract. If the rootKey is nil, the current merkletree root is used.
  857. func (mt *MerkleTree) GenerateSCVerifierProof(k *big.Int,
  858. rootKey *Hash) (*CircomVerifierProof, error) {
  859. if rootKey == nil {
  860. rootKey = mt.Root()
  861. }
  862. p, v, err := mt.GenerateProof(k, rootKey)
  863. if err != nil && err != ErrKeyNotFound {
  864. return nil, err
  865. }
  866. var cp CircomVerifierProof
  867. cp.Root = rootKey
  868. cp.Siblings = p.AllSiblings()
  869. if p.NodeAux != nil {
  870. cp.OldKey = p.NodeAux.Key
  871. cp.OldValue = p.NodeAux.Value
  872. } else {
  873. cp.OldKey = &HashZero
  874. cp.OldValue = &HashZero
  875. }
  876. cp.Key = NewHashFromBigInt(k)
  877. cp.Value = NewHashFromBigInt(v)
  878. if p.Existence {
  879. cp.Fnc = 0 // inclusion
  880. } else {
  881. cp.Fnc = 1 // non inclusion
  882. }
  883. return &cp, nil
  884. }
  885. // GenerateProof generates the proof of existence (or non-existence) of an
  886. // Entry's hash Index for a Merkle Tree given the root.
  887. // If the rootKey is nil, the current merkletree root is used
  888. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof,
  889. *big.Int, error) {
  890. p := &Proof{}
  891. var siblingKey *Hash
  892. kHash := NewHashFromBigInt(k)
  893. path := getPath(mt.maxLevels, kHash[:])
  894. if rootKey == nil {
  895. rootKey = mt.Root()
  896. }
  897. nextKey := rootKey
  898. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  899. n, err := mt.GetNode(nextKey)
  900. if err != nil {
  901. return nil, nil, err
  902. }
  903. switch n.Type {
  904. case NodeTypeEmpty:
  905. return p, big.NewInt(0), nil
  906. case NodeTypeLeaf:
  907. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  908. p.Existence = true
  909. return p, n.Entry[1].BigInt(), nil
  910. }
  911. // We found a leaf whose entry didn't match hIndex
  912. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  913. return p, n.Entry[1].BigInt(), nil
  914. case NodeTypeMiddle:
  915. if path[p.depth] {
  916. nextKey = n.ChildR
  917. siblingKey = n.ChildL
  918. } else {
  919. nextKey = n.ChildL
  920. siblingKey = n.ChildR
  921. }
  922. default:
  923. return nil, nil, ErrInvalidNodeFound
  924. }
  925. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  926. SetBitBigEndian(p.notempties[:], uint(p.depth))
  927. p.Siblings = append(p.Siblings, siblingKey)
  928. }
  929. }
  930. return nil, nil, ErrKeyNotFound
  931. }
  932. // VerifyProof verifies the Merkle Proof for the entry and root.
  933. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  934. rootFromProof, err := RootFromProof(proof, k, v)
  935. if err != nil {
  936. return false
  937. }
  938. return bytes.Equal(rootKey[:], rootFromProof[:])
  939. }
  940. // RootFromProof calculates the root that would correspond to a tree whose
  941. // siblings are the ones in the proof with the leaf hashing to hIndex and
  942. // hValue.
  943. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  944. kHash := NewHashFromBigInt(k)
  945. vHash := NewHashFromBigInt(v)
  946. sibIdx := len(proof.Siblings) - 1
  947. var err error
  948. var midKey *Hash
  949. if proof.Existence {
  950. midKey, err = LeafKey(kHash, vHash)
  951. if err != nil {
  952. return nil, err
  953. }
  954. } else {
  955. if proof.NodeAux == nil {
  956. midKey = &HashZero
  957. } else {
  958. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  959. return nil,
  960. fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  961. }
  962. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  963. if err != nil {
  964. return nil, err
  965. }
  966. }
  967. }
  968. path := getPath(int(proof.depth), kHash[:])
  969. var siblingKey *Hash
  970. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  971. if TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  972. siblingKey = proof.Siblings[sibIdx]
  973. sibIdx--
  974. } else {
  975. siblingKey = &HashZero
  976. }
  977. if path[lvl] {
  978. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  979. if err != nil {
  980. return nil, err
  981. }
  982. } else {
  983. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  984. if err != nil {
  985. return nil, err
  986. }
  987. }
  988. }
  989. return midKey, nil
  990. }
  991. // walk is a helper recursive function to iterate over all tree branches
  992. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  993. n, err := mt.GetNode(key)
  994. if err != nil {
  995. return err
  996. }
  997. switch n.Type {
  998. case NodeTypeEmpty:
  999. f(n)
  1000. case NodeTypeLeaf:
  1001. f(n)
  1002. case NodeTypeMiddle:
  1003. f(n)
  1004. if err := mt.walk(n.ChildL, f); err != nil {
  1005. return err
  1006. }
  1007. if err := mt.walk(n.ChildR, f); err != nil {
  1008. return err
  1009. }
  1010. default:
  1011. return ErrInvalidNodeFound
  1012. }
  1013. return nil
  1014. }
  1015. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  1016. // if rootKey is nil, it will get the current RootKey of the current state of
  1017. // the MerkleTree. For each node, it calls the f function given in the
  1018. // parameters. See some examples of the Walk function usage in the
  1019. // merkletree.go and merkletree_test.go
  1020. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  1021. if rootKey == nil {
  1022. rootKey = mt.Root()
  1023. }
  1024. err := mt.walk(rootKey, f)
  1025. return err
  1026. }
  1027. // GraphViz uses Walk function to generate a string GraphViz representation of
  1028. // the tree and writes it to w
  1029. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  1030. fmt.Fprintf(w, `digraph hierarchy {
  1031. node [fontname=Monospace,fontsize=10,shape=box]
  1032. `)
  1033. cnt := 0
  1034. var errIn error
  1035. err := mt.Walk(rootKey, func(n *Node) {
  1036. k, err := n.Key()
  1037. if err != nil {
  1038. errIn = err
  1039. }
  1040. switch n.Type {
  1041. case NodeTypeEmpty:
  1042. case NodeTypeLeaf:
  1043. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.String())
  1044. case NodeTypeMiddle:
  1045. lr := [2]string{n.ChildL.String(), n.ChildR.String()}
  1046. emptyNodes := ""
  1047. for i := range lr {
  1048. if lr[i] == "0" {
  1049. lr[i] = fmt.Sprintf("empty%v", cnt)
  1050. emptyNodes += fmt.Sprintf("\"%v\" [style=dashed,label=0];\n", lr[i])
  1051. cnt++
  1052. }
  1053. }
  1054. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.String(), lr[0], lr[1])
  1055. fmt.Fprint(w, emptyNodes)
  1056. default:
  1057. }
  1058. })
  1059. fmt.Fprintf(w, "}\n")
  1060. if errIn != nil {
  1061. return errIn
  1062. }
  1063. return err
  1064. }
  1065. // PrintGraphViz prints directly the GraphViz() output
  1066. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  1067. if rootKey == nil {
  1068. rootKey = mt.Root()
  1069. }
  1070. w := bytes.NewBufferString("")
  1071. fmt.Fprintf(w,
  1072. "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  1073. err := mt.GraphViz(w, nil)
  1074. if err != nil {
  1075. return err
  1076. }
  1077. fmt.Fprintf(w,
  1078. "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  1079. fmt.Println(w)
  1080. return nil
  1081. }
  1082. // DumpLeafs returns all the Leafs that exist under the given Root. If no Root
  1083. // is given (nil), it uses the current Root of the MerkleTree.
  1084. func (mt *MerkleTree) DumpLeafs(rootKey *Hash) ([]byte, error) {
  1085. var b []byte
  1086. err := mt.Walk(rootKey, func(n *Node) {
  1087. if n.Type == NodeTypeLeaf {
  1088. l := n.Entry[0].Bytes()
  1089. r := n.Entry[1].Bytes()
  1090. b = append(b, append(l[:], r[:]...)...)
  1091. }
  1092. })
  1093. return b, err
  1094. }
  1095. // ImportDumpedLeafs parses and adds to the MerkleTree the dumped list of leafs
  1096. // from the DumpLeafs function.
  1097. func (mt *MerkleTree) ImportDumpedLeafs(b []byte) error {
  1098. for i := 0; i < len(b); i += 64 {
  1099. lr := b[i : i+64]
  1100. lB, err := NewBigIntFromHashBytes(lr[:32])
  1101. if err != nil {
  1102. return err
  1103. }
  1104. rB, err := NewBigIntFromHashBytes(lr[32:])
  1105. if err != nil {
  1106. return err
  1107. }
  1108. err = mt.Add(lB, rB)
  1109. if err != nil {
  1110. return err
  1111. }
  1112. }
  1113. return nil
  1114. }