You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

586 lines
16 KiB

  1. package merkletree
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "math/big"
  8. "sync"
  9. "github.com/iden3/go-iden3-core/common"
  10. "github.com/iden3/go-iden3-core/db"
  11. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  12. )
  13. const (
  14. // proofFlagsLen is the byte length of the flags in the proof header (first 32
  15. // bytes).
  16. proofFlagsLen = 2
  17. // ElemBytesLen is the length of the Hash byte array
  18. ElemBytesLen = 32
  19. )
  20. var (
  21. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  22. ErrNodeKeyAlreadyExists = errors.New("node already exists")
  23. // ErrEntryIndexNotFound is used when no entry is found for an index.
  24. ErrEntryIndexNotFound = errors.New("node index not found in the DB")
  25. // ErrNodeDataBadSize is used when the data of a node has an incorrect
  26. // size and can't be parsed.
  27. ErrNodeDataBadSize = errors.New("node data has incorrect size in the DB")
  28. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  29. // maximum level.
  30. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  31. // ErrInvalidNodeFound is used when an invalid node is found and can't
  32. // be parsed.
  33. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  34. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  35. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  36. // ErrInvalidDBValue is used when a value in the key value DB is
  37. // invalid (for example, it doen't contain a byte header and a []byte
  38. // body of at least len=1.
  39. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  40. // ErrEntryIndexAlreadyExists is used when the entry index already
  41. // exists in the tree.
  42. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  43. // ErrNotWritable is used when the MerkleTree is not writable and a write function is called
  44. ErrNotWritable = errors.New("Merkle Tree not writable")
  45. rootNodeValue = []byte("currentroot")
  46. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  47. )
  48. type Hash [32]byte
  49. func (h Hash) String() string {
  50. return new(big.Int).SetBytes(h[:]).String()
  51. }
  52. func (h *Hash) BigInt() *big.Int {
  53. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  54. }
  55. func NewHashFromBigInt(b *big.Int) *Hash {
  56. r := &Hash{}
  57. copy(r[:], common.SwapEndianness(b.Bytes()))
  58. return r
  59. }
  60. type MerkleTree struct {
  61. sync.RWMutex
  62. db db.Storage
  63. rootKey *Hash
  64. writable bool
  65. maxLevels int
  66. }
  67. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  68. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  69. v, err := mt.db.Get(rootNodeValue)
  70. if err != nil {
  71. tx, err := mt.db.NewTx()
  72. if err != nil {
  73. return nil, err
  74. }
  75. mt.rootKey = &HashZero
  76. tx.Put(rootNodeValue, mt.rootKey[:])
  77. err = tx.Commit()
  78. if err != nil {
  79. return nil, err
  80. }
  81. return &mt, nil
  82. }
  83. mt.rootKey = &Hash{}
  84. copy(mt.rootKey[:], v)
  85. return &mt, nil
  86. }
  87. func (mt *MerkleTree) Root() *Hash {
  88. return mt.rootKey
  89. }
  90. func (mt *MerkleTree) Add(k, v *big.Int) error {
  91. // verify that the MerkleTree is writable
  92. if !mt.writable {
  93. return ErrNotWritable
  94. }
  95. // verfy that the ElemBytes are valid and fit inside the Finite Field.
  96. if !cryptoUtils.CheckBigIntInField(k) {
  97. return errors.New("Key not inside the Finite Field")
  98. }
  99. if !cryptoUtils.CheckBigIntInField(v) {
  100. return errors.New("Value not inside the Finite Field")
  101. }
  102. tx, err := mt.db.NewTx()
  103. if err != nil {
  104. return err
  105. }
  106. mt.Lock()
  107. defer mt.Unlock()
  108. kHash := NewHashFromBigInt(k)
  109. vHash := NewHashFromBigInt(v)
  110. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  111. path := getPath(mt.maxLevels, kHash[:])
  112. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  113. if err != nil {
  114. return err
  115. }
  116. mt.rootKey = newRootKey
  117. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  118. if err := tx.Commit(); err != nil {
  119. return err
  120. }
  121. return nil
  122. }
  123. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  124. // from newLeaf, at which point both leafs are stored, all while updating the
  125. // path.
  126. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  127. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  128. if lvl > mt.maxLevels-2 {
  129. return nil, ErrReachedMaxLevel
  130. }
  131. var newNodeMiddle *Node
  132. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  133. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  134. if err != nil {
  135. return nil, err
  136. }
  137. if pathNewLeaf[lvl] {
  138. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  139. } else {
  140. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  141. }
  142. return mt.addNode(tx, newNodeMiddle)
  143. } else {
  144. oldLeafKey, err := oldLeaf.Key()
  145. if err != nil {
  146. return nil, err
  147. }
  148. newLeafKey, err := newLeaf.Key()
  149. if err != nil {
  150. return nil, err
  151. }
  152. if pathNewLeaf[lvl] {
  153. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  154. } else {
  155. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  156. }
  157. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  158. _, err = mt.addNode(tx, newLeaf)
  159. if err != nil {
  160. return nil, err
  161. }
  162. return mt.addNode(tx, newNodeMiddle)
  163. }
  164. }
  165. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  166. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  167. lvl int, path []bool) (*Hash, error) {
  168. var err error
  169. var nextKey *Hash
  170. if lvl > mt.maxLevels-1 {
  171. return nil, ErrReachedMaxLevel
  172. }
  173. n, err := mt.GetNode(key)
  174. if err != nil {
  175. return nil, err
  176. }
  177. switch n.Type {
  178. case NodeTypeEmpty:
  179. // We can add newLeaf now
  180. return mt.addNode(tx, newLeaf)
  181. case NodeTypeLeaf:
  182. nKey := n.Entry[0]
  183. // Check if leaf node found contains the leaf node we are trying to add
  184. newLeafKey := newLeaf.Entry[0]
  185. if bytes.Equal(nKey[:], newLeafKey[:]) {
  186. return nil, ErrEntryIndexAlreadyExists
  187. }
  188. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  189. // We need to push newLeaf down until its path diverges from n's path
  190. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  191. case NodeTypeMiddle:
  192. // We need to go deeper, continue traversing the tree, left or right depending on path
  193. var newNodeMiddle *Node
  194. if path[lvl] {
  195. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  196. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  197. } else {
  198. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  199. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  200. }
  201. if err != nil {
  202. return nil, err
  203. }
  204. // Update the node to reflect the modified child
  205. return mt.addNode(tx, newNodeMiddle)
  206. default:
  207. return nil, ErrInvalidNodeFound
  208. }
  209. }
  210. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  211. // they are all the same and assumed to always exist.
  212. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  213. // verify that the MerkleTree is writable
  214. if !mt.writable {
  215. return nil, ErrNotWritable
  216. }
  217. if n.Type == NodeTypeEmpty {
  218. return n.Key()
  219. }
  220. k, err := n.Key()
  221. if err != nil {
  222. return nil, err
  223. }
  224. v := n.Value()
  225. // Check that the node key doesn't already exist
  226. if _, err := tx.Get(k[:]); err == nil {
  227. return nil, ErrNodeKeyAlreadyExists
  228. }
  229. tx.Put(k[:], v)
  230. return k, nil
  231. }
  232. // dbGet is a helper function to get the node of a key from the internal
  233. // storage.
  234. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
  235. if bytes.Equal(k, HashZero[:]) {
  236. return 0, nil, nil
  237. }
  238. value, err := mt.db.Get(k)
  239. if err != nil {
  240. return 0, nil, err
  241. }
  242. if len(value) < 2 {
  243. return 0, nil, ErrInvalidDBValue
  244. }
  245. nodeType := value[0]
  246. nodeBytes := value[1:]
  247. return NodeType(nodeType), nodeBytes, nil
  248. }
  249. // dbInsert is a helper function to insert a node into a key in an open db
  250. // transaction.
  251. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
  252. v := append([]byte{byte(t)}, data...)
  253. tx.Put(k, v)
  254. }
  255. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  256. // tree; they are all the same and assumed to always exist.
  257. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  258. if bytes.Equal(key[:], HashZero[:]) {
  259. return NewNodeEmpty(), nil
  260. }
  261. nBytes, err := mt.db.Get(key[:])
  262. if err != nil {
  263. return nil, err
  264. }
  265. return NewNodeFromBytes(nBytes)
  266. }
  267. // getPath returns the binary path, from the root to the leaf.
  268. func getPath(numLevels int, k []byte) []bool {
  269. path := make([]bool, numLevels)
  270. for n := 0; n < numLevels; n++ {
  271. path[n] = common.TestBit(k[:], uint(n))
  272. }
  273. return path
  274. }
  275. // NodeAux contains the auxiliary node used in a non-existence proof.
  276. type NodeAux struct {
  277. Key *Hash
  278. Value *Hash
  279. }
  280. // Proof defines the required elements for a MT proof of existence or non-existence.
  281. type Proof struct {
  282. // existence indicates wether this is a proof of existence or non-existence.
  283. Existence bool
  284. // depth indicates how deep in the tree the proof goes.
  285. depth uint
  286. // notempties is a bitmap of non-empty Siblings found in Siblings.
  287. notempties [ElemBytesLen - proofFlagsLen]byte
  288. // Siblings is a list of non-empty sibling keys.
  289. Siblings []*Hash
  290. NodeAux *NodeAux
  291. }
  292. // NewProofFromBytes parses a byte array into a Proof.
  293. func NewProofFromBytes(bs []byte) (*Proof, error) {
  294. if len(bs) < ElemBytesLen {
  295. return nil, ErrInvalidProofBytes
  296. }
  297. p := &Proof{}
  298. if (bs[0] & 0x01) == 0 {
  299. p.Existence = true
  300. }
  301. p.depth = uint(bs[1])
  302. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  303. siblingBytes := bs[ElemBytesLen:]
  304. sibIdx := 0
  305. for i := uint(0); i < p.depth; i++ {
  306. if common.TestBitBigEndian(p.notempties[:], i) {
  307. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  308. return nil, ErrInvalidProofBytes
  309. }
  310. var sib Hash
  311. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  312. p.Siblings = append(p.Siblings, &sib)
  313. sibIdx++
  314. }
  315. }
  316. if !p.Existence && ((bs[0] & 0x02) != 0) {
  317. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  318. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  319. if len(nodeAuxBytes) != 2*ElemBytesLen {
  320. return nil, ErrInvalidProofBytes
  321. }
  322. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  323. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  324. }
  325. return p, nil
  326. }
  327. // Bytes serializes a Proof into a byte array.
  328. func (p *Proof) Bytes() []byte {
  329. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  330. if p.NodeAux != nil {
  331. bsLen += 2 * ElemBytesLen
  332. }
  333. bs := make([]byte, bsLen)
  334. if !p.Existence {
  335. bs[0] |= 0x01
  336. }
  337. bs[1] = byte(p.depth)
  338. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  339. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  340. for i, k := range p.Siblings {
  341. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  342. }
  343. if p.NodeAux != nil {
  344. bs[0] |= 0x02
  345. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  346. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  347. }
  348. return bs
  349. }
  350. // GenerateProof generates the proof of existence (or non-existence) of an
  351. // Entry's hash Index for a Merkle Tree given the root.
  352. // If the rootKey is nil, the current merkletree root is used
  353. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
  354. p := &Proof{}
  355. var siblingKey *Hash
  356. kHash := NewHashFromBigInt(k)
  357. path := getPath(mt.maxLevels, kHash[:])
  358. if rootKey == nil {
  359. rootKey = mt.Root()
  360. }
  361. nextKey := rootKey
  362. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  363. n, err := mt.GetNode(nextKey)
  364. if err != nil {
  365. return nil, err
  366. }
  367. switch n.Type {
  368. case NodeTypeEmpty:
  369. return p, nil
  370. case NodeTypeLeaf:
  371. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  372. p.Existence = true
  373. return p, nil
  374. } else {
  375. // We found a leaf whose entry didn't match hIndex
  376. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  377. return p, nil
  378. }
  379. case NodeTypeMiddle:
  380. if path[p.depth] {
  381. nextKey = n.ChildR
  382. siblingKey = n.ChildL
  383. } else {
  384. nextKey = n.ChildL
  385. siblingKey = n.ChildR
  386. }
  387. default:
  388. return nil, ErrInvalidNodeFound
  389. }
  390. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  391. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  392. p.Siblings = append(p.Siblings, siblingKey)
  393. }
  394. }
  395. return nil, ErrEntryIndexNotFound
  396. }
  397. // VerifyProof verifies the Merkle Proof for the entry and root.
  398. func VerifyProof(rootKey *Hash, proof *Proof, k, v *big.Int) bool {
  399. rootFromProof, err := RootFromProof(proof, k, v)
  400. if err != nil {
  401. return false
  402. }
  403. return bytes.Equal(rootKey[:], rootFromProof[:])
  404. }
  405. // RootFromProof calculates the root that would correspond to a tree whose
  406. // siblings are the ones in the proof with the claim hashing to hIndex and
  407. // hValue.
  408. func RootFromProof(proof *Proof, k, v *big.Int) (*Hash, error) {
  409. kHash := NewHashFromBigInt(k)
  410. vHash := NewHashFromBigInt(v)
  411. sibIdx := len(proof.Siblings) - 1
  412. var err error
  413. var midKey *Hash
  414. if proof.Existence {
  415. midKey, err = LeafKey(kHash, vHash)
  416. if err != nil {
  417. return nil, err
  418. }
  419. } else {
  420. if proof.NodeAux == nil {
  421. midKey = &HashZero
  422. } else {
  423. if bytes.Equal(kHash[:], proof.NodeAux.Key[:]) {
  424. return nil, fmt.Errorf("Non-existence proof being checked against hIndex equal to nodeAux")
  425. }
  426. midKey, err = LeafKey(proof.NodeAux.Key, proof.NodeAux.Value)
  427. if err != nil {
  428. return nil, err
  429. }
  430. }
  431. }
  432. path := getPath(int(proof.depth), kHash[:])
  433. var siblingKey *Hash
  434. for lvl := int(proof.depth) - 1; lvl >= 0; lvl-- {
  435. if common.TestBitBigEndian(proof.notempties[:], uint(lvl)) {
  436. siblingKey = proof.Siblings[sibIdx]
  437. sibIdx--
  438. } else {
  439. siblingKey = &HashZero
  440. }
  441. if path[lvl] {
  442. midKey, err = NewNodeMiddle(siblingKey, midKey).Key()
  443. if err != nil {
  444. return nil, err
  445. }
  446. } else {
  447. midKey, err = NewNodeMiddle(midKey, siblingKey).Key()
  448. if err != nil {
  449. return nil, err
  450. }
  451. }
  452. }
  453. return midKey, nil
  454. }
  455. // walk is a helper recursive function to iterate over all tree branches
  456. func (mt *MerkleTree) walk(key *Hash, f func(*Node)) error {
  457. n, err := mt.GetNode(key)
  458. if err != nil {
  459. return err
  460. }
  461. switch n.Type {
  462. case NodeTypeEmpty:
  463. f(n)
  464. case NodeTypeLeaf:
  465. f(n)
  466. case NodeTypeMiddle:
  467. f(n)
  468. if err := mt.walk(n.ChildL, f); err != nil {
  469. return err
  470. }
  471. if err := mt.walk(n.ChildR, f); err != nil {
  472. return err
  473. }
  474. default:
  475. return ErrInvalidNodeFound
  476. }
  477. return nil
  478. }
  479. // Walk iterates over all the branches of a MerkleTree with the given rootKey
  480. // if rootKey is nil, it will get the current RootKey of the current state of the MerkleTree.
  481. // For each node, it calls the f function given in the parameters.
  482. // See some examples of the Walk function usage in the merkletree_test.go
  483. // test functions: TestMTWalk, TestMTWalkGraphViz, TestMTWalkDumpClaims
  484. func (mt *MerkleTree) Walk(rootKey *Hash, f func(*Node)) error {
  485. if rootKey == nil {
  486. rootKey = mt.Root()
  487. }
  488. err := mt.walk(rootKey, f)
  489. return err
  490. }
  491. // GraphViz uses Walk function to generate a string GraphViz representation of the
  492. // tree and writes it to w
  493. func (mt *MerkleTree) GraphViz(w io.Writer, rootKey *Hash) error {
  494. fmt.Fprintf(w, `digraph hierarchy {
  495. node [fontname=Monospace,fontsize=10,shape=box]
  496. `)
  497. cnt := 0
  498. var errIn error
  499. err := mt.Walk(rootKey, func(n *Node) {
  500. k, err := n.Key()
  501. if err != nil {
  502. errIn = err
  503. }
  504. switch n.Type {
  505. case NodeTypeEmpty:
  506. case NodeTypeLeaf:
  507. fmt.Fprintf(w, "\"%v\" [style=filled];\n", k.BigInt().String())
  508. case NodeTypeMiddle:
  509. lr := [2]string{n.ChildL.BigInt().String(), n.ChildR.BigInt().String()}
  510. for i := range lr {
  511. if lr[i] == "0" {
  512. lr[i] = fmt.Sprintf("empty%v", cnt)
  513. fmt.Fprintf(w, "\"%v\" [style=dashed,label=0];\n", lr[i])
  514. cnt++
  515. }
  516. }
  517. fmt.Fprintf(w, "\"%v\" -> {\"%v\" \"%v\"}\n", k.BigInt().String(), lr[0], lr[1])
  518. default:
  519. }
  520. })
  521. fmt.Fprintf(w, "}\n")
  522. if errIn != nil {
  523. return errIn
  524. }
  525. return err
  526. }
  527. // PrintGraphViz prints directly the GraphViz() output
  528. func (mt *MerkleTree) PrintGraphViz(rootKey *Hash) error {
  529. if rootKey == nil {
  530. rootKey = mt.Root()
  531. }
  532. w := bytes.NewBufferString("")
  533. fmt.Fprintf(w, "--------\nGraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n")
  534. err := mt.GraphViz(w, nil)
  535. if err != nil {
  536. return err
  537. }
  538. fmt.Fprintf(w, "End of GraphViz of the MerkleTree with RootKey "+rootKey.BigInt().String()+"\n--------\n")
  539. fmt.Println(w)
  540. return nil
  541. }