mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-07 03:26:46 +01:00
@@ -52,7 +52,7 @@ var (
|
||||
// write function is called
|
||||
ErrNotWritable = errors.New("Merkle Tree not writable")
|
||||
|
||||
rootNodeValue = []byte("currentroot")
|
||||
dbKeyRootNode = []byte("currentroot")
|
||||
// HashZero is used at Empty nodes
|
||||
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
)
|
||||
@@ -175,14 +175,14 @@ type MerkleTree struct {
|
||||
func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
|
||||
mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
|
||||
|
||||
v, err := mt.db.Get(rootNodeValue)
|
||||
if err != nil {
|
||||
root, err := mt.dbGetRoot()
|
||||
if err == db.ErrNotFound {
|
||||
tx, err := mt.db.NewTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mt.rootKey = &HashZero
|
||||
err = tx.Put(rootNodeValue, mt.rootKey[:])
|
||||
err = tx.Put(dbKeyRootNode, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -191,12 +191,24 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
|
||||
return nil, err
|
||||
}
|
||||
return &mt, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mt.rootKey = &Hash{}
|
||||
copy(mt.rootKey[:], v)
|
||||
mt.rootKey = root
|
||||
return &mt, nil
|
||||
}
|
||||
|
||||
func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
|
||||
v, err := mt.db.Get(dbKeyRootNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var root Hash
|
||||
// Skip first byte which contains the NodeType
|
||||
copy(root[:], v[1:])
|
||||
return &root, nil
|
||||
}
|
||||
|
||||
// DB returns the MerkleTree.DB()
|
||||
func (mt *MerkleTree) DB() db.Storage {
|
||||
return mt.db
|
||||
@@ -256,7 +268,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
|
||||
return err
|
||||
}
|
||||
mt.rootKey = newRootKey
|
||||
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -534,7 +546,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
|
||||
return nil, err
|
||||
}
|
||||
mt.rootKey = newRootKey
|
||||
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -630,7 +642,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error {
|
||||
func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
|
||||
if len(siblings) == 0 {
|
||||
mt.rootKey = &HashZero
|
||||
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -640,7 +652,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
|
||||
toUpload := siblings[len(siblings)-1]
|
||||
if len(siblings) < 2 { //nolint:gomnd
|
||||
mt.rootKey = siblings[0]
|
||||
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -664,7 +676,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
|
||||
return err
|
||||
}
|
||||
mt.rootKey = newRootKey
|
||||
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -673,7 +685,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
|
||||
// if i==0 (root position), stop and store the sibling of the deleted leaf as root
|
||||
if i == 0 {
|
||||
mt.rootKey = toUpload
|
||||
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
||||
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -108,6 +108,10 @@ func TestNewTree(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
|
||||
|
||||
dbRoot, err := mt.dbGetRoot()
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, mt.Root(), dbRoot)
|
||||
|
||||
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, big.NewInt(44), v)
|
||||
@@ -205,6 +209,10 @@ func TestUpdate(t *testing.T) {
|
||||
|
||||
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
|
||||
assert.Equal(t, ErrKeyNotFound, err)
|
||||
|
||||
dbRoot, err := mt.dbGetRoot()
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, mt.Root(), dbRoot)
|
||||
}
|
||||
|
||||
func TestUpdate2(t *testing.T) {
|
||||
@@ -449,6 +457,10 @@ func TestDelete(t *testing.T) {
|
||||
err = mt.Delete(big.NewInt(1))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "0", mt.Root().String())
|
||||
|
||||
dbRoot, err := mt.dbGetRoot()
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, mt.Root(), dbRoot)
|
||||
}
|
||||
|
||||
func TestDelete2(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user