You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

352 lines
10 KiB

6 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "errors"
  5. "github.com/syndtr/goleveldb/leveldb"
  6. )
  7. const (
  8. // EmptyNodeType indicates the type of an EmptyNodeValue Node
  9. EmptyNodeType = 00
  10. // NormalNodeType indicates the type of a middle Node
  11. normalNodeType = 01
  12. // FinalNodeType indicates the type of middle Node that is in an optimized branch, then in the value contains the value of the final leaf node of that branch
  13. finalNodeType = 02
  14. // ValueNodeType indicates the type of a value Node
  15. valueNodeType = 03
  16. // RootNodeType indicates the type of a root Node
  17. rootNodeType = 04
  18. )
  19. var (
  20. // ErrNodeAlreadyExists is an error that indicates that a node already exists in the merkletree database
  21. ErrNodeAlreadyExists = errors.New("node already exists")
  22. rootNodeValue = HashBytes([]byte("root"))
  23. // EmptyNodeValue is a [32]byte EmptyNodeValue array, all to zero
  24. EmptyNodeValue = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
  25. )
  26. // Hash used in this tree, is the [32]byte keccak()
  27. type Hash [32]byte
  28. // Value is the interface of a generic leaf, a key value object stored in the leveldb
  29. type Value interface {
  30. IndexLength() uint32 // returns the index length value
  31. Bytes() []byte // returns the value in byte array representation
  32. }
  33. //MerkleTree struct with the main elements of the Merkle Tree
  34. type MerkleTree struct {
  35. // sync.RWMutex
  36. storage *leveldb.DB
  37. root Hash
  38. numLevels int // Height of the Merkle Tree, number of levels
  39. }
  40. // New generates a new Merkle Tree
  41. func New(storage *leveldb.DB, numLevels int) (*MerkleTree, error) {
  42. var mt MerkleTree
  43. mt.storage = storage
  44. mt.numLevels = numLevels
  45. var err error
  46. _, _, rootHash, err := mt.Get(rootNodeValue)
  47. if err != nil {
  48. mt.root = EmptyNodeValue
  49. err = mt.Insert(rootNodeValue, rootNodeType, 0, mt.root[:])
  50. if err != nil {
  51. return nil, err
  52. }
  53. }
  54. copy(mt.root[:], rootHash)
  55. return &mt, nil
  56. }
  57. // Root returns the merkletree.Root
  58. func (mt *MerkleTree) Root() Hash {
  59. return mt.root
  60. }
  61. // NumLevels returns the merkletree.NumLevels
  62. func (mt *MerkleTree) NumLevels() int {
  63. return mt.numLevels
  64. }
  65. // Add adds the leaf to the MT
  66. func (mt *MerkleTree) Add(v Value) error {
  67. // add the leaf that we are adding
  68. mt.Insert(HashBytes(v.Bytes()), valueNodeType, v.IndexLength(), v.Bytes())
  69. hi := HashBytes(v.Bytes()[:v.IndexLength()])
  70. path := getPath(mt.numLevels, hi)
  71. nodeHash := mt.root
  72. var siblings []Hash
  73. for i := mt.numLevels - 2; i >= 0; i-- {
  74. nodeType, indexLength, nodeBytes, err := mt.Get(nodeHash)
  75. if err != nil {
  76. return err
  77. }
  78. if nodeType == byte(finalNodeType) {
  79. hiChild := HashBytes(nodeBytes[:indexLength])
  80. pathChild := getPath(mt.numLevels, hiChild)
  81. posDiff := comparePaths(pathChild, path)
  82. if posDiff == -1 {
  83. return ErrNodeAlreadyExists
  84. }
  85. finalNode1Hash := calcHashFromLeafAndLevel(posDiff, pathChild, HashBytes(nodeBytes))
  86. mt.Insert(finalNode1Hash, finalNodeType, indexLength, nodeBytes)
  87. finalNode2Hash := calcHashFromLeafAndLevel(posDiff, path, HashBytes(v.Bytes()))
  88. mt.Insert(finalNode2Hash, finalNodeType, v.IndexLength(), v.Bytes())
  89. // now the parent
  90. var parentNode treeNode
  91. if path[posDiff] {
  92. parentNode = treeNode{
  93. ChildL: finalNode1Hash,
  94. ChildR: finalNode2Hash,
  95. }
  96. } else {
  97. parentNode = treeNode{
  98. ChildL: finalNode2Hash,
  99. ChildR: finalNode1Hash,
  100. }
  101. }
  102. siblings = append(siblings, getEmptiesBetweenIAndPosHash(mt, i, posDiff+1)...)
  103. if mt.root, err = mt.replaceLeaf(siblings, path[posDiff+1:], parentNode.Ht(), normalNodeType, 0, parentNode.Bytes()); err != nil {
  104. return err
  105. }
  106. mt.Insert(rootNodeValue, rootNodeType, 0, mt.root[:])
  107. return nil
  108. }
  109. node := parseNodeBytes(nodeBytes)
  110. var sibling Hash
  111. if !path[i] {
  112. nodeHash = node.ChildL
  113. sibling = node.ChildR
  114. } else {
  115. nodeHash = node.ChildR
  116. sibling = node.ChildL
  117. }
  118. siblings = append(siblings, sibling)
  119. if bytes.Equal(nodeHash[:], EmptyNodeValue[:]) {
  120. // if the node is EmptyNodeValue, the leaf data will go directly at that height, as a Final Node
  121. if i == mt.numLevels-2 && bytes.Equal(siblings[len(siblings)-1][:], EmptyNodeValue[:]) {
  122. // if the pt node is the unique in the tree, just put it into the root node
  123. // this means to be in i==mt.NumLevels-2 && nodeHash==EmptyNodeValue
  124. finalNodeHash := calcHashFromLeafAndLevel(i+1, path, HashBytes(v.Bytes()))
  125. mt.Insert(finalNodeHash, finalNodeType, v.IndexLength(), v.Bytes())
  126. mt.root = finalNodeHash
  127. mt.Insert(rootNodeValue, rootNodeType, 0, mt.root[:])
  128. return nil
  129. }
  130. finalNodeHash := calcHashFromLeafAndLevel(i, path, HashBytes(v.Bytes()))
  131. if mt.root, err = mt.replaceLeaf(siblings, path[i:], finalNodeHash, finalNodeType, v.IndexLength(), v.Bytes()); err != nil {
  132. return err
  133. }
  134. mt.Insert(rootNodeValue, rootNodeType, 0, mt.root[:])
  135. return nil
  136. }
  137. }
  138. var err error
  139. mt.root, err = mt.replaceLeaf(siblings, path, HashBytes(v.Bytes()), valueNodeType, v.IndexLength(), v.Bytes())
  140. if err != nil {
  141. return err
  142. }
  143. mt.Insert(rootNodeValue, rootNodeType, 0, mt.root[:])
  144. return nil
  145. }
  146. // GenerateProof generates the Merkle Proof from a given leafHash for the current root
  147. func (mt *MerkleTree) GenerateProof(hi Hash) ([]byte, error) {
  148. var empties [32]byte
  149. path := getPath(mt.numLevels, hi)
  150. var siblings []Hash
  151. nodeHash := mt.root
  152. for level := 0; level < mt.numLevels-1; level++ {
  153. nodeType, indexLength, nodeBytes, err := mt.Get(nodeHash)
  154. if err != nil {
  155. return nil, err
  156. }
  157. if nodeType == byte(finalNodeType) {
  158. realValueInPos, err := mt.GetValueInPos(hi)
  159. if err != nil {
  160. return nil, err
  161. }
  162. if bytes.Equal(realValueInPos[:], EmptyNodeValue[:]) {
  163. // go until the path is different, then get the nodes between this FinalNode and the node in the diffPath, they will be the siblings of the merkle proof
  164. leafHi := HashBytes(nodeBytes[:indexLength]) // hi of element that was in the end of the branch (the finalNode)
  165. pathChild := getPath(mt.numLevels, leafHi)
  166. // get the position where the path is different
  167. posDiff := comparePaths(pathChild, path)
  168. if posDiff == -1 {
  169. return nil, ErrNodeAlreadyExists
  170. }
  171. if posDiff != mt.NumLevels()-1-level {
  172. sibling := calcHashFromLeafAndLevel(posDiff, pathChild, HashBytes(nodeBytes))
  173. setbitmap(empties[:], uint(mt.NumLevels()-2-posDiff))
  174. siblings = append([]Hash{sibling}, siblings...)
  175. }
  176. }
  177. break
  178. }
  179. node := parseNodeBytes(nodeBytes)
  180. var sibling Hash
  181. if !path[mt.numLevels-level-2] {
  182. nodeHash = node.ChildL
  183. sibling = node.ChildR
  184. } else {
  185. nodeHash = node.ChildR
  186. sibling = node.ChildL
  187. }
  188. if !bytes.Equal(sibling[:], EmptyNodeValue[:]) {
  189. setbitmap(empties[:], uint(level))
  190. siblings = append([]Hash{sibling}, siblings...)
  191. }
  192. }
  193. // merge empties and siblings
  194. var mp []byte
  195. mp = append(mp, empties[:]...)
  196. for k := range siblings {
  197. mp = append(mp, siblings[k][:]...)
  198. }
  199. return mp, nil
  200. }
  201. // GetValueInPos returns the merkletree value in the position of the Hash of the Index (Hi)
  202. func (mt *MerkleTree) GetValueInPos(hi Hash) ([]byte, error) {
  203. path := getPath(mt.numLevels, hi)
  204. nodeHash := mt.root
  205. for i := mt.numLevels - 2; i >= 0; i-- {
  206. nodeType, indexLength, nodeBytes, err := mt.Get(nodeHash)
  207. if err != nil {
  208. return nodeBytes, err
  209. }
  210. if nodeType == byte(finalNodeType) {
  211. // check if nodeBytes path is different of hi
  212. index := nodeBytes[:indexLength]
  213. hi := HashBytes(index)
  214. nodePath := getPath(mt.numLevels, hi)
  215. posDiff := comparePaths(path, nodePath)
  216. // if is different, return an EmptyNodeValue, else return the nodeBytes
  217. if posDiff != -1 {
  218. return EmptyNodeValue[:], nil
  219. }
  220. return nodeBytes, nil
  221. }
  222. node := parseNodeBytes(nodeBytes)
  223. if !path[i] {
  224. nodeHash = node.ChildL
  225. } else {
  226. nodeHash = node.ChildR
  227. }
  228. }
  229. _, _, valueBytes, err := mt.Get(nodeHash)
  230. if err != nil {
  231. return valueBytes, err
  232. }
  233. return valueBytes, nil
  234. }
  235. func calcHashFromLeafAndLevel(untilLevel int, path []bool, leafHash Hash) Hash {
  236. nodeCurrLevel := leafHash
  237. for i := 0; i < untilLevel; i++ {
  238. if path[i] {
  239. node := treeNode{
  240. ChildL: EmptyNodeValue,
  241. ChildR: nodeCurrLevel,
  242. }
  243. nodeCurrLevel = node.Ht()
  244. } else {
  245. node := treeNode{
  246. ChildL: nodeCurrLevel,
  247. ChildR: EmptyNodeValue,
  248. }
  249. nodeCurrLevel = node.Ht()
  250. }
  251. }
  252. return nodeCurrLevel
  253. }
  254. func (mt *MerkleTree) replaceLeaf(siblings []Hash, path []bool, newLeafHash Hash, nodetype byte, indexLength uint32, newLeafValue []byte) (Hash, error) {
  255. // add the new leaf
  256. mt.Insert(newLeafHash, nodetype, indexLength, newLeafValue)
  257. currNode := newLeafHash
  258. // here the path is only the path[posDiff+1]
  259. for i := 0; i < len(siblings); i++ {
  260. if !path[i] {
  261. node := treeNode{
  262. ChildL: currNode,
  263. ChildR: siblings[len(siblings)-1-i],
  264. }
  265. mt.Insert(node.Ht(), normalNodeType, 0, node.Bytes())
  266. currNode = node.Ht()
  267. } else {
  268. node := treeNode{
  269. ChildL: siblings[len(siblings)-1-i],
  270. ChildR: currNode,
  271. }
  272. mt.Insert(node.Ht(), normalNodeType, 0, node.Bytes())
  273. currNode = node.Ht()
  274. }
  275. }
  276. return currNode, nil // currNode = root
  277. }
  278. // CheckProof validates the Merkle Proof for the leafHash and root
  279. func CheckProof(root Hash, proof []byte, hi Hash, ht Hash, numLevels int) bool {
  280. var empties [32]byte
  281. copy(empties[:], proof[:len(empties)])
  282. hashLen := len(EmptyNodeValue)
  283. var siblings []Hash
  284. for i := len(empties); i < len(proof); i += hashLen {
  285. var siblingHash Hash
  286. copy(siblingHash[:], proof[i:i+hashLen])
  287. siblings = append(siblings, siblingHash)
  288. }
  289. path := getPath(numLevels, hi)
  290. nodeHash := ht
  291. siblingUsedPos := 0
  292. for level := numLevels - 2; level >= 0; level-- {
  293. var sibling Hash
  294. if testbitmap(empties[:], uint(level)) {
  295. sibling = siblings[siblingUsedPos]
  296. siblingUsedPos++
  297. } else {
  298. sibling = EmptyNodeValue
  299. }
  300. // calculate the nodeHash with the current nodeHash and the sibling
  301. var node treeNode
  302. if path[numLevels-level-2] {
  303. node = treeNode{
  304. ChildL: sibling,
  305. ChildR: nodeHash,
  306. }
  307. } else {
  308. node = treeNode{
  309. ChildL: nodeHash,
  310. ChildR: sibling,
  311. }
  312. }
  313. // if both childs are EmptyNodeValue, the parent will be EmptyNodeValue
  314. if bytes.Equal(nodeHash[:], EmptyNodeValue[:]) && bytes.Equal(sibling[:], EmptyNodeValue[:]) {
  315. nodeHash = EmptyNodeValue
  316. } else {
  317. nodeHash = node.Ht()
  318. }
  319. }
  320. return bytes.Equal(nodeHash[:], root[:])
  321. }