package sql
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/iden3/go-merkletree"
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
// TODO: upsert or insert?
|
|
const upsertStmt = `INSERT INTO mt_nodes (mt_id, key, type, child_l, child_r, entry) VALUES ($1, $2, $3, $4, $5, $6) ` +
|
|
`ON CONFLICT (mt_id, key) DO UPDATE SET type = $3, child_l = $4, child_r = $5, entry = $6`
|
|
|
|
const updateRootStmt = `INSERT INTO mt_roots (mt_id, key) VALUES ($1, $2) ` +
|
|
`ON CONFLICT (mt_id) DO UPDATE SET key = $2`
|
|
|
|
// Storage implements the db.Storage interface
|
|
type Storage struct {
|
|
db *sqlx.DB
|
|
mtId uint64
|
|
currentVersion uint64
|
|
currentRoot *merkletree.Hash
|
|
}
|
|
|
|
// StorageTx implements the db.Tx interface
|
|
type StorageTx struct {
|
|
*Storage
|
|
tx *sqlx.Tx
|
|
cache KvMap
|
|
currentRoot *merkletree.Hash
|
|
}
|
|
|
|
type NodeItem struct {
|
|
MTId uint64 `db:"mt_id"`
|
|
Key []byte `db:"key"`
|
|
// Type is the type of node in the tree.
|
|
Type byte `db:"type"`
|
|
// ChildL is the left child of a middle node.
|
|
ChildL []byte `db:"child_l"`
|
|
// ChildR is the right child of a middle node.
|
|
ChildR []byte `db:"child_r"`
|
|
// Entry is the data stored in a leaf node.
|
|
Entry []byte `db:"entry"`
|
|
CreatedAt *uint64 `db:"created_at"`
|
|
DeletedAt *uint64 `db:"deleted_at"`
|
|
}
|
|
|
|
type RootItem struct {
|
|
MTId uint64 `db:"mt_id"`
|
|
Key []byte `db:"key"`
|
|
CreatedAt *uint64 `db:"created_at"`
|
|
DeletedAt *uint64 `db:"deleted_at"`
|
|
}
|
|
|
|
// NewSqlStorage returns a new Storage
|
|
func NewSqlStorage(db *sqlx.DB, errorIfMissing bool) (*Storage, error) {
|
|
return &Storage{db: db}, nil
|
|
}
|
|
|
|
// WithPrefix implements the method WithPrefix of the interface db.Storage
|
|
func (s *Storage) WithPrefix(prefix []byte) merkletree.Storage {
|
|
//return &Storage{db: s.db, prefix: merkletree.Concat(s.prefix, prefix)}
|
|
// TODO: remove WithPrefix method
|
|
mtId, _ := binary.Uvarint(prefix)
|
|
return &Storage{db: s.db, mtId: mtId}
|
|
}
|
|
|
|
// NewTx implements the method NewTx of the interface db.Storage
|
|
func (s *Storage) NewTx() (merkletree.Tx, error) {
|
|
tx, err := s.db.Beginx()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &StorageTx{s, tx, make(KvMap), s.currentRoot}, nil
|
|
}
|
|
|
|
// Get retrieves a value from a key in the db.Storage
|
|
func (s *Storage) Get(key []byte) (*merkletree.Node, error) {
|
|
item := NodeItem{}
|
|
err := s.db.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", s.mtId, key)
|
|
if err == sql.ErrNoRows {
|
|
return nil, merkletree.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
node, err := item.Node()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return node, nil
|
|
}
|
|
|
|
// GetRoot retrieves a merkle tree root hash in the interface db.Tx
|
|
func (s *Storage) GetRoot() (*merkletree.Hash, error) {
|
|
var root merkletree.Hash
|
|
|
|
if s.currentRoot != nil {
|
|
copy(root[:], s.currentRoot[:])
|
|
return &root, nil
|
|
}
|
|
|
|
item := RootItem{}
|
|
err := s.db.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", s.mtId)
|
|
if err == sql.ErrNoRows {
|
|
return nil, merkletree.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
copy(root[:], item.Key[:])
|
|
return &root, nil
|
|
}
|
|
|
|
// Iterate implements the method Iterate of the interface db.Storage
|
|
func (s *Storage) Iterate(f func([]byte, *merkletree.Node) (bool, error)) error {
|
|
items := []NodeItem{}
|
|
|
|
err := s.db.Select(&items, "SELECT * FROM mt_nodes WHERE key WHERE mt_id = $1", s.mtId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, v := range items {
|
|
k := v.Key[:]
|
|
n, err := v.Node()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cont, err := f(k, n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !cont {
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a value from a key in the interface db.Tx
|
|
func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
|
|
//fullKey := append(tx.mtId, key...)
|
|
fullKey := key
|
|
if value, ok := tx.cache.Get(fullKey); ok {
|
|
return &value, nil
|
|
}
|
|
|
|
item := NodeItem{}
|
|
err := tx.tx.Get(&item, "SELECT * FROM mt_nodes WHERE mt_id = $1 AND key = $2", tx.mtId, key)
|
|
if err == sql.ErrNoRows {
|
|
return nil, merkletree.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
node, err := item.Node()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return node, nil
|
|
}
|
|
|
|
// Put saves a key:value into the db.Storage
|
|
func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
|
|
//fullKey := append(tx.mtId, k...)
|
|
fullKey := k
|
|
tx.cache.Put(tx.mtId, fullKey, *v)
|
|
fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
|
|
return nil
|
|
}
|
|
|
|
// GetRoot retrieves a merkle tree root hash in the interface db.Tx
|
|
func (tx *StorageTx) GetRoot() (*merkletree.Hash, error) {
|
|
var root merkletree.Hash
|
|
|
|
if tx.currentRoot != nil {
|
|
copy(root[:], tx.currentRoot[:])
|
|
return &root, nil
|
|
}
|
|
|
|
item := RootItem{}
|
|
err := tx.tx.Get(&item, "SELECT * FROM mt_roots WHERE mt_id = $1", tx.mtId)
|
|
if err == sql.ErrNoRows {
|
|
return nil, merkletree.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
copy(root[:], item.Key[:])
|
|
return &root, nil
|
|
}
|
|
|
|
// SetRoot sets a hash of merkle tree root in the interface db.Tx
|
|
func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
|
|
root := &merkletree.Hash{}
|
|
copy(root[:], hash[:])
|
|
tx.currentRoot = root
|
|
return nil
|
|
}
|
|
|
|
// Add implements the method Add of the interface db.Tx
|
|
func (tx *StorageTx) Add(atx merkletree.Tx) error {
|
|
dbtx := atx.(*StorageTx)
|
|
if tx.mtId != dbtx.mtId {
|
|
return errors.New("adding StorageTx with different prefix is not implemented")
|
|
}
|
|
for _, v := range dbtx.cache {
|
|
tx.cache.Put(v.MTId, v.K, v.V)
|
|
}
|
|
// TODO: change cache to store different currentRoots for different mtIds too!
|
|
tx.currentRoot = dbtx.currentRoot
|
|
return nil
|
|
}
|
|
|
|
// Commit implements the method Commit of the interface db.Tx
|
|
func (tx *StorageTx) Commit() error {
|
|
// execute a query on the server
|
|
fmt.Printf("Commit\n")
|
|
for _, v := range tx.cache {
|
|
fmt.Printf("key %x, value %+v\n", v.K, v.V)
|
|
node := v.V
|
|
|
|
var childL []byte
|
|
if node.ChildL != nil {
|
|
childL = append(childL, node.ChildL[:]...)
|
|
}
|
|
|
|
var childR []byte
|
|
if node.ChildR != nil {
|
|
childR = append(childR, node.ChildR[:]...)
|
|
}
|
|
|
|
var entry []byte
|
|
if node.Entry[0] != nil && node.Entry[1] != nil {
|
|
entry = append(node.Entry[0][:], node.Entry[1][:]...)
|
|
}
|
|
|
|
key, err := node.Key()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = tx.tx.Exec(upsertStmt, v.MTId, key[:], node.Type, childL, childR, entry)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if tx.currentRoot == nil {
|
|
tx.currentRoot = &merkletree.Hash{}
|
|
}
|
|
_, err := tx.tx.Exec(updateRootStmt, tx.mtId, tx.currentRoot[:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tx.cache = nil
|
|
return tx.tx.Commit()
|
|
}
|
|
|
|
// Close implements the method Close of the interface db.Tx
|
|
func (tx *StorageTx) Close() {
|
|
tx.tx.Rollback()
|
|
tx.cache = nil
|
|
}
|
|
|
|
// Close implements the method Close of the interface db.Storage
|
|
func (s *Storage) Close() {
|
|
err := s.db.Close()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
// List implements the method List of the interface db.Storage
|
|
func (s *Storage) List(limit int) ([]merkletree.KV, error) {
|
|
ret := []merkletree.KV{}
|
|
err := s.Iterate(func(key []byte, value *merkletree.Node) (bool, error) {
|
|
ret = append(ret, merkletree.KV{K: merkletree.Clone(key), V: *value})
|
|
if len(ret) == limit {
|
|
return false, nil
|
|
}
|
|
return true, nil
|
|
})
|
|
return ret, err
|
|
}
|
|
|
|
func (item *NodeItem) Node() (*merkletree.Node, error) {
|
|
node := merkletree.Node{
|
|
Type: merkletree.NodeType(item.Type),
|
|
}
|
|
if item.ChildL != nil {
|
|
node.ChildL = &merkletree.Hash{}
|
|
copy(node.ChildL[:], item.ChildL[:])
|
|
}
|
|
if item.ChildR != nil {
|
|
node.ChildR = &merkletree.Hash{}
|
|
copy(node.ChildR[:], item.ChildR[:])
|
|
}
|
|
if len(item.Entry) > 0 {
|
|
if len(item.Entry) != 2*merkletree.ElemBytesLen {
|
|
return nil, merkletree.ErrNodeBytesBadSize
|
|
}
|
|
node.Entry = [2]*merkletree.Hash{{}, {}}
|
|
copy(node.Entry[0][:], item.Entry[0:32])
|
|
copy(node.Entry[1][:], item.Entry[32:64])
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
// KV contains a key (K) and a value (V)
|
|
type KV struct {
|
|
MTId uint64
|
|
K []byte
|
|
V merkletree.Node
|
|
}
|
|
|
|
// KvMap is a key-value map between a sha256 byte array hash, and a KV struct
|
|
type KvMap map[[sha256.Size]byte]KV
|
|
|
|
// Get retrieves the value respective to a key from the KvMap
|
|
func (m KvMap) Get(k []byte) (merkletree.Node, bool) {
|
|
v, ok := m[sha256.Sum256(k)]
|
|
return v.V, ok
|
|
}
|
|
|
|
// Put stores a key and a value in the KvMap
|
|
func (m KvMap) Put(mtId uint64, k []byte, v merkletree.Node) {
|
|
m[sha256.Sum256(k)] = KV{mtId, k, v}
|
|
}
|