You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

415 lines
11 KiB

  1. package asmtree
  2. import (
  3. "bytes"
  4. "fmt"
  5. "path"
  6. "sync/atomic"
  7. "time"
  8. "go.vocdoni.io/dvote/censustree"
  9. "go.vocdoni.io/dvote/log"
  10. "git.sr.ht/~sircmpwn/go-bare"
  11. "github.com/p4u/asmt/db"
  12. asmt "github.com/p4u/asmt/smt"
  13. )
  14. // We use go-bare for export/import the trie. In order to support
  15. // big census (up to 8 Million entries) we need to increase the maximums.
  16. const bareMaxArrayLength uint64 = 1024 * 1014 * 8 // 8 Million
  17. const bareMaxUnmarshalBytes uint64 = bareMaxArrayLength * 32 * 2 // 512 MiB
  18. type Tree struct {
  19. Tree *asmt.Trie
  20. db db.DB
  21. public uint32
  22. lastAccessUnix int64 // a unix timestamp, used via sync/atomic
  23. size uint64
  24. snapshotRoot []byte // if not nil, this trie is considered an inmutable snapshot
  25. snapshotSize uint64
  26. }
  27. type Proof struct {
  28. Bitmap []byte
  29. Length int
  30. Siblings [][]byte
  31. Value []byte
  32. }
  33. type exportElement struct {
  34. Key []byte `bare:"key"`
  35. Value []byte `bare:"value"`
  36. }
  37. type exportData struct {
  38. Elements []exportElement `bare:"elements"`
  39. }
  40. const (
  41. MaxKeySize = 32
  42. MaxValueSize = 64
  43. dbRootPrefix = "this is the last root for the SMT tree"
  44. )
  45. // NewTree initializes a new AergoSMT tree following the censustree.Tree interface specification.
  46. func NewTree(name, storageDir string) (censustree.Tree, error) {
  47. tr := &Tree{}
  48. err := tr.Init(name, storageDir)
  49. return tr, err
  50. }
  51. // newTree opens or creates a merkle tree under the given storage.
  52. func newTree(name, storageDir string) (*asmt.Trie, db.DB, error) {
  53. dir := path.Join(storageDir, name)
  54. log.Debugf("creating new tree on %s", dir)
  55. d := db.NewDB(db.LevelImpl, dir)
  56. root := d.Get([]byte(dbRootPrefix))
  57. tr := asmt.NewTrie(root, asmt.Hasher, d)
  58. if root != nil {
  59. if err := tr.LoadCache(root); err != nil {
  60. return nil, nil, err
  61. }
  62. }
  63. return tr, d, nil
  64. }
  65. // Init initializes a new asmt tree
  66. func (t *Tree) Init(name, storageDir string) error {
  67. var err error
  68. t.Tree, t.db, err = newTree(name, storageDir)
  69. t.updateAccessTime()
  70. t.size = 0
  71. return err
  72. }
  73. func (t *Tree) MaxKeySize() int {
  74. return MaxKeySize
  75. }
  76. // LastAccess returns the last time the Tree was accessed, in the form of a unix
  77. // timestamp.
  78. func (t *Tree) LastAccess() int64 {
  79. return atomic.LoadInt64(&t.lastAccessUnix)
  80. }
  81. func (t *Tree) updateAccessTime() {
  82. atomic.StoreInt64(&t.lastAccessUnix, time.Now().Unix())
  83. }
  84. // Publish makes a merkle tree available for queries.
  85. // Application layer should check IsPublish() before considering the Tree available.
  86. func (t *Tree) Publish() {
  87. atomic.StoreUint32(&t.public, 1)
  88. }
  89. // UnPublish makes a merkle tree not available for queries
  90. func (t *Tree) UnPublish() {
  91. atomic.StoreUint32(&t.public, 0)
  92. }
  93. // IsPublic returns true if the tree is available
  94. func (t *Tree) IsPublic() bool {
  95. return atomic.LoadUint32(&t.public) == 1
  96. }
  97. // Commit saves permanently the tree on disk
  98. func (t *Tree) Commit() error {
  99. if t.snapshotRoot != nil {
  100. return fmt.Errorf("cannot commit to a snapshot trie")
  101. }
  102. err := t.Tree.Commit()
  103. if err != nil {
  104. return err
  105. }
  106. t.db.Set([]byte(dbRootPrefix), t.Root())
  107. return nil
  108. }
  109. // Add adds a new claim to the merkle tree
  110. // A claim is composed of two parts: index and value
  111. // 1.index is mandatory, the data will be used for indexing the claim into to merkle tree
  112. // 2.value is optional, the data will not affect the indexing
  113. func (t *Tree) Add(index, value []byte) error {
  114. t.updateAccessTime()
  115. if t.snapshotRoot != nil {
  116. return fmt.Errorf("cannot add to a snapshot trie")
  117. }
  118. if len(index) < 4 || len(index) > MaxKeySize {
  119. return fmt.Errorf("wrong key size: %d", len(index))
  120. }
  121. if len(value) > MaxValueSize {
  122. return fmt.Errorf("index or value claim data too big")
  123. }
  124. _, err := t.Tree.Update([][]byte{asmt.Hasher(index)}, [][]byte{asmt.Hasher(value)})
  125. if err != nil {
  126. return err
  127. }
  128. atomic.StoreUint64(&t.size, 0) // TBD: improve this
  129. return t.Commit()
  130. }
  131. // AddBatch adds a list of indexes and values.
  132. // The commit to disk is executed only once.
  133. // The values slince could be empty or as long as indexes.
  134. func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) {
  135. var wrongIndexes []int
  136. t.updateAccessTime()
  137. if t.snapshotRoot != nil {
  138. return wrongIndexes, fmt.Errorf("cannot add to a snapshot trie")
  139. }
  140. if len(values) > 0 && len(indexes) != len(values) {
  141. return wrongIndexes, fmt.Errorf("indexes and values have different size")
  142. }
  143. var value []byte
  144. for i, key := range indexes {
  145. if len(key) < 4 || len(key) > MaxKeySize {
  146. wrongIndexes = append(wrongIndexes, i)
  147. continue
  148. }
  149. value = nil
  150. if len(values) > 0 {
  151. if len(values[i]) > MaxValueSize {
  152. wrongIndexes = append(wrongIndexes, i)
  153. continue
  154. }
  155. value = values[i]
  156. }
  157. _, err := t.Tree.Update([][]byte{asmt.Hasher(key)}, [][]byte{asmt.Hasher(value)})
  158. if err != nil {
  159. return wrongIndexes, err
  160. }
  161. }
  162. atomic.StoreUint64(&t.size, 0) // TBD: improve this
  163. return wrongIndexes, t.Commit()
  164. }
  165. // Get returns the value of a key
  166. func (t *Tree) Get(key []byte) []byte { // Do something with error
  167. var value []byte
  168. if t.snapshotRoot != nil {
  169. value, _ = t.Tree.GetWithRoot(key, t.snapshotRoot)
  170. } else {
  171. value, _ = t.Tree.Get(key)
  172. }
  173. return value
  174. }
  175. // GenProof generates a merkle tree proof that can be later used on CheckProof() to validate it.
  176. func (t *Tree) GenProof(index, value []byte) ([]byte, error) {
  177. t.updateAccessTime()
  178. var err error
  179. var ap [][]byte
  180. var pvalue, bitmap []byte
  181. var length int
  182. var included bool
  183. key := asmt.Hasher(index)
  184. if t.snapshotRoot != nil {
  185. bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressedR(key,
  186. t.snapshotRoot)
  187. if err != nil {
  188. return nil, err
  189. }
  190. } else {
  191. bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressed(key)
  192. if err != nil {
  193. return nil, err
  194. }
  195. }
  196. if !included {
  197. return nil, nil
  198. }
  199. if !bytes.Equal(pvalue, asmt.Hasher(value)) {
  200. return nil, fmt.Errorf("incorrect value or key on genProof")
  201. }
  202. return bare.Marshal(&Proof{Bitmap: bitmap, Length: length, Siblings: ap, Value: pvalue})
  203. }
  204. // CheckProof validates a merkle proof and its data.
  205. func (t *Tree) CheckProof(index, value, root, mproof []byte) (bool, error) {
  206. t.updateAccessTime()
  207. p := Proof{}
  208. if err := bare.Unmarshal(mproof, &p); err != nil {
  209. return false, err
  210. }
  211. if !bytes.Equal(p.Value, asmt.Hasher(value)) {
  212. return false, fmt.Errorf("values mismatch %x != %x", p.Value, asmt.Hasher(value))
  213. }
  214. if root != nil {
  215. return t.Tree.VerifyInclusionWithRootC(
  216. root,
  217. p.Bitmap,
  218. asmt.Hasher(index),
  219. p.Value,
  220. p.Siblings,
  221. p.Length), nil
  222. }
  223. if t.snapshotRoot != nil {
  224. return t.Tree.VerifyInclusionWithRootC(
  225. t.snapshotRoot,
  226. p.Bitmap,
  227. asmt.Hasher(index),
  228. p.Value,
  229. p.Siblings,
  230. p.Length), nil
  231. }
  232. return t.Tree.VerifyInclusionC(
  233. p.Bitmap,
  234. asmt.Hasher(index),
  235. p.Value,
  236. p.Siblings,
  237. p.Length), nil
  238. }
  239. // Root returns the current root hash of the merkle tree
  240. func (t *Tree) Root() []byte {
  241. t.updateAccessTime()
  242. if t.snapshotRoot != nil {
  243. return t.snapshotRoot
  244. }
  245. return t.Tree.Root
  246. }
  247. // Dump returns the whole merkle tree serialized in a format that can be used on Import.
  248. // Byte seralization is performed using bare message protocol, it is a 40% size win over JSON.
  249. func (t *Tree) Dump(root []byte) ([]byte, error) {
  250. t.updateAccessTime()
  251. if root == nil && t.snapshotRoot != nil {
  252. root = t.snapshotRoot
  253. }
  254. dump := exportData{}
  255. t.iterateWithRoot(root, nil, func(k, v []byte) bool {
  256. ee := exportElement{Key: make([]byte, len(k)), Value: make([]byte, len(v))}
  257. // Copy elements since it's not safe to hold on to the []byte values from Iterate
  258. copy(ee.Key, k[:])
  259. copy(ee.Value, v[:])
  260. dump.Elements = append(dump.Elements, ee)
  261. return false
  262. })
  263. bare.MaxArrayLength(bareMaxArrayLength)
  264. bare.MaxUnmarshalBytes(bareMaxUnmarshalBytes)
  265. return bare.Marshal(&dump)
  266. }
  267. // String returns a human readable representation of the tree.
  268. func (t *Tree) String() string {
  269. s := bytes.Buffer{}
  270. t.iterate(t.snapshotRoot, func(k, v []byte) bool {
  271. s.WriteString(fmt.Sprintf("%x => %x\n", k, v))
  272. return false
  273. })
  274. return s.String()
  275. }
  276. // Size returns the number of leaf nodes on the merkle tree.
  277. // TO-DO: root is currently ignored
  278. func (t *Tree) Size(root []byte) (int64, error) {
  279. if t.snapshotRoot != nil {
  280. return int64(t.snapshotSize), nil
  281. }
  282. return int64(t.count()), nil
  283. }
  284. // DumpPlain returns the entire list of added claims for a specific root hash.
  285. // First return parametre are the indexes and second the values.
  286. // If root is not specified, the last one is used.
  287. func (t *Tree) DumpPlain(root []byte) ([][]byte, [][]byte, error) {
  288. var indexes, values [][]byte
  289. var err error
  290. t.updateAccessTime()
  291. t.iterateWithRoot(root, nil, func(k, v []byte) bool {
  292. indexes = append(indexes, k)
  293. values = append(values, v)
  294. return false
  295. })
  296. return indexes, values, err
  297. }
  298. // ImportDump imports a partial or whole tree previously exported with Dump()
  299. func (t *Tree) ImportDump(data []byte) error {
  300. t.updateAccessTime()
  301. if t.snapshotRoot != nil {
  302. return fmt.Errorf("cannot import to a snapshot")
  303. }
  304. census := new(exportData)
  305. bare.MaxArrayLength(bareMaxArrayLength)
  306. bare.MaxUnmarshalBytes(bareMaxUnmarshalBytes)
  307. if err := bare.Unmarshal(data, census); err != nil {
  308. return fmt.Errorf("importdump cannot unmarshal data: %w", err)
  309. }
  310. keys := [][]byte{}
  311. values := [][]byte{}
  312. for _, ee := range census.Elements {
  313. keys = append(keys, ee.Key)
  314. values = append(values, ee.Value)
  315. }
  316. _, err := t.Tree.Update(keys, values)
  317. if err != nil {
  318. return err
  319. }
  320. atomic.StoreUint64(&t.size, 0) // TBD: improve this
  321. return t.Commit()
  322. }
  323. // Snapshot returns a Tree instance of a exiting merkle root.
  324. // A Snapshot cannot be modified.
  325. func (t *Tree) Snapshot(root []byte) (censustree.Tree, error) {
  326. exist, err := t.HashExists(root)
  327. if err != nil {
  328. return nil, err
  329. }
  330. if !exist {
  331. return nil, fmt.Errorf("root %x does not exist, cannot build snapshot", root)
  332. }
  333. return &Tree{Tree: t.Tree, public: t.public, snapshotRoot: root, snapshotSize: t.count()}, nil
  334. }
  335. func (t *Tree) Close() error {
  336. t.db.Close()
  337. return nil
  338. }
  339. // HashExists checks if a hash exists as a node in the merkle tree
  340. func (t *Tree) HashExists(hash []byte) (bool, error) {
  341. t.updateAccessTime()
  342. return t.Tree.TrieRootExists(hash), nil
  343. }
  344. func (t *Tree) count() uint64 {
  345. if v := atomic.LoadUint64(&t.size); v != 0 {
  346. return v
  347. }
  348. counter := uint64(0)
  349. if err := t.Tree.Walk(t.snapshotRoot, func(*asmt.WalkResult) int32 {
  350. counter++
  351. return 0
  352. }); err != nil {
  353. return 0
  354. }
  355. atomic.StoreUint64(&t.size, counter)
  356. return counter
  357. }
  358. func (t *Tree) iterate(prefix []byte, callback func(key, value []byte) bool) {
  359. t.Tree.Walk(t.snapshotRoot, func(v *asmt.WalkResult) int32 {
  360. if callback(v.Key, v.Value) {
  361. return 1
  362. } else {
  363. return 0
  364. }
  365. })
  366. }
  367. func (t *Tree) iterateWithRoot(root, prefix []byte, callback func(key, value []byte) bool) {
  368. t.Tree.Walk(root, func(v *asmt.WalkResult) int32 {
  369. if callback(v.Key, v.Value) {
  370. return 1
  371. } else {
  372. return 0
  373. }
  374. })
  375. }