You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

333 lines
8.2 KiB

3 years ago
3 years ago
  1. package sql
  2. import (
  3. "crypto/sha256"
  4. "database/sql"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "github.com/iden3/go-merkletree"
  9. "github.com/jmoiron/sqlx"
  10. _ "github.com/lib/pq"
  11. )
  12. // TODO: upsert or insert?
  13. const upsertStmt = `INSERT INTO mt_nodes (mt_id, key, type, child_l, child_r, entry) VALUES ($1, $2, $3, $4, $5, $6) ` +
  14. `ON CONFLICT (mt_id, key) DO UPDATE SET type = $3, child_l = $4, child_r = $5, entry = $6`
  15. const updateRootStmt = `INSERT INTO mt_roots (mt_id, key) VALUES ($1, $2) ` +
  16. `ON CONFLICT (mt_id) DO UPDATE SET key = $2`
  17. // Storage implements the db.Storage interface
  18. type Storage struct {
  19. db *sqlx.DB
  20. mtId uint64
  21. currentVersion uint64
  22. currentRoot *merkletree.Hash
  23. }
  24. // StorageTx implements the db.Tx interface
  25. type StorageTx struct {
  26. *Storage
  27. tx *sqlx.Tx
  28. cache KvMap
  29. currentRoot *merkletree.Hash
  30. }
  31. type NodeItem struct {
  32. MTId uint64 `db:"mt_id"`
  33. Key []byte `db:"key"`
  34. // Type is the type of node in the tree.
  35. Type byte `db:"type"`
  36. // ChildL is the left child of a middle node.
  37. ChildL []byte `db:"child_l"`
  38. // ChildR is the right child of a middle node.
  39. ChildR []byte `db:"child_r"`
  40. // Entry is the data stored in a leaf node.
  41. Entry []byte `db:"entry"`
  42. CreatedAt *uint64 `db:"created_at"`
  43. DeletedAt *uint64 `db:"deleted_at"`
  44. }
  45. type RootItem struct {
  46. MTId uint64 `db:"mt_id"`
  47. Key []byte `db:"key"`
  48. CreatedAt *uint64 `db:"created_at"`
  49. DeletedAt *uint64 `db:"deleted_at"`
  50. }
  51. // NewSqlStorage returns a new Storage
  52. func NewSqlStorage(db *sqlx.DB, errorIfMissing bool) (*Storage, error) {
  53. return &Storage{db: db}, nil
  54. }
  55. // WithPrefix implements the method WithPrefix of the interface db.Storage
  56. func (s *Storage) WithPrefix(prefix []byte) merkletree.Storage {
  57. //return &Storage{db: s.db, prefix: merkletree.Concat(s.prefix, prefix)}
  58. // TODO: remove WithPrefix method
  59. mtId, _ := binary.Uvarint(prefix)
  60. return &Storage{db: s.db, mtId: mtId}
  61. }
  62. // NewTx implements the method NewTx of the interface db.Storage
  63. func (s *Storage) NewTx() (merkletree.Tx, error) {
  64. tx, err := s.db.Beginx()
  65. if err != nil {
  66. return nil, err
  67. }
  68. return &StorageTx{s, tx, make(KvMap), s.currentRoot}, nil
  69. }
  70. // Get retrieves a value from a key in the db.Storage
  71. func (s *Storage) Get(key []byte) (*merkletree.Node, error) {
  72. item := NodeItem{}
  73. err := s.db.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", s.mtId, key)
  74. if err == sql.ErrNoRows {
  75. return nil, merkletree.ErrNotFound
  76. }
  77. if err != nil {
  78. return nil, err
  79. }
  80. node, err := item.Node()
  81. if err != nil {
  82. return nil, err
  83. }
  84. return node, nil
  85. }
  86. // GetRoot retrieves a merkle tree root hash in the interface db.Tx
  87. func (s *Storage) GetRoot() (*merkletree.Hash, error) {
  88. var root merkletree.Hash
  89. if s.currentRoot != nil {
  90. copy(root[:], s.currentRoot[:])
  91. return &root, nil
  92. }
  93. item := RootItem{}
  94. err := s.db.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", s.mtId)
  95. if err == sql.ErrNoRows {
  96. return nil, merkletree.ErrNotFound
  97. }
  98. if err != nil {
  99. return nil, err
  100. }
  101. copy(root[:], item.Key[:])
  102. return &root, nil
  103. }
  104. // Iterate implements the method Iterate of the interface db.Storage
  105. func (s *Storage) Iterate(f func([]byte, *merkletree.Node) (bool, error)) error {
  106. items := []NodeItem{}
  107. err := s.db.Select(&items, "SELECT * FROM mt_nodes WHERE key WHERE mt_id = $1", s.mtId)
  108. if err != nil {
  109. return err
  110. }
  111. for _, v := range items {
  112. k := v.Key[:]
  113. n, err := v.Node()
  114. if err != nil {
  115. return err
  116. }
  117. cont, err := f(k, n)
  118. if err != nil {
  119. return err
  120. }
  121. if !cont {
  122. break
  123. }
  124. }
  125. return nil
  126. }
  127. // Get retrieves a value from a key in the interface db.Tx
  128. func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
  129. //fullKey := append(tx.mtId, key...)
  130. fullKey := key
  131. if value, ok := tx.cache.Get(fullKey); ok {
  132. return &value, nil
  133. }
  134. item := NodeItem{}
  135. err := tx.tx.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", tx.mtId, key)
  136. if err == sql.ErrNoRows {
  137. return nil, merkletree.ErrNotFound
  138. }
  139. if err != nil {
  140. return nil, err
  141. }
  142. node, err := item.Node()
  143. if err != nil {
  144. return nil, err
  145. }
  146. return node, nil
  147. }
  148. // Put saves a key:value into the db.Storage
  149. func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
  150. //fullKey := append(tx.mtId, k...)
  151. fullKey := k
  152. tx.cache.Put(tx.mtId, fullKey, *v)
  153. fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
  154. return nil
  155. }
  156. // GetRoot retrieves a merkle tree root hash in the interface db.Tx
  157. func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
  158. var root merkletree.Hash
  159. if tx.currentRoot != nil {
  160. copy(root[:], tx.currentRoot[:])
  161. return &root, nil
  162. }
  163. item := RootItem{}
  164. err := tx.tx.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", tx.mtId)
  165. if err == sql.ErrNoRows {
  166. return nil, merkletree.ErrNotFound
  167. }
  168. if err != nil {
  169. return nil, err
  170. }
  171. copy(root[:], item.Key[:])
  172. return &root, nil
  173. }
  174. // SetRoot sets a hash of merkle tree root in the interface db.Tx
  175. func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
  176. root := &merkletree.Hash{}
  177. copy(root[:], hash[:])
  178. tx.currentRoot = root
  179. return nil
  180. }
  181. // Add implements the method Add of the interface db.Tx
  182. func (tx *StorageTx) Add(atx merkletree.Tx) error {
  183. dbtx := atx.(*StorageTx)
  184. if tx.mtId != dbtx.mtId {
  185. return errors.New("adding StorageTx with different prefix is not implemented")
  186. }
  187. for _, v := range dbtx.cache {
  188. tx.cache.Put(v.MTId, v.K, v.V)
  189. }
  190. // TODO: change cache to store different currentRoots for different mtIds too!
  191. tx.currentRoot = dbtx.currentRoot
  192. return nil
  193. }
  194. // Commit implements the method Commit of the interface db.Tx
  195. func (tx *StorageTx) Commit() error {
  196. // execute a query on the server
  197. fmt.Printf("Commit\n")
  198. for _, v := range tx.cache {
  199. fmt.Printf("key %x, value %+v\n", v.K, v.V)
  200. node := v.V
  201. var childL []byte
  202. if node.ChildL != nil {
  203. childL = append(childL, node.ChildL[:]...)
  204. }
  205. var childR []byte
  206. if node.ChildR != nil {
  207. childR = append(childR, node.ChildR[:]...)
  208. }
  209. var entry []byte
  210. if node.Entry[0] != nil && node.Entry[1] != nil {
  211. entry = append(node.Entry[0][:], node.Entry[1][:]...)
  212. }
  213. key, err := node.Key()
  214. if err != nil {
  215. return err
  216. }
  217. _, err = tx.tx.Exec(upsertStmt, v.MTId, key[:], node.Type, childL, childR, entry)
  218. if err != nil {
  219. return err
  220. }
  221. }
  222. if tx.currentRoot == nil {
  223. tx.currentRoot = &merkletree.Hash{}
  224. }
  225. _, err := tx.tx.Exec(updateRootStmt, tx.mtId, tx.currentRoot[:])
  226. if err != nil {
  227. return err
  228. }
  229. tx.cache = nil
  230. return tx.tx.Commit()
  231. }
  232. // Close implements the method Close of the interface db.Tx
  233. func (tx *StorageTx) Close() {
  234. tx.tx.Rollback()
  235. tx.cache = nil
  236. }
  237. // Close implements the method Close of the interface db.Storage
  238. func (s *Storage) Close() {
  239. err := s.db.Close()
  240. if err != nil {
  241. panic(err)
  242. }
  243. }
  244. // List implements the method List of the interface db.Storage
  245. func (s *Storage) List(limit int) ([]merkletree.KV, error) {
  246. ret := []merkletree.KV{}
  247. err := s.Iterate(func(key []byte, value *merkletree.Node) (bool, error) {
  248. ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: *value})
  249. if len(ret) == limit {
  250. return false, nil
  251. }
  252. return true, nil
  253. })
  254. return ret, err
  255. }
  256. func (item *NodeItem) Node() (*merkletree.Node, error) {
  257. node := merkletree.Node{
  258. Type: merkletree.NodeType(item.Type),
  259. }
  260. if item.ChildL != nil {
  261. node.ChildL = &merkletree.Hash{}
  262. copy(node.ChildL[:], item.ChildL[:])
  263. }
  264. if item.ChildR != nil {
  265. node.ChildR = &merkletree.Hash{}
  266. copy(node.ChildR[:], item.ChildR[:])
  267. }
  268. if len(item.Entry) > 0 {
  269. if len(item.Entry) != 2*merkletree.ElemBytesLen {
  270. return nil, merkletree.ErrNodeBytesBadSize
  271. }
  272. node.Entry = [2]*merkletree.Hash{{}, {}}
  273. copy(node.Entry[0][:], item.Entry[0:32])
  274. copy(node.Entry[1][:], item.Entry[32:64])
  275. }
  276. return &node, nil
  277. }
  278. // KV contains a key (K) and a value (V)
  279. type KV struct {
  280. MTId uint64
  281. K []byte
  282. V merkletree.Node
  283. }
  284. // KvMap is a key-value map between a sha256 byte array hash, and a KV struct
  285. type KvMap map[[sha256.Size]byte]KV
  286. // Get retrieves the value respective to a key from the KvMap
  287. func (m KvMap) Get(k []byte) (merkletree.Node, bool) {
  288. v, ok := m[sha256.Sum256(k)]
  289. return v.V, ok
  290. }
  291. // Put stores a key and a value in the KvMap
  292. func (m KvMap) Put(mtId uint64, k []byte, v merkletree.Node) {
  293. m[sha256.Sum256(k)] = KV{mtId, k, v}
  294. }