Fixed sql tx close, fixed unit tests. Refactoring. Added missing structs and methods.

This commit is contained in:
Oleksandr Brezhniev
2021-06-25 23:34:20 +03:00
parent cadeb222c6
commit 113995d6f4
13 changed files with 1531 additions and 1044 deletions

View File

@@ -1,6 +1,7 @@
package sql
import (
"crypto/sha256"
"database/sql"
"encoding/binary"
"errors"
@@ -29,7 +30,7 @@ type Storage struct {
type StorageTx struct {
*Storage
tx *sqlx.Tx
cache merkletree.KvMap
cache KvMap
currentRoot *merkletree.Hash
}
@@ -74,7 +75,7 @@ func (s *Storage) NewTx() (merkletree.Tx, error) {
if err != nil {
return nil, err
}
return &StorageTx{s, tx, make(merkletree.KvMap), s.currentRoot}, nil
return &StorageTx{s, tx, make(KvMap), s.currentRoot}, nil
}
// Get retrieves a value from a key in the db.Storage
@@ -167,7 +168,7 @@ func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
//fullKey := append(tx.mtId, k...)
fullKey := k
tx.cache.Put(fullKey, *v)
tx.cache.Put(tx.mtId, fullKey, *v)
fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
return nil
}
@@ -204,17 +205,13 @@ func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
// Add implements the method Add of the interface db.Tx
func (tx *StorageTx) Add(atx merkletree.Tx) error {
dbtx := atx.(*StorageTx)
//if !bytes.Equal(tx.prefix, dbtx.prefix) {
// // TODO: change cache to store prefix too!
// return errors.New("adding StorageTx with different prefix is not implemented")
//}
if tx.mtId != dbtx.mtId {
// TODO: change cache to store prefix too!
return errors.New("adding StorageTx with different prefix is not implemented")
}
for _, v := range dbtx.cache {
tx.cache.Put(v.K, v.V)
tx.cache.Put(v.MTId, v.K, v.V)
}
// TODO: change cache to store different currentRoots for different mtIds too!
tx.currentRoot = dbtx.currentRoot
return nil
}
@@ -246,7 +243,7 @@ func (tx *StorageTx) Commit() error {
if err != nil {
return err
}
_, err = tx.tx.Exec(upsertStmt, tx.mtId, key[:], node.Type, childL, childR, entry)
_, err = tx.tx.Exec(upsertStmt, v.MTId, key[:], node.Type, childL, childR, entry)
if err != nil {
return err
}
@@ -266,7 +263,7 @@ func (tx *StorageTx) Commit() error {
// Close implements the method Close of the interface db.Tx
func (tx *StorageTx) Close() {
//tx.tx.Rollback()
tx.tx.Rollback()
tx.cache = nil
}
@@ -313,3 +310,24 @@ func (item *NodeItem) Node() (*merkletree.Node, error) {
}
return &node, nil
}
// KV contains a key (K) and a value (V)
type KV struct {
MTId uint64
K []byte
V merkletree.Node
}
// KvMap is a key-value map between a sha256 byte array hash, and a KV struct
type KvMap map[[sha256.Size]byte]KV
// Get retrieves the value respective to a key from the KvMap
func (m KvMap) Get(k []byte) (merkletree.Node, bool) {
v, ok := m[sha256.Sum256(k)]
return v.V, ok
}
// Put stores a key and a value in the KvMap
func (m KvMap) Put(mtId uint64, k []byte, v merkletree.Node) {
m[sha256.Sum256(k)] = KV{mtId, k, v}
}

View File

@@ -4,11 +4,13 @@ import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/iden3/go-iden3-crypto/constants"
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
"github.com/iden3/go-merkletree"
"github.com/iden3/go-merkletree/db/memory"
"github.com/iden3/go-merkletree/db/test"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -18,7 +20,11 @@ import (
"testing"
)
func sqlStorage(t *testing.T) merkletree.Storage {
var maxMTId uint64 = 0
var cleared = false
func setupDB() (*sqlx.DB, error) {
var err error
host := os.Getenv("PGHOST")
if host == "" {
host = "localhost"
@@ -33,7 +39,7 @@ func sqlStorage(t *testing.T) merkletree.Storage {
}
password := os.Getenv("PGPASSWORD")
if password == "" {
panic("No PGPASSWORD envvar specified")
return nil, errors.New("No PGPASSWORD envvar specified")
}
dbname := os.Getenv("PGDATABASE")
if dbname == "" {
@@ -50,19 +56,34 @@ func sqlStorage(t *testing.T) merkletree.Storage {
)
dbx, err := sqlx.Connect("postgres", psqlconn)
if err != nil {
t.Fatal(err)
return nil
return nil, err
}
// clear MerkleTree table
//if !cleared {
dbx.Exec("TRUNCATE TABLE mt_roots")
dbx.Exec("TRUNCATE TABLE mt_nodes")
cleared = true
//}
return dbx, nil
}
func sqlStorage(t *testing.T) merkletree.Storage {
dbx, err := setupDB()
if err != nil {
t.Fatal(err)
return nil
}
sto, err := NewSqlStorage(dbx, false)
if err != nil {
t.Fatal(err)
return nil
}
sto.mtId = maxMTId
maxMTId++
t.Cleanup(func() {
})
@@ -70,26 +91,60 @@ func sqlStorage(t *testing.T) merkletree.Storage {
return sto
}
func TestReturnKnownErrIfNotExists(t *testing.T) {
test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
}
func TestStorageInsertGet(t *testing.T) {
test.TestStorageInsertGet(t, sqlStorage(t))
}
func TestStorageWithPrefix(t *testing.T) {
test.TestStorageWithPrefix(t, sqlStorage(t))
}
func TestSql(t *testing.T) {
//sto := sqlStorage(t)
//t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
// test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
//})
//t.Run("TestStorageInsertGet", func(t *testing.T) {
// test.TestStorageInsertGet(t, sqlStorage(t))
//})
//test.TestStorageWithPrefix(t, sqlStorage(t))
//test.TestConcatTx(t, sqlStorage(t))
//test.TestList(t, sqlStorage(t))
//test.TestIterate(t, sqlStorage(t))
t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
})
t.Run("TestStorageInsertGet", func(t *testing.T) {
test.TestStorageInsertGet(t, sqlStorage(t))
})
t.Run("TestStorageWithPrefix", func(t *testing.T) {
test.TestStorageWithPrefix(t, sqlStorage(t))
})
test.TestConcatTx(t, sqlStorage(t))
test.TestList(t, sqlStorage(t))
test.TestIterate(t, sqlStorage(t))
test.TestNewTree(t, sqlStorage(t))
test.TestAddDifferentOrder(t, sqlStorage(t), sqlStorage(t))
test.TestAddRepeatedIndex(t, sqlStorage(t))
test.TestGet(t, sqlStorage(t))
test.TestUpdate(t, sqlStorage(t))
test.TestUpdate2(t, sqlStorage(t))
test.TestGenerateAndVerifyProof128(t, sqlStorage(t))
test.TestTreeLimit(t, sqlStorage(t))
test.TestSiblingsFromProof(t, sqlStorage(t))
test.TestVerifyProofCases(t, sqlStorage(t))
test.TestVerifyProofFalse(t, sqlStorage(t))
test.TestGraphViz(t, sqlStorage(t))
test.TestDelete(t, sqlStorage(t))
test.TestDelete2(t, sqlStorage(t), sqlStorage(t))
test.TestDelete3(t, sqlStorage(t), sqlStorage(t))
test.TestDelete4(t, sqlStorage(t), sqlStorage(t))
test.TestDelete5(t, sqlStorage(t), sqlStorage(t))
test.TestDeleteNonExistingKeys(t, sqlStorage(t))
test.TestDumpLeafsImportLeafs(t, sqlStorage(t), sqlStorage(t))
test.TestAddAndGetCircomProof(t, sqlStorage(t))
test.TestUpdateCircomProcessorProof(t, sqlStorage(t))
test.TestSmtVerifier(t, sqlStorage(t))
test.TestTypesMarshalers(t, sqlStorage(t))
}
var debug = false
type Fatalable interface {
Fatal(args ...interface{})
}
func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree {
sto := sqlStorage(f)