You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

384 lines
6.8 KiB

/**
* @file
* @copyright defined in aergo/LICENSE.txt
*/
package db
import (
"bytes"
"container/list"
"encoding/gob"
"os"
"path"
"sort"
"sync"
)
// This function is always called first
func init() {
dbConstructor := func(dir string) (DB, error) {
return newMemoryDB(dir)
}
registorDBConstructor(MemoryImpl, dbConstructor)
}
func newMemoryDB(dir string) (DB, error) {
var db map[string][]byte
filePath := path.Join(dir, "database")
file, err := os.Open(filePath)
if err == nil {
decoder := gob.NewDecoder(file) //
err = decoder.Decode(&db)
if err != nil {
return nil, err
}
}
file.Close()
if db == nil {
db = make(map[string][]byte)
}
database := &memorydb{
db: db,
dir: filePath,
}
return database, nil
}
//=========================================================
// DB Implementation
//=========================================================
// Enforce database and transaction implements interfaces
var _ DB = (*memorydb)(nil)
type memorydb struct {
lock sync.Mutex
db map[string][]byte
dir string
}
func (db *memorydb) Type() string {
return "memorydb"
}
func (db *memorydb) Set(key, value []byte) {
db.lock.Lock()
defer db.lock.Unlock()
key = convNilToBytes(key)
value = convNilToBytes(value)
db.db[string(key)] = value
}
func (db *memorydb) Delete(key []byte) {
db.lock.Lock()
defer db.lock.Unlock()
key = convNilToBytes(key)
delete(db.db, string(key))
}
func (db *memorydb) Get(key []byte) []byte {
db.lock.Lock()
defer db.lock.Unlock()
key = convNilToBytes(key)
return db.db[string(key)]
}
func (db *memorydb) Exist(key []byte) bool {
db.lock.Lock()
defer db.lock.Unlock()
key = convNilToBytes(key)
_, ok := db.db[string(key)]
return ok
}
func (db *memorydb) Close() {
db.lock.Lock()
defer db.lock.Unlock()
file, err := os.OpenFile(db.dir, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err == nil {
encoder := gob.NewEncoder(file)
encoder.Encode(db.db)
}
file.Close()
}
func (db *memorydb) NewTx() Transaction {
return &memoryTransaction{
db: db,
opList: list.New(),
isDiscard: false,
isCommit: false,
}
}
func (db *memorydb) NewBulk() Bulk {
return &memoryBulk{
db: db,
opList: list.New(),
isDiscard: false,
isCommit: false,
}
}
//=========================================================
// Transaction Implementation
//=========================================================
type memoryTransaction struct {
txLock sync.Mutex
db *memorydb
opList *list.List
isDiscard bool
isCommit bool
}
type txOp struct {
isSet bool
key []byte
value []byte
}
func (transaction *memoryTransaction) Set(key, value []byte) {
transaction.txLock.Lock()
defer transaction.txLock.Unlock()
key = convNilToBytes(key)
value = convNilToBytes(value)
transaction.opList.PushBack(&txOp{true, key, value})
}
func (transaction *memoryTransaction) Delete(key []byte) {
transaction.txLock.Lock()
defer transaction.txLock.Unlock()
key = convNilToBytes(key)
transaction.opList.PushBack(&txOp{false, key, nil})
}
func (transaction *memoryTransaction) Commit() {
transaction.txLock.Lock()
defer transaction.txLock.Unlock()
if transaction.isDiscard {
panic("Commit after dicard tx is not allowed")
} else if transaction.isCommit {
panic("Commit occures two times")
}
db := transaction.db
db.lock.Lock()
defer db.lock.Unlock()
for e := transaction.opList.Front(); e != nil; e = e.Next() {
op := e.Value.(*txOp)
if op.isSet {
db.db[string(op.key)] = op.value
} else {
delete(db.db, string(op.key))
}
}
transaction.isCommit = true
}
func (transaction *memoryTransaction) Discard() {
transaction.txLock.Lock()
defer transaction.txLock.Unlock()
transaction.isDiscard = true
}
//=========================================================
// Bulk Implementation
//=========================================================
type memoryBulk struct {
txLock sync.Mutex
db *memorydb
opList *list.List
isDiscard bool
isCommit bool
}
func (bulk *memoryBulk) Set(key, value []byte) {
bulk.txLock.Lock()
defer bulk.txLock.Unlock()
key = convNilToBytes(key)
value = convNilToBytes(value)
bulk.opList.PushBack(&txOp{true, key, value})
}
func (bulk *memoryBulk) Delete(key []byte) {
bulk.txLock.Lock()
defer bulk.txLock.Unlock()
key = convNilToBytes(key)
bulk.opList.PushBack(&txOp{false, key, nil})
}
func (bulk *memoryBulk) Flush() {
bulk.txLock.Lock()
defer bulk.txLock.Unlock()
if bulk.isDiscard {
panic("Commit after dicard tx is not allowed")
} else if bulk.isCommit {
panic("Commit occures two times")
}
db := bulk.db
db.lock.Lock()
defer db.lock.Unlock()
for e := bulk.opList.Front(); e != nil; e = e.Next() {
op := e.Value.(*txOp)
if op.isSet {
db.db[string(op.key)] = op.value
} else {
delete(db.db, string(op.key))
}
}
bulk.isCommit = true
}
func (bulk *memoryBulk) DiscardLast() {
bulk.txLock.Lock()
defer bulk.txLock.Unlock()
bulk.isDiscard = true
}
//=========================================================
// Iterator Implementation
//=========================================================
type memoryIterator struct {
start []byte
end []byte
reverse bool
keys []string
isInvalid bool
cursor int
db *memorydb
}
func isKeyInRange(key []byte, start []byte, end []byte, reverse bool) bool {
if reverse {
if start != nil && bytes.Compare(start, key) < 0 {
return false
}
if end != nil && bytes.Compare(key, end) <= 0 {
return false
}
return true
}
if bytes.Compare(key, start) < 0 {
return false
}
if end != nil && bytes.Compare(end, key) <= 0 {
return false
}
return true
}
func (db *memorydb) Iterator(start, end []byte) Iterator {
db.lock.Lock()
defer db.lock.Unlock()
var reverse bool
// if end is bigger then start, then reverse order
if bytes.Compare(start, end) == 1 {
reverse = true
} else {
reverse = false
}
var keys sort.StringSlice
for key := range db.db {
if isKeyInRange([]byte(key), start, end, reverse) {
keys = append(keys, key)
}
}
if reverse {
sort.Sort(sort.Reverse(keys))
} else {
sort.Strings(keys)
}
return &memoryIterator{
start: start,
end: end,
reverse: reverse,
isInvalid: false,
keys: keys,
cursor: 0,
db: db,
}
}
func (iter *memoryIterator) Next() {
if !iter.Valid() {
panic("Iterator is Invalid")
}
iter.cursor++
}
func (iter *memoryIterator) Valid() bool {
// Once invalid, forever invalid.
if iter.isInvalid {
return false
}
return 0 <= iter.cursor && iter.cursor < len(iter.keys)
}
func (iter *memoryIterator) Key() (key []byte) {
if !iter.Valid() {
panic("Iterator is Invalid")
}
return []byte(iter.keys[iter.cursor])
}
func (iter *memoryIterator) Value() (value []byte) {
if !iter.Valid() {
panic("Iterator is Invalid")
}
key := []byte(iter.keys[iter.cursor])
return iter.db.Get(key)
}