diff --git a/db/db.go b/db/db.go index 6faccab..3d20511 100644 --- a/db/db.go +++ b/db/db.go @@ -27,9 +27,14 @@ type Storage interface { // the merkletree storage. Examples of the interface implementation can be // found at db/memory and db/leveldb directories. type Tx interface { + // Get retreives the value for the given key + // looking first in the content of the Tx, and + // then into the content of the Storage Get([]byte) ([]byte, error) - Put(k, v []byte) - Add(Tx) + // Put sets the key & value into the Tx + Put(k, v []byte) error + // Add adds the given Tx into the Tx + Add(Tx) error Commit() error Close() } diff --git a/db/leveldb/leveldb.go b/db/leveldb/leveldb.go index f5d9cf6..91738f3 100644 --- a/db/leveldb/leveldb.go +++ b/db/leveldb/leveldb.go @@ -130,16 +130,18 @@ func (tx *LevelDbStorageTx) Get(key []byte) ([]byte, error) { } // Put saves a key:value into the db.Storage -func (tx *LevelDbStorageTx) Put(k, v []byte) { +func (tx *LevelDbStorageTx) Put(k, v []byte) error { tx.cache.Put(db.Concat(tx.prefix, k[:]), v) + return nil } // Add implements the method Add of the interface db.Tx -func (tx *LevelDbStorageTx) Add(atx db.Tx) { +func (tx *LevelDbStorageTx) Add(atx db.Tx) error { ldbtx := atx.(*LevelDbStorageTx) for _, v := range ldbtx.cache { tx.cache.Put(v.K, v.V) } + return nil } // Commit implements the method Commit of the interface db.Tx diff --git a/db/memory/memory.go b/db/memory/memory.go index f90a4ff..7e6d020 100644 --- a/db/memory/memory.go +++ b/db/memory/memory.go @@ -83,8 +83,9 @@ func (tx *MemoryStorageTx) Get(key []byte) ([]byte, error) { } // Put implements the method Put of the interface db.Tx -func (tx *MemoryStorageTx) Put(k, v []byte) { +func (tx *MemoryStorageTx) Put(k, v []byte) error { tx.kv.Put(db.Concat(tx.s.prefix, k), v) + return nil } // Commit implements the method Commit of the interface db.Tx @@ -97,11 +98,12 @@ func (tx *MemoryStorageTx) Commit() error { } // Add implements the method Add of the interface db.Tx -func (tx *MemoryStorageTx) Add(atx db.Tx) { +func (tx *MemoryStorageTx) Add(atx db.Tx) error { mstx := atx.(*MemoryStorageTx) for _, v := range mstx.kv { tx.kv.Put(v.K, v.V) } + return nil } // Close implements the method Close of the interface db.Tx diff --git a/db/pebble/pebble.go b/db/pebble/pebble.go index 8aec675..8babc34 100644 --- a/db/pebble/pebble.go +++ b/db/pebble/pebble.go @@ -17,10 +17,8 @@ type PebbleStorage struct { // PebbleStorageTx implements the db.Tx interface type PebbleStorageTx struct { - // FUTURE currently Tx is using the same strategy than in MemoryDB and - // LevelDB, in next iteration can be moved to Pebble Batch strategy *PebbleStorage - cache db.KvMap + batch *pebble.Batch } // NewPebbleStorage returns a new PebbleStorage @@ -72,7 +70,7 @@ func (p *PebbleStorage) WithPrefix(prefix []byte) db.Storage { // NewTx implements the method NewTx of the interface db.Storage func (p *PebbleStorage) NewTx() (db.Tx, error) { - return &PebbleStorageTx{p, make(db.KvMap)}, nil + return &PebbleStorageTx{p, p.pdb.NewIndexedBatch()}, nil } // Get retreives a value from a key in the db.Storage @@ -124,46 +122,34 @@ func (tx *PebbleStorageTx) Get(key []byte) ([]byte, error) { fullkey := db.Concat(tx.prefix, key) - if value, ok := tx.cache.Get(fullkey); ok { - return value, nil - } - - value, closer, err := tx.pdb.Get(fullkey) + v, closer, err := tx.batch.Get(fullkey) if err == pebble.ErrNotFound { return nil, db.ErrNotFound } closer.Close() - return value, err + return v, err } // Put saves a key:value into the db.Storage -func (tx *PebbleStorageTx) Put(k, v []byte) { - tx.cache.Put(db.Concat(tx.prefix, k[:]), v) +func (tx *PebbleStorageTx) Put(k, v []byte) error { + return tx.batch.Set(db.Concat(tx.prefix, k[:]), v, nil) } // Add implements the method Add of the interface db.Tx -func (tx *PebbleStorageTx) Add(atx db.Tx) { - ldbtx := atx.(*PebbleStorageTx) - for _, v := range ldbtx.cache { - tx.cache.Put(v.K, v.V) - } +func (tx *PebbleStorageTx) Add(atx db.Tx) error { + patx := atx.(*PebbleStorageTx) + return tx.batch.Apply(patx.batch, nil) } // Commit implements the method Commit of the interface db.Tx func (tx *PebbleStorageTx) Commit() error { - batch := tx.PebbleStorage.pdb.NewBatch() - for _, v := range tx.cache { - _ = batch.Set(v.K, v.V, nil) - } - - tx.cache = nil - return batch.Commit(nil) + return tx.batch.Commit(nil) } // Close implements the method Close of the interface db.Tx func (tx *PebbleStorageTx) Close() { - tx.cache = nil + _ = tx.batch.Close() } // Close implements the method Close of the interface db.Storage diff --git a/db/test/test.go b/db/test/test.go index 2587d98..d6ab7c8 100644 --- a/db/test/test.go +++ b/db/test/test.go @@ -28,7 +28,8 @@ func TestStorageInsertGet(t *testing.T, sto db.Storage) { tx, err := sto.NewTx() assert.Nil(t, err) - tx.Put(key, value) + err = tx.Put(key, value) + assert.Nil(t, err) v, err := tx.Get(key) assert.Nil(t, err) assert.Equal(t, value, v) @@ -53,7 +54,8 @@ func TestStorageWithPrefix(t *testing.T, sto db.Storage) { sto1tx, err := sto1.NewTx() assert.Nil(t, err) - sto1tx.Put(k, []byte{4, 5, 6}) + err = sto1tx.Put(k, []byte{4, 5, 6}) + assert.Nil(t, err) v1, err := sto1tx.Get(k) assert.Nil(t, err) assert.Equal(t, v1, []byte{4, 5, 6}) @@ -61,7 +63,8 @@ func TestStorageWithPrefix(t *testing.T, sto db.Storage) { sto2tx, err := sto2.NewTx() assert.Nil(t, err) - sto2tx.Put(k, []byte{8, 9}) + err = sto2tx.Put(k, []byte{8, 9}) + assert.Nil(t, err) v2, err := sto2tx.Get(k) assert.Nil(t, err) assert.Equal(t, v2, []byte{8, 9}) @@ -93,16 +96,22 @@ func TestIterate(t *testing.T, sto db.Storage) { assert.Equal(t, 0, len(r)) sto1tx, _ := sto1.NewTx() - sto1tx.Put([]byte{1}, []byte{4}) - sto1tx.Put([]byte{2}, []byte{5}) - sto1tx.Put([]byte{3}, []byte{6}) + err = sto1tx.Put([]byte{1}, []byte{4}) + assert.Nil(t, err) + err = sto1tx.Put([]byte{2}, []byte{5}) + assert.Nil(t, err) + err = sto1tx.Put([]byte{3}, []byte{6}) + assert.Nil(t, err) assert.Nil(t, sto1tx.Commit()) sto2 := sto.WithPrefix([]byte{2}) sto2tx, _ := sto2.NewTx() - sto2tx.Put([]byte{1}, []byte{7}) - sto2tx.Put([]byte{2}, []byte{8}) - sto2tx.Put([]byte{3}, []byte{9}) + err = sto2tx.Put([]byte{1}, []byte{7}) + assert.Nil(t, err) + err = sto2tx.Put([]byte{2}, []byte{8}) + assert.Nil(t, err) + err = sto2tx.Put([]byte{3}, []byte{9}) + assert.Nil(t, err) assert.Nil(t, sto2tx.Commit()) r = []db.KV{} @@ -136,14 +145,17 @@ func TestConcatTx(t *testing.T, sto db.Storage) { if err != nil { panic(err) } - sto1tx.Put(k, []byte{4, 5, 6}) + err = sto1tx.Put(k, []byte{4, 5, 6}) + assert.Nil(t, err) sto2tx, err := sto2.NewTx() if err != nil { panic(err) } - sto2tx.Put(k, []byte{8, 9}) + err = sto2tx.Put(k, []byte{8, 9}) + assert.Nil(t, err) - sto1tx.Add(sto2tx) + err = sto1tx.Add(sto2tx) + assert.Nil(t, err) assert.Nil(t, sto1tx.Commit()) // check outside tx @@ -166,16 +178,22 @@ func TestList(t *testing.T, sto db.Storage) { assert.Equal(t, 0, len(r1)) sto1tx, _ := sto1.NewTx() - sto1tx.Put([]byte{1}, []byte{4}) - sto1tx.Put([]byte{2}, []byte{5}) - sto1tx.Put([]byte{3}, []byte{6}) + err = sto1tx.Put([]byte{1}, []byte{4}) + assert.Nil(t, err) + err = sto1tx.Put([]byte{2}, []byte{5}) + assert.Nil(t, err) + err = sto1tx.Put([]byte{3}, []byte{6}) + assert.Nil(t, err) assert.Nil(t, sto1tx.Commit()) sto2 := sto.WithPrefix([]byte{2}) sto2tx, _ := sto2.NewTx() - sto2tx.Put([]byte{1}, []byte{7}) - sto2tx.Put([]byte{2}, []byte{8}) - sto2tx.Put([]byte{3}, []byte{9}) + err = sto2tx.Put([]byte{1}, []byte{7}) + assert.Nil(t, err) + err = sto2tx.Put([]byte{2}, []byte{8}) + assert.Nil(t, err) + err = sto2tx.Put([]byte{3}, []byte{9}) + assert.Nil(t, err) assert.Nil(t, sto2tx.Commit()) r, err := sto1.List(100) diff --git a/merkletree.go b/merkletree.go index f81ba19..96c3ed0 100644 --- a/merkletree.go +++ b/merkletree.go @@ -141,7 +141,10 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) { return nil, err } mt.rootKey = &HashZero - tx.Put(rootNodeValue, mt.rootKey[:]) + err = tx.Put(rootNodeValue, mt.rootKey[:]) + if err != nil { + return nil, err + } err = tx.Commit() if err != nil { return nil, err @@ -212,7 +215,10 @@ func (mt *MerkleTree) Add(k, v *big.Int) error { return err } mt.rootKey = newRootKey - mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + if err != nil { + return err + } if err := tx.Commit(); err != nil { return err @@ -363,8 +369,8 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) { if _, err := tx.Get(k[:]); err == nil { return nil, ErrNodeKeyAlreadyExists } - tx.Put(k[:], v) - return k, nil + err = tx.Put(k[:], v) + return k, err } // updateNode updates an existing node in the MT. Empty nodes are not stored @@ -382,8 +388,8 @@ func (mt *MerkleTree) updateNode(tx db.Tx, n *Node) (*Hash, error) { return nil, err } v := n.Value() - tx.Put(k[:], v) - return k, nil + err = tx.Put(k[:], v) + return k, err } // Get returns the value of the leaf for the given key @@ -487,7 +493,10 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) { return nil, err } mt.rootKey = newRootKey - mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + if err != nil { + return nil, err + } cp.NewRoot = newRootKey if err := tx.Commit(); err != nil { return nil, err @@ -580,14 +589,20 @@ func (mt *MerkleTree) Delete(k *big.Int) error { func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error { if len(siblings) == 0 { mt.rootKey = &HashZero - mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + if err != nil { + return err + } return tx.Commit() } toUpload := siblings[len(siblings)-1] if len(siblings) < 2 { //nolint:gomnd mt.rootKey = siblings[0] - mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + if err != nil { + return err + } return tx.Commit() } for i := len(siblings) - 2; i >= 0; i-- { //nolint:gomnd @@ -608,13 +623,19 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [ return err } mt.rootKey = newRootKey - mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + if err != nil { + return err + } break } // if i==0 (root position), stop and store the sibling of the deleted leaf as root if i == 0 { mt.rootKey = toUpload - mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + if err != nil { + return err + } break } } @@ -650,9 +671,9 @@ func (mt *MerkleTree) recalculatePathUntilRoot(tx db.Tx, path []bool, node *Node // dbInsert is a helper function to insert a node into a key in an open db // transaction. -func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) { +func (mt *MerkleTree) dbInsert(tx db.Tx, k []byte, t NodeType, data []byte) error { v := append([]byte{byte(t)}, data...) - tx.Put(k, v) + return tx.Put(k, v) } // GetNode gets a node by key from the MT. Empty nodes are not stored in the