Merge pull request #12 from iden3/feature/fixroot

Fix loading root key
This commit is contained in:
arnau
2020-12-15 15:20:17 +01:00
committed by GitHub
2 changed files with 36 additions and 12 deletions

View File

@@ -52,7 +52,7 @@ var (
// write function is called // write function is called
ErrNotWritable = errors.New("Merkle Tree not writable") ErrNotWritable = errors.New("Merkle Tree not writable")
rootNodeValue = []byte("currentroot") dbKeyRootNode = []byte("currentroot")
// HashZero is used at Empty nodes // HashZero is used at Empty nodes
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
) )
@@ -175,14 +175,14 @@ type MerkleTree struct {
func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) { func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true} mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
v, err := mt.db.Get(rootNodeValue) root, err := mt.dbGetRoot()
if err != nil { if err == db.ErrNotFound {
tx, err := mt.db.NewTx() tx, err := mt.db.NewTx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
mt.rootKey = &HashZero mt.rootKey = &HashZero
err = tx.Put(rootNodeValue, mt.rootKey[:]) err = tx.Put(dbKeyRootNode, mt.rootKey[:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -191,12 +191,24 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
return nil, err return nil, err
} }
return &mt, nil return &mt, nil
} else if err != nil {
return nil, err
} }
mt.rootKey = &Hash{} mt.rootKey = root
copy(mt.rootKey[:], v)
return &mt, nil return &mt, nil
} }
func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
v, err := mt.db.Get(dbKeyRootNode)
if err != nil {
return nil, err
}
var root Hash
// Skip first byte which contains the NodeType
copy(root[:], v[1:])
return &root, nil
}
// DB returns the MerkleTree.DB() // DB returns the MerkleTree.DB()
func (mt *MerkleTree) DB() db.Storage { func (mt *MerkleTree) DB() db.Storage {
return mt.db return mt.db
@@ -256,7 +268,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return err return err
} }
mt.rootKey = newRootKey mt.rootKey = newRootKey
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil { if err != nil {
return err return err
} }
@@ -534,7 +546,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
return nil, err return nil, err
} }
mt.rootKey = newRootKey mt.rootKey = newRootKey
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -630,7 +642,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error {
func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error { func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
if len(siblings) == 0 { if len(siblings) == 0 {
mt.rootKey = &HashZero mt.rootKey = &HashZero
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil { if err != nil {
return err return err
} }
@@ -640,7 +652,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
toUpload := siblings[len(siblings)-1] toUpload := siblings[len(siblings)-1]
if len(siblings) < 2 { //nolint:gomnd if len(siblings) < 2 { //nolint:gomnd
mt.rootKey = siblings[0] mt.rootKey = siblings[0]
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil { if err != nil {
return err return err
} }
@@ -664,7 +676,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
return err return err
} }
mt.rootKey = newRootKey mt.rootKey = newRootKey
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil { if err != nil {
return err return err
} }
@@ -673,7 +685,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
// if i==0 (root position), stop and store the sibling of the deleted leaf as root // if i==0 (root position), stop and store the sibling of the deleted leaf as root
if i == 0 { if i == 0 {
mt.rootKey = toUpload mt.rootKey = toUpload
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
if err != nil { if err != nil {
return err return err
} }

View File

@@ -108,6 +108,10 @@ func TestNewTree(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String()) assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
proof, v, err := mt.GenerateProof(big.NewInt(33), nil) proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, big.NewInt(44), v) assert.Equal(t, big.NewInt(44), v)
@@ -205,6 +209,10 @@ func TestUpdate(t *testing.T) {
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024)) _, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
assert.Equal(t, ErrKeyNotFound, err) assert.Equal(t, ErrKeyNotFound, err)
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
} }
func TestUpdate2(t *testing.T) { func TestUpdate2(t *testing.T) {
@@ -449,6 +457,10 @@ func TestDelete(t *testing.T) {
err = mt.Delete(big.NewInt(1)) err = mt.Delete(big.NewInt(1))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "0", mt.Root().String()) assert.Equal(t, "0", mt.Root().String())
dbRoot, err := mt.dbGetRoot()
require.Nil(t, err)
assert.Equal(t, mt.Root(), dbRoot)
} }
func TestDelete2(t *testing.T) { func TestDelete2(t *testing.T) {