From 2b4090bb7d025513bd0f64579807ded6f9d5d761 Mon Sep 17 00:00:00 2001 From: Eduard S Date: Tue, 15 Dec 2020 15:12:18 +0100 Subject: [PATCH] Fix loading root key --- merkletree.go | 36 ++++++++++++++++++++++++------------ merkletree_test.go | 12 ++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/merkletree.go b/merkletree.go index 0dddf81..c57921f 100644 --- a/merkletree.go +++ b/merkletree.go @@ -52,7 +52,7 @@ var ( // write function is called ErrNotWritable = errors.New("Merkle Tree not writable") - rootNodeValue = []byte("currentroot") + dbKeyRootNode = []byte("currentroot") // HashZero is used at Empty nodes HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) @@ -175,14 +175,14 @@ type MerkleTree struct { func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) { mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true} - v, err := mt.db.Get(rootNodeValue) - if err != nil { + root, err := mt.dbGetRoot() + if err == db.ErrNotFound { tx, err := mt.db.NewTx() if err != nil { return nil, err } mt.rootKey = &HashZero - err = tx.Put(rootNodeValue, mt.rootKey[:]) + err = tx.Put(dbKeyRootNode, mt.rootKey[:]) if err != nil { return nil, err } @@ -191,12 +191,24 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) { return nil, err } return &mt, nil + } else if err != nil { + return nil, err } - mt.rootKey = &Hash{} - copy(mt.rootKey[:], v) + mt.rootKey = root return &mt, nil } +func (mt *MerkleTree) dbGetRoot() (*Hash, error) { + v, err := mt.db.Get(dbKeyRootNode) + if err != nil { + return nil, err + } + var root Hash + // Skip first byte which contains the NodeType + copy(root[:], v[1:]) + return &root, nil +} + // DB returns the MerkleTree.DB() func (mt *MerkleTree) DB() db.Storage { return mt.db @@ -256,7 +268,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error { return err } mt.rootKey = newRootKey - err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:]) if err != nil { return err } @@ -534,7 +546,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) { return nil, err } mt.rootKey = newRootKey - err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:]) if err != nil { return nil, err } @@ -630,7 +642,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error { func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error { if len(siblings) == 0 { mt.rootKey = &HashZero - err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:]) if err != nil { return err } @@ -640,7 +652,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [ toUpload := siblings[len(siblings)-1] if len(siblings) < 2 { //nolint:gomnd mt.rootKey = siblings[0] - err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:]) if err != nil { return err } @@ -664,7 +676,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [ return err } mt.rootKey = newRootKey - err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:]) if err != nil { return err } @@ -673,7 +685,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [ // if i==0 (root position), stop and store the sibling of the deleted leaf as root if i == 0 { mt.rootKey = toUpload - err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:]) + err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:]) if err != nil { return err } diff --git a/merkletree_test.go b/merkletree_test.go index 032da94..c7683cd 100644 --- a/merkletree_test.go +++ b/merkletree_test.go @@ -108,6 +108,10 @@ func TestNewTree(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String()) + dbRoot, err := mt.dbGetRoot() + require.Nil(t, err) + assert.Equal(t, mt.Root(), dbRoot) + proof, v, err := mt.GenerateProof(big.NewInt(33), nil) assert.Nil(t, err) assert.Equal(t, big.NewInt(44), v) @@ -205,6 +209,10 @@ func TestUpdate(t *testing.T) { _, err = mt.Update(big.NewInt(1000), big.NewInt(1024)) assert.Equal(t, ErrKeyNotFound, err) + + dbRoot, err := mt.dbGetRoot() + require.Nil(t, err) + assert.Equal(t, mt.Root(), dbRoot) } func TestUpdate2(t *testing.T) { @@ -449,6 +457,10 @@ func TestDelete(t *testing.T) { err = mt.Delete(big.NewInt(1)) assert.Nil(t, err) assert.Equal(t, "0", mt.Root().String()) + + dbRoot, err := mt.dbGetRoot() + require.Nil(t, err) + assert.Equal(t, mt.Root(), dbRoot) } func TestDelete2(t *testing.T) {