You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

594 lines
16 KiB

  1. package merkletree
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/big"
  8. "sync"
  9. "github.com/iden3/go-iden3-core/common"
  10. "github.com/iden3/go-iden3-core/db"
  11. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  12. )
  13. const (
  14. // proofFlagsLen is the byte length of the flags in the proof header (first 32
  15. // bytes).
  16. proofFlagsLen = 2
  17. // ElemBytesLen is the length of the Hash byte array
  18. ElemBytesLen = 32
  19. )
  20. var (
  21. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  22. ErrNodeKeyAlreadyExists = errors.New("node already exists")
  23. // ErrEntryIndexNotFound is used when no entry is found for an index.
  24. ErrEntryIndexNotFound = errors.New("node index not found in the DB")
  25. // ErrNodeDataBadSize is used when the data of a node has an incorrect
  26. // size and can't be parsed.
  27. ErrNodeDataBadSize = errors.New("node data has incorrect size in the DB")
  28. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  29. // maximum level.
  30. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  31. // ErrInvalidNodeFound is used when an invalid node is found and can't
  32. // be parsed.
  33. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  34. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  35. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  36. // ErrInvalidDBValue is used when a value in the key value DB is
  37. // invalid (for example, it doen't contain a byte header and a []byte
  38. // body of at least len=1.
  39. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  40. // ErrEntryIndexAlreadyExists is used when the entry index already
  41. // exists in the tree.
  42. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  43. // ErrNotWritable is used when the MerkleTree is not writable and a write function is called
  44. ErrNotWritable = errors.New("Merkle Tree not writable")
  45. rootNodeValue = []byte("currentroot")
  46. // HashZero is used at Empty nodes
  47. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  48. )
  49. // Hash is the generic type stored in the MerkleTree
  50. type Hash [32]byte
  51. func (h Hash) String() string {
  52. return new(big.Int).SetBytes(h[:]).String()
  53. }
  54. // BigInt returns the *big.Int representation of the *Hash
  55. func (h *Hash) BigInt() *big.Int {
  56. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  57. }
  58. // NewHashFromBigInt returns a *Hash representation of the given *big.Int
  59. func NewHashFromBigInt(b *big.Int) *Hash {
  60. r := &Hash{}
  61. copy(r[:], common.SwapEndianness(b.Bytes()))
  62. return r
  63. }
  64. // MerkleTree is the struct with the main elements of the MerkleTree
  65. type MerkleTree struct {
  66. sync.RWMutex
  67. db db.Storage
  68. rootKey *Hash
  69. writable bool
  70. maxLevels int
  71. }
  72. // NewMerkleTree loads a new Merkletree. If in the sotrage already exists one will open that one, if not, will create a new one.
  73. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  74. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  75. v, err := mt.db.Get(rootNodeValue)
  76. if err != nil {
  77. tx, err := mt.db.NewTx()
  78. if err != nil {
  79. return nil, err
  80. }
  81. mt.rootKey = &HashZero
  82. tx.Put(rootNodeValue, mt.rootKey[:])
  83. err = tx.Commit()
  84. if err != nil {
  85. return nil, err
  86. }
  87. return &mt, nil
  88. }
  89. mt.rootKey = &Hash{}
  90. copy(mt.rootKey[:], v)
  91. return &mt, nil
  92. }
  93. // Root returns the MerkleRoot
  94. func (mt *MerkleTree) Root() *Hash {
  95. return mt.rootKey
  96. }
  97. // Add adds a Key & Value into the MerkleTree. Where the `k` determines the path from the Root to the Leaf.
  98. func (mt *MerkleTree) Add(k, v *big.Int) error {
  99. // verify that the MerkleTree is writable
  100. if !mt.writable {
  101. return ErrNotWritable
  102. }
  103. // verfy that the ElemBytes are valid and fit inside the Finite Field.
  104. if !cryptoUtils.CheckBigIntInField(k) {
  105. return errors.New("Key not inside the Finite Field")
  106. }
  107. if !cryptoUtils.CheckBigIntInField(v) {
  108. return errors.New("Value not inside the Finite Field")
  109. }
  110. tx, err := mt.db.NewTx()
  111. if err != nil {
  112. return err
  113. }
  114. mt.Lock()
  115. defer mt.Unlock()
  116. kHash := NewHashFromBigInt(k)
  117. vHash := NewHashFromBigInt(v)
  118. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  119. path := getPath(mt.maxLevels, kHash[:])
  120. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  121. if err != nil {
  122. return err
  123. }
  124. mt.rootKey = newRootKey
  125. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  126. if err := tx.Commit(); err != nil {
  127. return err
  128. }
  129. return nil
  130. }
  131. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  132. // from newLeaf, at which point both leafs are stored, all while updating the
  133. // path.
  134. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  135. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  136. if lvl > mt.maxLevels-2 {
  137. return nil, ErrReachedMaxLevel
  138. }
  139. var newNodeMiddle *Node
  140. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  141. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  142. if err != nil {
  143. return nil, err
  144. }
  145. if pathNewLeaf[lvl] {
  146. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  147. } else {
  148. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  149. }
  150. return mt.addNode(tx, newNodeMiddle)
  151. } else {
  152. oldLeafKey, err := oldLeaf.Key()
  153. if err != nil {
  154. return nil, err
  155. }
  156. newLeafKey, err := newLeaf.Key()
  157. if err != nil {
  158. return nil, err
  159. }
  160. if pathNewLeaf[lvl] {
  161. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  162. } else {
  163. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  164. }
  165. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  166. _, err = mt.addNode(tx, newLeaf)
  167. if err != nil {
  168. return nil, err
  169. }
  170. return mt.addNode(tx, newNodeMiddle)
  171. }
  172. }
  173. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  174. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  175. lvl int, path []bool) (*Hash, error) {
  176. var err error
  177. var nextKey *Hash
  178. if lvl > mt.maxLevels-1 {
  179. return nil, ErrReachedMaxLevel
  180. }
  181. n, err := mt.GetNode(key)
  182. if err != nil {
  183. return nil, err
  184. }
  185. switch n.Type {
  186. case NodeTypeEmpty:
  187. // We can add newLeaf now
  188. return mt.addNode(tx, newLeaf)
  189. case NodeTypeLeaf:
  190. nKey := n.Entry[0]
  191. // Check if leaf node found contains the leaf node we are trying to add
  192. newLeafKey := newLeaf.Entry[0]
  193. if bytes.Equal(nKey[:], newLeafKey[:]) {
  194. return nil, ErrEntryIndexAlreadyExists
  195. }
  196. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  197. // We need to push newLeaf down until its path diverges from n's path
  198. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  199. case NodeTypeMiddle:
  200. // We need to go deeper, continue traversing the tree, left or right depending on path
  201. var newNodeMiddle *Node
  202. if path[lvl] {
  203. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  204. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  205. } else {
  206. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  207. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  208. }
  209. if err != nil {
  210. return nil, err
  211. }
  212. // Update the node to reflect the modified child
  213. return mt.addNode(tx, newNodeMiddle)
  214. default:
  215. return nil, ErrInvalidNodeFound
  216. }
  217. }
  218. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  219. // they are all the same and assumed to always exist.
  220. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  221. // verify that the MerkleTree is writable
  222. if !mt.writable {
  223. return nil, ErrNotWritable
  224. }
  225. if n.Type == NodeTypeEmpty {
  226. return n.Key()
  227. }
  228. k, err := n.Key()
  229. if err != nil {
  230. return nil, err
  231. }
  232. v := n.Value()
  233. // Check that the node key doesn't already exist
  234. if _, err := tx.Get(k[:]); err == nil {
  235. return nil, ErrNodeKeyAlreadyExists
  236. }
  237. tx.Put(k[:], v)
  238. return k, nil
  239. }
  240. // dbGet is a helper function to get the node of a key from the internal
  241. // storage.
  242. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
  243. if bytes.Equal(k, HashZero[:]) {
  244. return 0, nil, nil
  245. }
  246. value, err := mt.db.Get(k)
  247. if err != nil {
  248. return 0, nil, err
  249. }
  250. if len(value) < 2 {
  251. return 0, nil, ErrInvalidDBValue
  252. }
  253. nodeType := value[0]
  254. nodeBytes := value[1:]
  255. return NodeType(nodeType), nodeBytes, nil
  256. }
  257. // dbInsert is a helper function to insert a node into a key in an open db
  258. // transaction.
  259. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
  260. v := append([]byte{byte(t)}, data...)
  261. tx.Put(k, v)
  262. }
  263. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  264. // tree; they are all the same and assumed to always exist.
  265. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  266. if bytes.Equal(key[:], HashZero[:]) {
  267. return NewNodeEmpty(), nil
  268. }
  269. nBytes, err := mt.db.Get(key[:])
  270. if err != nil {
  271. return nil, err
  272. }
  273. return NewNodeFromBytes(nBytes)
  274. }
  275. // getPath returns the binary path, from the root to the leaf.
  276. func getPath(numLevels int, k []byte) []bool {
  277. path := make([]bool, numLevels)
  278. for n := 0; n < numLevels; n++ {
  279. path[n] = common.TestBit(k[:], uint(n))
  280. }
  281. return path
  282. }
  283. // NodeAux contains the auxiliary node used in a non-existence proof.
  284. type NodeAux struct {
  285. Key *Hash
  286. Value *Hash
  287. }
  288. // Proof defines the required elements for a MT proof of existence or non-existence.
  289. type Proof struct {
  290. // existence indicates wether this is a proof of existence or non-existence.
  291. Existence bool
  292. // depth indicates how deep in the tree the proof goes.
  293. depth uint
  294. // notempties is a bitmap of non-empty Siblings found in Siblings.
  295. notempties [ElemBytesLen - proofFlagsLen]byte
  296. // Siblings is a list of non-empty sibling keys.
  297. Siblings []*Hash
  298. NodeAux *NodeAux
  299. }
  300. // NewProofFromBytes parses a byte array into a Proof.
  301. func NewProofFromBytes(bs []byte) (*Proof, error) {
  302. if len(bs) < ElemBytesLen {
  303. return nil, ErrInvalidProofBytes
  304. }
  305. p := &Proof{}
  306. if (bs[0] & 0x01) == 0 {
  307. p.Existence = true
  308. }
  309. p.depth = uint(bs[1])
  310. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  311. siblingBytes := bs[ElemBytesLen:]
  312. sibIdx := 0
  313. for i := uint(0); i < p.depth; i++ {
  314. if common.TestBitBigEndian(p.notempties[:], i) {
  315. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  316. return nil, ErrInvalidProofBytes
  317. }
  318. var sib Hash
  319. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  320. p.Siblings = append(p.Siblings, &sib)
  321. sibIdx++
  322. }
  323. }
  324. if !p.Existence && ((bs[0] & 0x02) != 0) {
  325. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  326. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  327. if len(nodeAuxBytes) != 2*ElemBytesLen {
  328. return nil, ErrInvalidProofBytes
  329. }
  330. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  331. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  332. }
  333. return p, nil
  334. }
  335. // Bytes serializes a Proof into a byte array.
  336. func (p *Proof) Bytes() []byte {
  337. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  338. if p.NodeAux != nil {
  339. bsLen += 2 * ElemBytesLen
  340. }
  341. bs := make([]byte, bsLen)
  342. if !p.Existence {
  343. bs[0] |= 0x01
  344. }
  345. bs[1] = byte(p.depth)
  346. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  347. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  348. for i, k := range p.Siblings {
  349. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  350. }
  351. if p.NodeAux != nil {
  352. bs[0] |= 0x02
  353. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  354. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  355. }
  356. return bs
  357. }
  358. // GenerateProof generates the proof of existence (or non-existence) of an
  359. // Entry's hash Index for a Merkle Tree given the root.
  360. // If the rootKey is nil, the current merkletree root is used
  361. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
  362. p := &Proof{}
  363. var siblingKey *Hash
  364. kHash := NewHashFromBigInt(k)
  365. path := getPath(mt.maxLevels, kHash[:])
  366. if rootKey == nil {
  367. rootKey = mt.Root()
  368. }
  369. nextKey := rootKey
  370. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  371. n, err := mt.GetNode(nextKey)
  372. if err != nil {
  373. return nil, err
  374. }
  375. switch n.Type {
  376. case NodeTypeEmpty:
  377. return p, nil
  378. case NodeTypeLeaf:
  379. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  380. p.Existence = true
  381. return p, nil
  382. } else {
  383. // We found a leaf whose entry didn't match hIndex
  384. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  385. return p, nil
  386. }
  387. case NodeTypeMiddle:
  388. if path[p.depth] {
  389. nextKey = n.ChildR
  390. siblingKey = n.ChildL
  391. } else {
  392. nextKey = n.ChildL
  393. siblingKey = n.ChildR
  394. }
  395. default:
  396. return nil, ErrInvalidNodeFound
  397. }
  398. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  399. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  400. p.Siblings = append(p.Siblings, siblingKey)
  401. }
  402. }
  403. return nil, ErrEntryIndexNotFound
  404. }
  405. // VerifyProof verifies the Merkle Proof for the entry and root.
  406. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  407. rootFromProof, err := RootFromProof(proof, k, v)
  408. if err != nil {
  409. return false
  410. }
  411. return bytes.Equal(rootKey[:], rootFromProof[:])
  412. }
  413. // RootFromProof calculates the root that would correspond to a tree whose
  414. // siblings are the ones in the proof with the claim hashing to hIndex and
  415. // hValue.
  416. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  417. kHash := NewHashFromBigInt(k)
  418. vHash := NewHashFromBigInt(v)
  419. sibIdx := len(proof.Siblings) - 1
  420. var err error
  421. var midKey *Hash
  422. if proof.Existence {
  423. midKey, err = LeafKey(kHash, vHash)
  424. if err != nil {
  425. return nil, err
  426. }
  427. } else {
  428. if proof.NodeAux == nil {
  429. midKey = &HashZero
  430. } else {
  431. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  432. return nil, fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  433. }
  434. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  435. if err != nil {
  436. return nil, err
  437. }
  438. }
  439. }
  440. path := getPath(int(proof.depth), kHash[:])
  441. var siblingKey *Hash
  442. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  443. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  444. siblingKey = proof.Siblings[sibIdx]
  445. sibIdx--
  446. } else {
  447. siblingKey = &HashZero
  448. }
  449. if path[lvl] {
  450. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  451. if err != nil {
  452. return nil, err
  453. }
  454. } else {
  455. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  456. if err != nil {
  457. return nil, err
  458. }
  459. }
  460. }
  461. return midKey, nil
  462. }
  463. // walk is a helper recursive function to iterate over all tree branches
  464. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  465. n, err := mt.GetNode(key)
  466. if err != nil {
  467. return err
  468. }
  469. switch n.Type {
  470. case NodeTypeEmpty:
  471. f(n)
  472. case NodeTypeLeaf:
  473. f(n)
  474. case NodeTypeMiddle:
  475. f(n)
  476. if err := mt.walk(n.ChildL, f); err != nil {
  477. return err
  478. }
  479. if err := mt.walk(n.ChildR, f); err != nil {
  480. return err
  481. }
  482. default:
  483. return ErrInvalidNodeFound
  484. }
  485. return nil
  486. }
  487. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  488. // if rootKey is nil, it will get the current RootKey of the current state of the MerkleTree.
  489. // For each node, it calls the f function given in the parameters.
  490. // See some examples of the Walk function usage in the merkletree_test.go
  491. // test functions: TestMTWalk, TestMTWalkGraphViz, TestMTWalkDumpClaims
  492. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  493. if rootKey == nil {
  494. rootKey = mt.Root()
  495. }
  496. err := mt.walk(rootKey, f)
  497. return err
  498. }
  499. // GraphViz uses Walk function to generate a string GraphViz representation of the
  500. // tree and writes it to w
  501. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  502. fmt.Fprintf(w, `digraph hierarchy {
  503. node [fontname=Monospace,fontsize=10,shape=box]
  504. `)
  505. cnt := 0
  506. var errIn error
  507. err := mt.Walk(rootKey, func(n *Node) {
  508. k, err := n.Key()
  509. if err != nil {
  510. errIn = err
  511. }
  512. switch n.Type {
  513. case NodeTypeEmpty:
  514. case NodeTypeLeaf:
  515. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.BigInt().String())
  516. case NodeTypeMiddle:
  517. lr := [2]string{n.ChildL.BigInt().String(), n.ChildR.BigInt().String()}
  518. for i := range lr {
  519. if lr[i] == "0" {
  520. lr[i] = fmt.Sprintf("empty%v", cnt)
  521. fmt.Fprintf(w, "\"%v\" [style=dashed,label=0];\n", lr[i])
  522. cnt++
  523. }
  524. }
  525. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.BigInt().String(), lr[0], lr[1])
  526. default:
  527. }
  528. })
  529. fmt.Fprintf(w, "}\n")
  530. if errIn != nil {
  531. return errIn
  532. }
  533. return err
  534. }
  535. // PrintGraphViz prints directly the GraphViz() output
  536. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  537. if rootKey == nil {
  538. rootKey = mt.Root()
  539. }
  540. w := bytes.NewBufferString("")
  541. fmt.Fprintf(w, "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  542. err := mt.GraphViz(w, nil)
  543. if err != nil {
  544. return err
  545. }
  546. fmt.Fprintf(w, "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  547. fmt.Println(w)
  548. return nil
  549. }