/** * @file * @copyright defined in aergo/LICENSE.txt */ package db import ( "bytes" "container/list" "encoding/gob" "os" "path" "sort" "sync" ) // This function is always called first func init() { dbConstructor := func(dir string) (DB, error) { return newMemoryDB(dir) } registorDBConstructor(MemoryImpl, dbConstructor) } func newMemoryDB(dir string) (DB, error) { var db map[string][]byte filePath := path.Join(dir, "database") file, err := os.Open(filePath) if err == nil { decoder := gob.NewDecoder(file) // err = decoder.Decode(&db) if err != nil { return nil, err } } file.Close() if db == nil { db = make(map[string][]byte) } database := &memorydb{ db: db, dir: filePath, } return database, nil } //========================================================= // DB Implementation //========================================================= // Enforce database and transaction implements interfaces var _ DB = (*memorydb)(nil) type memorydb struct { lock sync.Mutex db map[string][]byte dir string } func (db *memorydb) Type() string { return "memorydb" } func (db *memorydb) Set(key, value []byte) { db.lock.Lock() defer db.lock.Unlock() key = convNilToBytes(key) value = convNilToBytes(value) db.db[string(key)] = value } func (db *memorydb) Delete(key []byte) { db.lock.Lock() defer db.lock.Unlock() key = convNilToBytes(key) delete(db.db, string(key)) } func (db *memorydb) Get(key []byte) []byte { db.lock.Lock() defer db.lock.Unlock() key = convNilToBytes(key) return db.db[string(key)] } func (db *memorydb) Exist(key []byte) bool { db.lock.Lock() defer db.lock.Unlock() key = convNilToBytes(key) _, ok := db.db[string(key)] return ok } func (db *memorydb) Close() { db.lock.Lock() defer db.lock.Unlock() file, err := os.OpenFile(db.dir, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) if err == nil { encoder := gob.NewEncoder(file) encoder.Encode(db.db) } file.Close() } func (db *memorydb) NewTx() Transaction { return &memoryTransaction{ db: db, opList: list.New(), isDiscard: false, isCommit: false, } } func (db *memorydb) NewBulk() Bulk { return &memoryBulk{ db: db, opList: list.New(), isDiscard: false, isCommit: false, } } //========================================================= // Transaction Implementation //========================================================= type memoryTransaction struct { txLock sync.Mutex db *memorydb opList *list.List isDiscard bool isCommit bool } type txOp struct { isSet bool key []byte value []byte } func (transaction *memoryTransaction) Set(key, value []byte) { transaction.txLock.Lock() defer transaction.txLock.Unlock() key = convNilToBytes(key) value = convNilToBytes(value) transaction.opList.PushBack(&txOp{true, key, value}) } func (transaction *memoryTransaction) Delete(key []byte) { transaction.txLock.Lock() defer transaction.txLock.Unlock() key = convNilToBytes(key) transaction.opList.PushBack(&txOp{false, key, nil}) } func (transaction *memoryTransaction) Commit() { transaction.txLock.Lock() defer transaction.txLock.Unlock() if transaction.isDiscard { panic("Commit after dicard tx is not allowed") } else if transaction.isCommit { panic("Commit occures two times") } db := transaction.db db.lock.Lock() defer db.lock.Unlock() for e := transaction.opList.Front(); e != nil; e = e.Next() { op := e.Value.(*txOp) if op.isSet { db.db[string(op.key)] = op.value } else { delete(db.db, string(op.key)) } } transaction.isCommit = true } func (transaction *memoryTransaction) Discard() { transaction.txLock.Lock() defer transaction.txLock.Unlock() transaction.isDiscard = true } //========================================================= // Bulk Implementation //========================================================= type memoryBulk struct { txLock sync.Mutex db *memorydb opList *list.List isDiscard bool isCommit bool } func (bulk *memoryBulk) Set(key, value []byte) { bulk.txLock.Lock() defer bulk.txLock.Unlock() key = convNilToBytes(key) value = convNilToBytes(value) bulk.opList.PushBack(&txOp{true, key, value}) } func (bulk *memoryBulk) Delete(key []byte) { bulk.txLock.Lock() defer bulk.txLock.Unlock() key = convNilToBytes(key) bulk.opList.PushBack(&txOp{false, key, nil}) } func (bulk *memoryBulk) Flush() { bulk.txLock.Lock() defer bulk.txLock.Unlock() if bulk.isDiscard { panic("Commit after dicard tx is not allowed") } else if bulk.isCommit { panic("Commit occures two times") } db := bulk.db db.lock.Lock() defer db.lock.Unlock() for e := bulk.opList.Front(); e != nil; e = e.Next() { op := e.Value.(*txOp) if op.isSet { db.db[string(op.key)] = op.value } else { delete(db.db, string(op.key)) } } bulk.isCommit = true } func (bulk *memoryBulk) DiscardLast() { bulk.txLock.Lock() defer bulk.txLock.Unlock() bulk.isDiscard = true } //========================================================= // Iterator Implementation //========================================================= type memoryIterator struct { start []byte end []byte reverse bool keys []string isInvalid bool cursor int db *memorydb } func isKeyInRange(key []byte, start []byte, end []byte, reverse bool) bool { if reverse { if start != nil && bytes.Compare(start, key) < 0 { return false } if end != nil && bytes.Compare(key, end) <= 0 { return false } return true } if bytes.Compare(key, start) < 0 { return false } if end != nil && bytes.Compare(end, key) <= 0 { return false } return true } func (db *memorydb) Iterator(start, end []byte) Iterator { db.lock.Lock() defer db.lock.Unlock() var reverse bool // if end is bigger then start, then reverse order if bytes.Compare(start, end) == 1 { reverse = true } else { reverse = false } var keys sort.StringSlice for key := range db.db { if isKeyInRange([]byte(key), start, end, reverse) { keys = append(keys, key) } } if reverse { sort.Sort(sort.Reverse(keys)) } else { sort.Strings(keys) } return &memoryIterator{ start: start, end: end, reverse: reverse, isInvalid: false, keys: keys, cursor: 0, db: db, } } func (iter *memoryIterator) Next() { if !iter.Valid() { panic("Iterator is Invalid") } iter.cursor++ } func (iter *memoryIterator) Valid() bool { // Once invalid, forever invalid. if iter.isInvalid { return false } return 0 <= iter.cursor && iter.cursor < len(iter.keys) } func (iter *memoryIterator) Key() (key []byte) { if !iter.Valid() { panic("Iterator is Invalid") } return []byte(iter.keys[iter.cursor]) } func (iter *memoryIterator) Value() (value []byte) { if !iter.Valid() { panic("Iterator is Invalid") } key := []byte(iter.keys[iter.cursor]) return iter.db.Get(key) }