You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

432 lines
12 KiB

  1. package merkletree
  2. import (
  3. "bytes"
  4. "errors"
  5. "math/big"
  6. "sync"
  7. "github.com/iden3/go-iden3-core/common"
  8. "github.com/iden3/go-iden3-core/db"
  9. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  10. )
  11. const (
  12. // proofFlagsLen is the byte length of the flags in the proof header (first 32
  13. // bytes).
  14. proofFlagsLen = 2
  15. // ElemBytesLen is the length of the Hash byte array
  16. ElemBytesLen = 32
  17. )
  18. var (
  19. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  20. ErrNodeKeyAlreadyExists = errors.New("node already exists")
  21. // ErrEntryIndexNotFound is used when no entry is found for an index.
  22. ErrEntryIndexNotFound = errors.New("node index not found in the DB")
  23. // ErrNodeDataBadSize is used when the data of a node has an incorrect
  24. // size and can't be parsed.
  25. ErrNodeDataBadSize = errors.New("node data has incorrect size in the DB")
  26. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  27. // maximum level.
  28. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  29. // ErrInvalidNodeFound is used when an invalid node is found and can't
  30. // be parsed.
  31. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  32. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  33. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  34. // ErrInvalidDBValue is used when a value in the key value DB is
  35. // invalid (for example, it doen't contain a byte header and a []byte
  36. // body of at least len=1.
  37. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  38. // ErrEntryIndexAlreadyExists is used when the entry index already
  39. // exists in the tree.
  40. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  41. // ErrNotWritable is used when the MerkleTree is not writable and a write function is called
  42. ErrNotWritable = errors.New("Merkle Tree not writable")
  43. rootNodeValue = []byte("currentroot")
  44. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  45. )
  46. type Hash [32]byte
  47. func (h Hash) String() string {
  48. return new(big.Int).SetBytes(h[:]).String()
  49. }
  50. func (h *Hash) BigInt() *big.Int {
  51. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  52. }
  53. func NewHashFromBigInt(b *big.Int) *Hash {
  54. r := &Hash{}
  55. copy(r[:], common.SwapEndianness(b.Bytes()))
  56. return r
  57. }
  58. type MerkleTree struct {
  59. sync.RWMutex
  60. db db.Storage
  61. rootKey *Hash
  62. writable bool
  63. maxLevels int
  64. }
  65. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  66. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  67. v, err := mt.db.Get(rootNodeValue)
  68. if err != nil {
  69. tx, err := mt.db.NewTx()
  70. if err != nil {
  71. return nil, err
  72. }
  73. mt.rootKey = &HashZero
  74. tx.Put(rootNodeValue, mt.rootKey[:])
  75. err = tx.Commit()
  76. if err != nil {
  77. return nil, err
  78. }
  79. return &mt, nil
  80. }
  81. mt.rootKey = &Hash{}
  82. copy(mt.rootKey[:], v)
  83. return &mt, nil
  84. }
  85. func (mt *MerkleTree) Root() *Hash {
  86. return mt.rootKey
  87. }
  88. func (mt *MerkleTree) Add(k, v *big.Int) error {
  89. // verify that the MerkleTree is writable
  90. if !mt.writable {
  91. return ErrNotWritable
  92. }
  93. // verfy that the ElemBytes are valid and fit inside the Finite Field.
  94. if !cryptoUtils.CheckBigIntInField(k) {
  95. return errors.New("Key not inside the Finite Field")
  96. }
  97. if !cryptoUtils.CheckBigIntInField(v) {
  98. return errors.New("Value not inside the Finite Field")
  99. }
  100. tx, err := mt.db.NewTx()
  101. if err != nil {
  102. return err
  103. }
  104. mt.Lock()
  105. defer mt.Unlock()
  106. kHash := NewHashFromBigInt(k)
  107. vHash := NewHashFromBigInt(v)
  108. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  109. path := getPath(mt.maxLevels, kHash[:])
  110. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  111. if err != nil {
  112. return err
  113. }
  114. mt.rootKey = newRootKey
  115. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  116. if err := tx.Commit(); err != nil {
  117. return err
  118. }
  119. return nil
  120. }
  121. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  122. // from newLeaf, at which point both leafs are stored, all while updating the
  123. // path.
  124. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  125. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  126. if lvl > mt.maxLevels-2 {
  127. return nil, ErrReachedMaxLevel
  128. }
  129. var newNodeMiddle *Node
  130. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  131. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  132. if err != nil {
  133. return nil, err
  134. }
  135. if pathNewLeaf[lvl] {
  136. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  137. } else {
  138. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  139. }
  140. return mt.addNode(tx, newNodeMiddle)
  141. } else {
  142. oldLeafKey, err := oldLeaf.Key()
  143. if err != nil {
  144. return nil, err
  145. }
  146. newLeafKey, err := newLeaf.Key()
  147. if err != nil {
  148. return nil, err
  149. }
  150. if pathNewLeaf[lvl] {
  151. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  152. } else {
  153. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  154. }
  155. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  156. _, err = mt.addNode(tx, newLeaf)
  157. if err != nil {
  158. return nil, err
  159. }
  160. return mt.addNode(tx, newNodeMiddle)
  161. }
  162. }
  163. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  164. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  165. lvl int, path []bool) (*Hash, error) {
  166. var err error
  167. var nextKey *Hash
  168. if lvl > mt.maxLevels-1 {
  169. return nil, ErrReachedMaxLevel
  170. }
  171. n, err := mt.GetNode(key)
  172. if err != nil {
  173. return nil, err
  174. }
  175. switch n.Type {
  176. case NodeTypeEmpty:
  177. // We can add newLeaf now
  178. return mt.addNode(tx, newLeaf)
  179. case NodeTypeLeaf:
  180. nKey := n.Entry[0]
  181. // Check if leaf node found contains the leaf node we are trying to add
  182. newLeafKey := newLeaf.Entry[0]
  183. if bytes.Equal(nKey[:], newLeafKey[:]) {
  184. return nil, ErrEntryIndexAlreadyExists
  185. }
  186. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  187. // We need to push newLeaf down until its path diverges from n's path
  188. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  189. case NodeTypeMiddle:
  190. // We need to go deeper, continue traversing the tree, left or right depending on path
  191. var newNodeMiddle *Node
  192. if path[lvl] {
  193. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  194. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  195. } else {
  196. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  197. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  198. }
  199. if err != nil {
  200. return nil, err
  201. }
  202. // Update the node to reflect the modified child
  203. return mt.addNode(tx, newNodeMiddle)
  204. default:
  205. return nil, ErrInvalidNodeFound
  206. }
  207. }
  208. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  209. // they are all the same and assumed to always exist.
  210. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  211. // verify that the MerkleTree is writable
  212. if !mt.writable {
  213. return nil, ErrNotWritable
  214. }
  215. if n.Type == NodeTypeEmpty {
  216. return n.Key()
  217. }
  218. k, err := n.Key()
  219. if err != nil {
  220. return nil, err
  221. }
  222. v := n.Value()
  223. // Check that the node key doesn't already exist
  224. if _, err := tx.Get(k[:]); err == nil {
  225. return nil, ErrNodeKeyAlreadyExists
  226. }
  227. tx.Put(k[:], v)
  228. return k, nil
  229. }
  230. // dbGet is a helper function to get the node of a key from the internal
  231. // storage.
  232. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
  233. if bytes.Equal(k, HashZero[:]) {
  234. return 0, nil, nil
  235. }
  236. value, err := mt.db.Get(k)
  237. if err != nil {
  238. return 0, nil, err
  239. }
  240. if len(value) < 2 {
  241. return 0, nil, ErrInvalidDBValue
  242. }
  243. nodeType := value[0]
  244. nodeBytes := value[1:]
  245. return NodeType(nodeType), nodeBytes, nil
  246. }
  247. // dbInsert is a helper function to insert a node into a key in an open db
  248. // transaction.
  249. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
  250. v := append([]byte{byte(t)}, data...)
  251. tx.Put(k, v)
  252. }
  253. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  254. // tree; they are all the same and assumed to always exist.
  255. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  256. if bytes.Equal(key[:], HashZero[:]) {
  257. return NewNodeEmpty(), nil
  258. }
  259. nBytes, err := mt.db.Get(key[:])
  260. if err != nil {
  261. return nil, err
  262. }
  263. return NewNodeFromBytes(nBytes)
  264. }
  265. // getPath returns the binary path, from the root to the leaf.
  266. func getPath(numLevels int, k []byte) []bool {
  267. path := make([]bool, numLevels)
  268. for n := 0; n < numLevels; n++ {
  269. path[n] = common.TestBit(k[:], uint(n))
  270. }
  271. return path
  272. }
  273. // NodeAux contains the auxiliary node used in a non-existence proof.
  274. type NodeAux struct {
  275. Key *Hash
  276. Value *Hash
  277. }
  278. // Proof defines the required elements for a MT proof of existence or non-existence.
  279. type Proof struct {
  280. // existence indicates wether this is a proof of existence or non-existence.
  281. Existence bool
  282. // depth indicates how deep in the tree the proof goes.
  283. depth uint
  284. // notempties is a bitmap of non-empty Siblings found in Siblings.
  285. notempties [ElemBytesLen - proofFlagsLen]byte
  286. // Siblings is a list of non-empty sibling keys.
  287. Siblings []*Hash
  288. NodeAux *NodeAux
  289. }
  290. // NewProofFromBytes parses a byte array into a Proof.
  291. func NewProofFromBytes(bs []byte) (*Proof, error) {
  292. if len(bs) < ElemBytesLen {
  293. return nil, ErrInvalidProofBytes
  294. }
  295. p := &Proof{}
  296. if (bs[0] & 0x01) == 0 {
  297. p.Existence = true
  298. }
  299. p.depth = uint(bs[1])
  300. copy(p.notempties[:], bs[proofFlagsLen:ElemBytesLen])
  301. siblingBytes := bs[ElemBytesLen:]
  302. sibIdx := 0
  303. for i := uint(0); i < p.depth; i++ {
  304. if common.TestBitBigEndian(p.notempties[:], i) {
  305. if len(siblingBytes) < (sibIdx+1)*ElemBytesLen {
  306. return nil, ErrInvalidProofBytes
  307. }
  308. var sib Hash
  309. copy(sib[:], siblingBytes[sibIdx*ElemBytesLen:(sibIdx+1)*ElemBytesLen])
  310. p.Siblings = append(p.Siblings, &sib)
  311. sibIdx++
  312. }
  313. }
  314. if !p.Existence && ((bs[0] & 0x02) != 0) {
  315. p.NodeAux = &NodeAux{Key: &Hash{}, Value: &Hash{}}
  316. nodeAuxBytes := siblingBytes[len(p.Siblings)*ElemBytesLen:]
  317. if len(nodeAuxBytes) != 2*ElemBytesLen {
  318. return nil, ErrInvalidProofBytes
  319. }
  320. copy(p.NodeAux.Key[:], nodeAuxBytes[:ElemBytesLen])
  321. copy(p.NodeAux.Value[:], nodeAuxBytes[ElemBytesLen:2*ElemBytesLen])
  322. }
  323. return p, nil
  324. }
  325. // Bytes serializes a Proof into a byte array.
  326. func (p *Proof) Bytes() []byte {
  327. bsLen := proofFlagsLen + len(p.notempties) + ElemBytesLen*len(p.Siblings)
  328. if p.NodeAux != nil {
  329. bsLen += 2 * ElemBytesLen
  330. }
  331. bs := make([]byte, bsLen)
  332. if !p.Existence {
  333. bs[0] |= 0x01
  334. }
  335. bs[1] = byte(p.depth)
  336. copy(bs[proofFlagsLen:len(p.notempties)+proofFlagsLen], p.notempties[:])
  337. siblingsBytes := bs[len(p.notempties)+proofFlagsLen:]
  338. for i, k := range p.Siblings {
  339. copy(siblingsBytes[i*ElemBytesLen:(i+1)*ElemBytesLen], k[:])
  340. }
  341. if p.NodeAux != nil {
  342. bs[0] |= 0x02
  343. copy(bs[len(bs)-2*ElemBytesLen:], p.NodeAux.Key[:])
  344. copy(bs[len(bs)-1*ElemBytesLen:], p.NodeAux.Value[:])
  345. }
  346. return bs
  347. }
  348. // GenerateProof generates the proof of existence (or non-existence) of an
  349. // Entry's hash Index for a Merkle Tree given the root.
  350. // If the rootKey is nil, the current merkletree root is used
  351. func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
  352. p := &Proof{}
  353. var siblingKey *Hash
  354. kHash := NewHashFromBigInt(k)
  355. path := getPath(mt.maxLevels, kHash[:])
  356. if rootKey == nil {
  357. rootKey = mt.Root()
  358. }
  359. nextKey := rootKey
  360. for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
  361. n, err := mt.GetNode(nextKey)
  362. if err != nil {
  363. return nil, err
  364. }
  365. switch n.Type {
  366. case NodeTypeEmpty:
  367. return p, nil
  368. case NodeTypeLeaf:
  369. if bytes.Equal(kHash[:], n.Entry[0][:]) {
  370. p.Existence = true
  371. return p, nil
  372. } else {
  373. // We found a leaf whose entry didn't match hIndex
  374. p.NodeAux = &NodeAux{Key: n.Entry[0], Value: n.Entry[1]}
  375. return p, nil
  376. }
  377. case NodeTypeMiddle:
  378. if path[p.depth] {
  379. nextKey = n.ChildR
  380. siblingKey = n.ChildL
  381. } else {
  382. nextKey = n.ChildL
  383. siblingKey = n.ChildR
  384. }
  385. default:
  386. return nil, ErrInvalidNodeFound
  387. }
  388. if !bytes.Equal(siblingKey[:], HashZero[:]) {
  389. common.SetBitBigEndian(p.notempties[:], uint(p.depth))
  390. p.Siblings = append(p.Siblings, siblingKey)
  391. }
  392. }
  393. return nil, ErrEntryIndexNotFound
  394. }