mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-07 19:46:43 +01:00
Fixed sql tx close, fixed unit tests. Refactoring. Added missing structs and methods.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
@@ -29,7 +30,7 @@ type Storage struct {
|
||||
type StorageTx struct {
|
||||
*Storage
|
||||
tx *sqlx.Tx
|
||||
cache merkletree.KvMap
|
||||
cache KvMap
|
||||
currentRoot *merkletree.Hash
|
||||
}
|
||||
|
||||
@@ -74,7 +75,7 @@ func (s *Storage) NewTx() (merkletree.Tx, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StorageTx{s, tx, make(merkletree.KvMap), s.currentRoot}, nil
|
||||
return &StorageTx{s, tx, make(KvMap), s.currentRoot}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from a key in the db.Storage
|
||||
@@ -167,7 +168,7 @@ func (tx *StorageTx) Get(key []byte) (*merkletree.Node, error) {
|
||||
func (tx *StorageTx) Put(k []byte, v *merkletree.Node) error {
|
||||
//fullKey := append(tx.mtId, k...)
|
||||
fullKey := k
|
||||
tx.cache.Put(fullKey, *v)
|
||||
tx.cache.Put(tx.mtId, fullKey, *v)
|
||||
fmt.Printf("tx.Put(%x, %+v)\n", fullKey, v)
|
||||
return nil
|
||||
}
|
||||
@@ -204,17 +205,13 @@ func (tx *StorageTx) SetRoot(hash *merkletree.Hash) error {
|
||||
// Add implements the method Add of the interface db.Tx
|
||||
func (tx *StorageTx) Add(atx merkletree.Tx) error {
|
||||
dbtx := atx.(*StorageTx)
|
||||
//if !bytes.Equal(tx.prefix, dbtx.prefix) {
|
||||
// // TODO: change cache to store prefix too!
|
||||
// return errors.New("adding StorageTx with different prefix is not implemented")
|
||||
//}
|
||||
if tx.mtId != dbtx.mtId {
|
||||
// TODO: change cache to store prefix too!
|
||||
return errors.New("adding StorageTx with different prefix is not implemented")
|
||||
}
|
||||
for _, v := range dbtx.cache {
|
||||
tx.cache.Put(v.K, v.V)
|
||||
tx.cache.Put(v.MTId, v.K, v.V)
|
||||
}
|
||||
// TODO: change cache to store different currentRoots for different mtIds too!
|
||||
tx.currentRoot = dbtx.currentRoot
|
||||
return nil
|
||||
}
|
||||
@@ -246,7 +243,7 @@ func (tx *StorageTx) Commit() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.tx.Exec(upsertStmt, tx.mtId, key[:], node.Type, childL, childR, entry)
|
||||
_, err = tx.tx.Exec(upsertStmt, v.MTId, key[:], node.Type, childL, childR, entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -266,7 +263,7 @@ func (tx *StorageTx) Commit() error {
|
||||
|
||||
// Close implements the method Close of the interface db.Tx
|
||||
func (tx *StorageTx) Close() {
|
||||
//tx.tx.Rollback()
|
||||
tx.tx.Rollback()
|
||||
tx.cache = nil
|
||||
}
|
||||
|
||||
@@ -313,3 +310,24 @@ func (item *NodeItem) Node() (*merkletree.Node, error) {
|
||||
}
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// KV contains a key (K) and a value (V)
|
||||
type KV struct {
|
||||
MTId uint64
|
||||
K []byte
|
||||
V merkletree.Node
|
||||
}
|
||||
|
||||
// KvMap is a key-value map between a sha256 byte array hash, and a KV struct
|
||||
type KvMap map[[sha256.Size]byte]KV
|
||||
|
||||
// Get retrieves the value respective to a key from the KvMap
|
||||
func (m KvMap) Get(k []byte) (merkletree.Node, bool) {
|
||||
v, ok := m[sha256.Sum256(k)]
|
||||
return v.V, ok
|
||||
}
|
||||
|
||||
// Put stores a key and a value in the KvMap
|
||||
func (m KvMap) Put(mtId uint64, k []byte, v merkletree.Node) {
|
||||
m[sha256.Sum256(k)] = KV{mtId, k, v}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/iden3/go-iden3-crypto/constants"
|
||||
cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
|
||||
"github.com/iden3/go-merkletree"
|
||||
"github.com/iden3/go-merkletree/db/memory"
|
||||
"github.com/iden3/go-merkletree/db/test"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -18,7 +20,11 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func sqlStorage(t *testing.T) merkletree.Storage {
|
||||
var maxMTId uint64 = 0
|
||||
var cleared = false
|
||||
|
||||
func setupDB() (*sqlx.DB, error) {
|
||||
var err error
|
||||
host := os.Getenv("PGHOST")
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
@@ -33,7 +39,7 @@ func sqlStorage(t *testing.T) merkletree.Storage {
|
||||
}
|
||||
password := os.Getenv("PGPASSWORD")
|
||||
if password == "" {
|
||||
panic("No PGPASSWORD envvar specified")
|
||||
return nil, errors.New("No PGPASSWORD envvar specified")
|
||||
}
|
||||
dbname := os.Getenv("PGDATABASE")
|
||||
if dbname == "" {
|
||||
@@ -50,19 +56,34 @@ func sqlStorage(t *testing.T) merkletree.Storage {
|
||||
)
|
||||
dbx, err := sqlx.Connect("postgres", psqlconn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// clear MerkleTree table
|
||||
//if !cleared {
|
||||
dbx.Exec("TRUNCATE TABLE mt_roots")
|
||||
dbx.Exec("TRUNCATE TABLE mt_nodes")
|
||||
cleared = true
|
||||
//}
|
||||
|
||||
return dbx, nil
|
||||
}
|
||||
|
||||
func sqlStorage(t *testing.T) merkletree.Storage {
|
||||
|
||||
dbx, err := setupDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
sto, err := NewSqlStorage(dbx, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
}
|
||||
sto.mtId = maxMTId
|
||||
maxMTId++
|
||||
|
||||
t.Cleanup(func() {
|
||||
})
|
||||
@@ -70,26 +91,60 @@ func sqlStorage(t *testing.T) merkletree.Storage {
|
||||
return sto
|
||||
}
|
||||
|
||||
func TestReturnKnownErrIfNotExists(t *testing.T) {
|
||||
test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
|
||||
}
|
||||
|
||||
func TestStorageInsertGet(t *testing.T) {
|
||||
test.TestStorageInsertGet(t, sqlStorage(t))
|
||||
}
|
||||
|
||||
func TestStorageWithPrefix(t *testing.T) {
|
||||
test.TestStorageWithPrefix(t, sqlStorage(t))
|
||||
}
|
||||
|
||||
func TestSql(t *testing.T) {
|
||||
//sto := sqlStorage(t)
|
||||
//t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
|
||||
// test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
|
||||
//})
|
||||
//t.Run("TestStorageInsertGet", func(t *testing.T) {
|
||||
// test.TestStorageInsertGet(t, sqlStorage(t))
|
||||
//})
|
||||
//test.TestStorageWithPrefix(t, sqlStorage(t))
|
||||
//test.TestConcatTx(t, sqlStorage(t))
|
||||
//test.TestList(t, sqlStorage(t))
|
||||
//test.TestIterate(t, sqlStorage(t))
|
||||
t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
|
||||
test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
|
||||
})
|
||||
t.Run("TestStorageInsertGet", func(t *testing.T) {
|
||||
test.TestStorageInsertGet(t, sqlStorage(t))
|
||||
})
|
||||
t.Run("TestStorageWithPrefix", func(t *testing.T) {
|
||||
test.TestStorageWithPrefix(t, sqlStorage(t))
|
||||
})
|
||||
test.TestConcatTx(t, sqlStorage(t))
|
||||
test.TestList(t, sqlStorage(t))
|
||||
test.TestIterate(t, sqlStorage(t))
|
||||
|
||||
test.TestNewTree(t, sqlStorage(t))
|
||||
test.TestAddDifferentOrder(t, sqlStorage(t), sqlStorage(t))
|
||||
test.TestAddRepeatedIndex(t, sqlStorage(t))
|
||||
test.TestGet(t, sqlStorage(t))
|
||||
test.TestUpdate(t, sqlStorage(t))
|
||||
test.TestUpdate2(t, sqlStorage(t))
|
||||
test.TestGenerateAndVerifyProof128(t, sqlStorage(t))
|
||||
test.TestTreeLimit(t, sqlStorage(t))
|
||||
test.TestSiblingsFromProof(t, sqlStorage(t))
|
||||
test.TestVerifyProofCases(t, sqlStorage(t))
|
||||
test.TestVerifyProofFalse(t, sqlStorage(t))
|
||||
test.TestGraphViz(t, sqlStorage(t))
|
||||
test.TestDelete(t, sqlStorage(t))
|
||||
test.TestDelete2(t, sqlStorage(t), sqlStorage(t))
|
||||
test.TestDelete3(t, sqlStorage(t), sqlStorage(t))
|
||||
test.TestDelete4(t, sqlStorage(t), sqlStorage(t))
|
||||
test.TestDelete5(t, sqlStorage(t), sqlStorage(t))
|
||||
test.TestDeleteNonExistingKeys(t, sqlStorage(t))
|
||||
test.TestDumpLeafsImportLeafs(t, sqlStorage(t), sqlStorage(t))
|
||||
test.TestAddAndGetCircomProof(t, sqlStorage(t))
|
||||
test.TestUpdateCircomProcessorProof(t, sqlStorage(t))
|
||||
test.TestSmtVerifier(t, sqlStorage(t))
|
||||
test.TestTypesMarshalers(t, sqlStorage(t))
|
||||
}
|
||||
|
||||
var debug = false
|
||||
|
||||
type Fatalable interface {
|
||||
Fatal(args ...interface{})
|
||||
}
|
||||
|
||||
func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree {
|
||||
sto := sqlStorage(f)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user