Browse Source

Fix loading root key

feature/fixroot
Eduard S 3 years ago
parent
commit
2b4090bb7d
2 changed files with 36 additions and 12 deletions
  1. +24
    -12
      merkletree.go
  2. +12
    -0
      merkletree_test.go

+ 24
- 12
merkletree.go

@ -52,7 +52,7 @@ var (
// write function is called
ErrNotWritable = errors.New("Merkle Tree not writable")
rootNodeValue = []byte("currentroot")
dbKeyRootNode = []byte("currentroot")
// HashZero is used at Empty nodes
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
@ -175,14 +175,14 @@ type MerkleTree struct {
func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
v, err := mt.db.Get(rootNodeValue)
if err != nil {
root, err := mt.dbGetRoot()
if err == db.ErrNotFound {
tx, err := mt.db.NewTx()
if err != nil {
return nil, err
}
mt.rootKey = &HashZero
err = tx.Put(rootNodeValue, mt.rootKey[:])
err = tx.Put(dbKeyRootNode, mt.rootKey[:])
if err != nil {
return nil, err
}
@ -191,12 +191,24 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
return nil, err
}
return &mt, nil
} else if err != nil {
return nil, err
}
mt.rootKey = &Hash{}
copy(mt.rootKey[:], v)
mt.rootKey = root
return &mt, nil
}
func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
v, err := mt.db.Get(dbKeyRootNode)
if err != nil {
return nil, err
}
var root Hash
// Skip first byte which contains the NodeType
copy(root[:], v[1:])
return &root, nil
}
// DB returns the MerkleTree.DB()
func (mt *MerkleTree) DB() db.Storage {
return mt.db
@ -256,7 +268,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return err
}
mt.rootKey = newRootKey
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil {
return err
}
@ -534,7 +546,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
return nil, err
}
mt.rootKey = newRootKey
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil {
return nil, err
}
@ -630,7 +642,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error {
func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
if len(siblings) == 0 {
mt.rootKey = &HashZero
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil {
return err
}
@ -640,7 +652,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
toUpload := siblings[len(siblings)-1]
if len(siblings) < 2 { //nolint:gomnd
mt.rootKey = siblings[0]
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil {
return err
}
@ -664,7 +676,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
return err
}
mt.rootKey = newRootKey
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil {
return err
}
@ -673,7 +685,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
// if i==0 (root position), stop and store the sibling of the deleted leaf as root
if i == 0 {
mt.rootKey = toUpload
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil {
return err
}

+ 12
- 0
merkletree_test.go

@ -108,6 +108,10 @@ func TestNewTree(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v)
@ -205,6 +209,10 @@ func TestUpdate(t *testing.T) {
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
assert.Equal(t, ErrKeyNotFound, err)
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
}
func TestUpdate2(t *testing.T) {
@ -449,6 +457,10 @@ func TestDelete(t *testing.T) {
err = mt.Delete(big.NewInt(1))
assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String())
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
}
func TestDelete2(t *testing.T) {

Loading…
Cancel
Save