You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
7.7 KiB

  1. package sql
  2. import (
  3. "database/sql"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "github.com/iden3/go-merkletree"
  8. "github.com/jmoiron/sqlx"
  9. _ "github.com/lib/pq"
  10. )
  11. // TODO: upsert or insert?
  12. const upsertStmt = `INSERT INTO mt_nodes (mt_id, key, type, child_l, child_r, entry) VALUES ($1, $2, $3, $4, $5, $6) ` +
  13. `ON CONFLICT (mt_id, key) DO UPDATE SET type = $3, child_l = $4, child_r = $5, entry = $6`
  14. const updateRootStmt = `INSERT INTO mt_roots (mt_id, key) VALUES ($1, $2) ` +
  15. `ON CONFLICT (mt_id) DO UPDATE SET key = $2`
  16. // Storage implements the db.Storage interface
  17. type Storage struct {
  18. db *sqlx.DB
  19. mtId uint64
  20. currentVersion uint64
  21. currentRoot *merkletree.Hash
  22. }
  23. // StorageTx implements the db.Tx interface
  24. type StorageTx struct {
  25. *Storage
  26. tx *sqlx.Tx
  27. cache merkletree.KvMap
  28. currentRoot *merkletree.Hash
  29. }
  30. type NodeItem struct {
  31. MTId uint64 `db:"mt_id"`
  32. Key []byte `db:"key"`
  33. // Type is the type of node in the tree.
  34. Type byte `db:"type"`
  35. // ChildL is the left child of a middle node.
  36. ChildL []byte `db:"child_l"`
  37. // ChildR is the right child of a middle node.
  38. ChildR []byte `db:"child_r"`
  39. // Entry is the data stored in a leaf node.
  40. Entry []byte `db:"entry"`
  41. CreatedAt *uint64 `db:"created_at"`
  42. DeletedAt *uint64 `db:"deleted_at"`
  43. }
  44. type RootItem struct {
  45. MTId uint64 `db:"mt_id"`
  46. Key []byte `db:"key"`
  47. CreatedAt *uint64 `db:"created_at"`
  48. DeletedAt *uint64 `db:"deleted_at"`
  49. }
  50. // NewSqlStorage returns a new Storage
  51. func NewSqlStorage(db *sqlx.DB, errorIfMissing bool) (*Storage, error) {
  52. return &Storage{db: db}, nil
  53. }
  54. // WithPrefix implements the method WithPrefix of the interface db.Storage
  55. func (s *Storage) WithPrefix(prefix []byte) merkletree.Storage {
  56. //return &Storage{db: s.db, prefix: merkletree.Concat(s.prefix, prefix)}
  57. // TODO: remove WithPrefix method
  58. mtId := s.mtId<<4 | binary.LittleEndian.Uint64(prefix)
  59. return &Storage{db: s.db, mtId: mtId}
  60. }
  61. // NewTx implements the method NewTx of the interface db.Storage
  62. func (s *Storage) NewTx() (merkletree.Tx, error) {
  63. tx, err := s.db.Beginx()
  64. if err != nil {
  65. return nil, err
  66. }
  67. return &StorageTx{s, tx, make(merkletree.KvMap), s.currentRoot}, nil
  68. }
  69. // Get retrieves a value from a key in the db.Storage
  70. func (s *Storage) Get(key []byte) (*merkletree.Node, error) {
  71. item := NodeItem{}
  72. err := s.db.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", s.mtId, key)
  73. if err == sql.ErrNoRows {
  74. return nil, merkletree.ErrNotFound
  75. }
  76. if err != nil {
  77. return nil, err
  78. }
  79. node, err := item.Node()
  80. if err != nil {
  81. return nil, err
  82. }
  83. return node, nil
  84. }
  85. // GetRoot retrieves a merkle tree root hash in the interface db.Tx
  86. func (s *Storage) GetRoot() (*merkletree.Hash, error) {
  87. var root merkletree.Hash
  88. if s.currentRoot != nil {
  89. copy(root[:], s.currentRoot[:])
  90. return &root, nil
  91. }
  92. item := RootItem{}
  93. err := s.db.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", s.mtId)
  94. if err == sql.ErrNoRows {
  95. return nil, merkletree.ErrNotFound
  96. }
  97. if err != nil {
  98. return nil, err
  99. }
  100. copy(root[:], item.Key[:])
  101. return &root, nil
  102. }
  103. // Iterate implements the method Iterate of the interface db.Storage
  104. func (s *Storage) Iterate(f func([]byte, *merkletree.Node) (bool, error)) error {
  105. items := []NodeItem{}
  106. err := s.db.Select(&items, "SELECT * FROM mt_nodes WHERE key WHERE mt_id = $1", s.mtId)
  107. if err != nil {
  108. return err
  109. }
  110. for _, v := range items {
  111. k := v.Key[:]
  112. n, err := v.Node()
  113. if err != nil {
  114. return err
  115. }
  116. cont, err := f(k, n)
  117. if err != nil {
  118. return err
  119. }
  120. if !cont {
  121. break
  122. }
  123. }
  124. return nil
  125. }
  126. // Get retrieves a value from a key in the interface db.Tx
  127. func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
  128. //fullKey := append(tx.mtId, key...)
  129. fullKey := key
  130. if value, ok := tx.cache.Get(fullKey); ok {
  131. return &value, nil
  132. }
  133. item := NodeItem{}
  134. err := tx.tx.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", tx.mtId, key)
  135. if err == sql.ErrNoRows {
  136. return nil, merkletree.ErrNotFound
  137. }
  138. if err != nil {
  139. return nil, err
  140. }
  141. node, err := item.Node()
  142. if err != nil {
  143. return nil, err
  144. }
  145. return node, nil
  146. }
  147. // Put saves a key:value into the db.Storage
  148. func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
  149. //fullKey := append(tx.mtId, k...)
  150. fullKey := k
  151. tx.cache.Put(fullKey, *v)
  152. fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
  153. return nil
  154. }
  155. // GetRoot retrieves a merkle tree root hash in the interface db.Tx
  156. func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
  157. var root merkletree.Hash
  158. if tx.currentRoot != nil {
  159. copy(root[:], tx.currentRoot[:])
  160. return &root, nil
  161. }
  162. item := RootItem{}
  163. err := tx.tx.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", tx.mtId)
  164. if err == sql.ErrNoRows {
  165. return nil, merkletree.ErrNotFound
  166. }
  167. if err != nil {
  168. return nil, err
  169. }
  170. copy(root[:], item.Key[:])
  171. return &root, nil
  172. }
  173. // SetRoot sets a hash of merkle tree root in the interface db.Tx
  174. func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
  175. root := &merkletree.Hash{}
  176. copy(root[:], hash[:])
  177. tx.currentRoot = root
  178. return nil
  179. }
  180. // Add implements the method Add of the interface db.Tx
  181. func (tx *StorageTx) Add(atx merkletree.Tx) error {
  182. dbtx := atx.(*StorageTx)
  183. //if !bytes.Equal(tx.prefix, dbtx.prefix) {
  184. // // TODO: change cache to store prefix too!
  185. // return errors.New("adding StorageTx with different prefix is not implemented")
  186. //}
  187. if tx.mtId != dbtx.mtId {
  188. // TODO: change cache to store prefix too!
  189. return errors.New("adding StorageTx with different prefix is not implemented")
  190. }
  191. for _, v := range dbtx.cache {
  192. tx.cache.Put(v.K, v.V)
  193. }
  194. tx.currentRoot = dbtx.currentRoot
  195. return nil
  196. }
  197. // Commit implements the method Commit of the interface db.Tx
  198. func (tx *StorageTx) Commit() error {
  199. // execute a query on the server
  200. fmt.Printf("Commit\n")
  201. for _, v := range tx.cache {
  202. fmt.Printf("key %x, value %+v\n", v.K, v.V)
  203. node := v.V
  204. var childL []byte
  205. if node.ChildL != nil {
  206. childL = append(childL, node.ChildL[:]...)
  207. }
  208. var childR []byte
  209. if node.ChildR != nil {
  210. childR = append(childR, node.ChildR[:]...)
  211. }
  212. var entry []byte
  213. if node.Entry[0] != nil && node.Entry[1] != nil {
  214. entry = append(node.Entry[0][:], node.Entry[1][:]...)
  215. }
  216. key, err := node.Key()
  217. if err != nil {
  218. return err
  219. }
  220. _, err = tx.tx.Exec(upsertStmt, tx.mtId, key[:], node.Type, childL, childR, entry)
  221. if err != nil {
  222. return err
  223. }
  224. }
  225. _, err := tx.tx.Exec(updateRootStmt, tx.mtId, tx.currentRoot[:])
  226. if err != nil {
  227. return err
  228. }
  229. tx.cache = nil
  230. return tx.tx.Commit()
  231. }
  232. // Close implements the method Close of the interface db.Tx
  233. func (tx *StorageTx) Close() {
  234. //tx.tx.Rollback()
  235. tx.cache = nil
  236. }
  237. // Close implements the method Close of the interface db.Storage
  238. func (s *Storage) Close() {
  239. err := s.db.Close()
  240. if err != nil {
  241. panic(err)
  242. }
  243. }
  244. // List implements the method List of the interface db.Storage
  245. func (s *Storage) List(limit int) ([]merkletree.KV, error) {
  246. ret := []merkletree.KV{}
  247. err := s.Iterate(func(key []byte, value *merkletree.Node) (bool, error) {
  248. ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: *value})
  249. if len(ret) == limit {
  250. return false, nil
  251. }
  252. return true, nil
  253. })
  254. return ret, err
  255. }
  256. func (item *NodeItem) Node() (*merkletree.Node, error) {
  257. node := merkletree.Node{
  258. Type: merkletree.NodeType(item.Type),
  259. }
  260. if item.ChildL != nil {
  261. node.ChildL = &merkletree.Hash{}
  262. copy(node.ChildL[:], item.ChildL[:])
  263. }
  264. if item.ChildR != nil {
  265. node.ChildR = &merkletree.Hash{}
  266. copy(node.ChildR[:], item.ChildR[:])
  267. }
  268. if len(item.Entry) > 0 {
  269. if len(item.Entry) != 2*merkletree.ElemBytesLen {
  270. return nil, merkletree.ErrNodeBytesBadSize
  271. }
  272. node.Entry = [2]*merkletree.Hash{{}, {}}
  273. copy(node.Entry[0][:], item.Entry[0:32])
  274. copy(node.Entry[1][:], item.Entry[32:64])
  275. }
  276. return &node, nil
  277. }