You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

302 lines
8.3 KiB

  1. package merkletree
  2. import (
  3. "bytes"
  4. "errors"
  5. "math/big"
  6. "sync"
  7. "github.com/iden3/go-iden3-core/common"
  8. "github.com/iden3/go-iden3-core/db"
  9. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  10. )
  11. const (
  12. // proofFlagsLen is the byte length of the flags in the proof header (first 32
  13. // bytes).
  14. proofFlagsLen = 2
  15. // ElemBytesLen is the length of the Hash byte array
  16. ElemBytesLen = 32
  17. )
  18. var (
  19. // ErrNodeKeyAlreadyExists is used when a node key already exists.
  20. ErrNodeKeyAlreadyExists = errors.New("node already exists")
  21. // ErrEntryIndexNotFound is used when no entry is found for an index.
  22. ErrEntryIndexNotFound = errors.New("node index not found in the DB")
  23. // ErrNodeDataBadSize is used when the data of a node has an incorrect
  24. // size and can't be parsed.
  25. ErrNodeDataBadSize = errors.New("node data has incorrect size in the DB")
  26. // ErrReachedMaxLevel is used when a traversal of the MT reaches the
  27. // maximum level.
  28. ErrReachedMaxLevel = errors.New("reached maximum level of the merkle tree")
  29. // ErrInvalidNodeFound is used when an invalid node is found and can't
  30. // be parsed.
  31. ErrInvalidNodeFound = errors.New("found an invalid node in the DB")
  32. // ErrInvalidProofBytes is used when a serialized proof is invalid.
  33. ErrInvalidProofBytes = errors.New("the serialized proof is invalid")
  34. // ErrInvalidDBValue is used when a value in the key value DB is
  35. // invalid (for example, it doen't contain a byte header and a []byte
  36. // body of at least len=1.
  37. ErrInvalidDBValue = errors.New("the value in the DB is invalid")
  38. // ErrEntryIndexAlreadyExists is used when the entry index already
  39. // exists in the tree.
  40. ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
  41. // ErrNotWritable is used when the MerkleTree is not writable and a write function is called
  42. ErrNotWritable = errors.New("Merkle Tree not writable")
  43. rootNodeValue = []byte("currentroot")
  44. HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  45. )
  46. type Hash [32]byte
  47. func (h Hash) String() string {
  48. return new(big.Int).SetBytes(h[:]).String()
  49. }
  50. func (h *Hash) BigInt() *big.Int {
  51. return new(big.Int).SetBytes(common.SwapEndianness(h[:]))
  52. }
  53. func NewHashFromBigInt(b *big.Int) *Hash {
  54. r := &Hash{}
  55. copy(r[:], common.SwapEndianness(b.Bytes()))
  56. return r
  57. }
  58. type MerkleTree struct {
  59. sync.RWMutex
  60. db db.Storage
  61. rootKey *Hash
  62. writable bool
  63. maxLevels int
  64. }
  65. func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
  66. mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
  67. v, err := mt.db.Get(rootNodeValue)
  68. if err != nil {
  69. tx, err := mt.db.NewTx()
  70. if err != nil {
  71. return nil, err
  72. }
  73. mt.rootKey = &HashZero
  74. tx.Put(rootNodeValue, mt.rootKey[:])
  75. err = tx.Commit()
  76. if err != nil {
  77. return nil, err
  78. }
  79. return &mt, nil
  80. }
  81. mt.rootKey = &Hash{}
  82. copy(mt.rootKey[:], v)
  83. return &mt, nil
  84. }
  85. func (mt *MerkleTree) Root() *Hash {
  86. return mt.rootKey
  87. }
  88. func (mt *MerkleTree) Add(k, v *big.Int) error {
  89. // verify that the MerkleTree is writable
  90. if !mt.writable {
  91. return ErrNotWritable
  92. }
  93. // verfy that the ElemBytes are valid and fit inside the Finite Field.
  94. if !cryptoUtils.CheckBigIntInField(k) {
  95. return errors.New("Key not inside the Finite Field")
  96. }
  97. if !cryptoUtils.CheckBigIntInField(v) {
  98. return errors.New("Value not inside the Finite Field")
  99. }
  100. tx, err := mt.db.NewTx()
  101. if err != nil {
  102. return err
  103. }
  104. mt.Lock()
  105. defer mt.Unlock()
  106. kHash := NewHashFromBigInt(k)
  107. vHash := NewHashFromBigInt(v)
  108. newNodeLeaf := NewNodeLeaf(kHash, vHash)
  109. path := getPath(mt.maxLevels, kHash[:])
  110. newRootKey, err := mt.addLeaf(tx, newNodeLeaf, mt.rootKey, 0, path)
  111. if err != nil {
  112. return err
  113. }
  114. mt.rootKey = newRootKey
  115. mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
  116. if err := tx.Commit(); err != nil {
  117. return err
  118. }
  119. return nil
  120. }
  121. // pushLeaf recursively pushes an existing oldLeaf down until its path diverges
  122. // from newLeaf, at which point both leafs are stored, all while updating the
  123. // path.
  124. func (mt *MerkleTree) pushLeaf(tx db.Tx, newLeaf *Node, oldLeaf *Node,
  125. lvl int, pathNewLeaf []bool, pathOldLeaf []bool) (*Hash, error) {
  126. if lvl > mt.maxLevels-2 {
  127. return nil, ErrReachedMaxLevel
  128. }
  129. var newNodeMiddle *Node
  130. if pathNewLeaf[lvl] == pathOldLeaf[lvl] { // We need to go deeper!
  131. nextKey, err := mt.pushLeaf(tx, newLeaf, oldLeaf, lvl+1, pathNewLeaf, pathOldLeaf)
  132. if err != nil {
  133. return nil, err
  134. }
  135. if pathNewLeaf[lvl] {
  136. newNodeMiddle = NewNodeMiddle(&HashZero, nextKey) // go right
  137. } else {
  138. newNodeMiddle = NewNodeMiddle(nextKey, &HashZero) // go left
  139. }
  140. return mt.addNode(tx, newNodeMiddle)
  141. } else {
  142. oldLeafKey, err := oldLeaf.Key()
  143. if err != nil {
  144. return nil, err
  145. }
  146. newLeafKey, err := newLeaf.Key()
  147. if err != nil {
  148. return nil, err
  149. }
  150. if pathNewLeaf[lvl] {
  151. newNodeMiddle = NewNodeMiddle(oldLeafKey, newLeafKey)
  152. } else {
  153. newNodeMiddle = NewNodeMiddle(newLeafKey, oldLeafKey)
  154. }
  155. // We can add newLeaf now. We don't need to add oldLeaf because it's already in the tree.
  156. _, err = mt.addNode(tx, newLeaf)
  157. if err != nil {
  158. return nil, err
  159. }
  160. return mt.addNode(tx, newNodeMiddle)
  161. }
  162. }
  163. // addLeaf recursively adds a newLeaf in the MT while updating the path.
  164. func (mt *MerkleTree) addLeaf(tx db.Tx, newLeaf *Node, key *Hash,
  165. lvl int, path []bool) (*Hash, error) {
  166. var err error
  167. var nextKey *Hash
  168. if lvl > mt.maxLevels-1 {
  169. return nil, ErrReachedMaxLevel
  170. }
  171. n, err := mt.GetNode(key)
  172. if err != nil {
  173. return nil, err
  174. }
  175. switch n.Type {
  176. case NodeTypeEmpty:
  177. // We can add newLeaf now
  178. return mt.addNode(tx, newLeaf)
  179. case NodeTypeLeaf:
  180. nKey := n.Entry[0]
  181. // Check if leaf node found contains the leaf node we are trying to add
  182. newLeafKey := newLeaf.Entry[0]
  183. if bytes.Equal(nKey[:], newLeafKey[:]) {
  184. return nil, ErrEntryIndexAlreadyExists
  185. }
  186. pathOldLeaf := getPath(mt.maxLevels, nKey[:])
  187. // We need to push newLeaf down until its path diverges from n's path
  188. return mt.pushLeaf(tx, newLeaf, n, lvl, path, pathOldLeaf)
  189. case NodeTypeMiddle:
  190. // We need to go deeper, continue traversing the tree, left or right depending on path
  191. var newNodeMiddle *Node
  192. if path[lvl] {
  193. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildR, lvl+1, path) // go right
  194. newNodeMiddle = NewNodeMiddle(n.ChildL, nextKey)
  195. } else {
  196. nextKey, err = mt.addLeaf(tx, newLeaf, n.ChildL, lvl+1, path) // go left
  197. newNodeMiddle = NewNodeMiddle(nextKey, n.ChildR)
  198. }
  199. if err != nil {
  200. return nil, err
  201. }
  202. // Update the node to reflect the modified child
  203. return mt.addNode(tx, newNodeMiddle)
  204. default:
  205. return nil, ErrInvalidNodeFound
  206. }
  207. }
  208. // addNode adds a node into the MT. Empty nodes are not stored in the tree;
  209. // they are all the same and assumed to always exist.
  210. func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
  211. // verify that the MerkleTree is writable
  212. if !mt.writable {
  213. return nil, ErrNotWritable
  214. }
  215. if n.Type == NodeTypeEmpty {
  216. return n.Key()
  217. }
  218. k, err := n.Key()
  219. if err != nil {
  220. return nil, err
  221. }
  222. v := n.Value()
  223. // Check that the node key doesn't already exist
  224. if _, err := tx.Get(k[:]); err == nil {
  225. return nil, ErrNodeKeyAlreadyExists
  226. }
  227. tx.Put(k[:], v)
  228. return k, nil
  229. }
  230. // dbGet is a helper function to get the node of a key from the internal
  231. // storage.
  232. func (mt *MerkleTree) dbGet(k []byte) (NodeType, []byte, error) {
  233. if bytes.Equal(k, HashZero[:]) {
  234. return 0, nil, nil
  235. }
  236. value, err := mt.db.Get(k)
  237. if err != nil {
  238. return 0, nil, err
  239. }
  240. if len(value) < 2 {
  241. return 0, nil, ErrInvalidDBValue
  242. }
  243. nodeType := value[0]
  244. nodeBytes := value[1:]
  245. return NodeType(nodeType), nodeBytes, nil
  246. }
  247. // dbInsert is a helper function to insert a node into a key in an open db
  248. // transaction.
  249. func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) {
  250. v := append([]byte{byte(t)}, data...)
  251. tx.Put(k, v)
  252. }
  253. // GetNode gets a node by key from the MT. Empty nodes are not stored in the
  254. // tree; they are all the same and assumed to always exist.
  255. func (mt *MerkleTree) GetNode(key *Hash) (*Node, error) {
  256. if bytes.Equal(key[:], HashZero[:]) {
  257. return NewNodeEmpty(), nil
  258. }
  259. nBytes, err := mt.db.Get(key[:])
  260. if err != nil {
  261. return nil, err
  262. }
  263. return NewNodeFromBytes(nBytes)
  264. }
  265. // getPath returns the binary path, from the root to the leaf.
  266. func getPath(numLevels int, k []byte) []bool {
  267. path := make([]bool, numLevels)
  268. for n := 0; n < numLevels; n++ {
  269. path[n] = common.TestBit(k[:], uint(n))
  270. }
  271. return path
  272. }