mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-07 11:36:47 +01:00
@@ -52,7 +52,7 @@ var (
|
|||||||
// write function is called
|
// write function is called
|
||||||
ErrNotWritable = errors.New("Merkle Tree not writable")
|
ErrNotWritable = errors.New("Merkle Tree not writable")
|
||||||
|
|
||||||
rootNodeValue = []byte("currentroot")
|
dbKeyRootNode = []byte("currentroot")
|
||||||
// HashZero is used at Empty nodes
|
// HashZero is used at Empty nodes
|
||||||
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
)
|
)
|
||||||
@@ -175,14 +175,14 @@ type MerkleTree struct {
|
|||||||
func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
|
func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
|
||||||
mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
|
mt := MerkleTree{db: storage, maxLevels: maxLevels, writable: true}
|
||||||
|
|
||||||
v, err := mt.db.Get(rootNodeValue)
|
root, err := mt.dbGetRoot()
|
||||||
if err != nil {
|
if err == db.ErrNotFound {
|
||||||
tx, err := mt.db.NewTx()
|
tx, err := mt.db.NewTx()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mt.rootKey = &HashZero
|
mt.rootKey = &HashZero
|
||||||
err = tx.Put(rootNodeValue, mt.rootKey[:])
|
err = tx.Put(dbKeyRootNode, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -191,12 +191,24 @@ func NewMerkleTree(storage db.Storage, maxLevels int) (*MerkleTree, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &mt, nil
|
return &mt, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
mt.rootKey = &Hash{}
|
mt.rootKey = root
|
||||||
copy(mt.rootKey[:], v)
|
|
||||||
return &mt, nil
|
return &mt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mt *MerkleTree) dbGetRoot() (*Hash, error) {
|
||||||
|
v, err := mt.db.Get(dbKeyRootNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var root Hash
|
||||||
|
// Skip first byte which contains the NodeType
|
||||||
|
copy(root[:], v[1:])
|
||||||
|
return &root, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DB returns the MerkleTree.DB()
|
// DB returns the MerkleTree.DB()
|
||||||
func (mt *MerkleTree) DB() db.Storage {
|
func (mt *MerkleTree) DB() db.Storage {
|
||||||
return mt.db
|
return mt.db
|
||||||
@@ -256,7 +268,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mt.rootKey = newRootKey
|
mt.rootKey = newRootKey
|
||||||
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -534,7 +546,7 @@ func (mt *MerkleTree) Update(k, v *big.Int) (*CircomProcessorProof, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mt.rootKey = newRootKey
|
mt.rootKey = newRootKey
|
||||||
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -630,7 +642,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error {
|
|||||||
func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
|
func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings []*Hash) error {
|
||||||
if len(siblings) == 0 {
|
if len(siblings) == 0 {
|
||||||
mt.rootKey = &HashZero
|
mt.rootKey = &HashZero
|
||||||
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -640,7 +652,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
|
|||||||
toUpload := siblings[len(siblings)-1]
|
toUpload := siblings[len(siblings)-1]
|
||||||
if len(siblings) < 2 { //nolint:gomnd
|
if len(siblings) < 2 { //nolint:gomnd
|
||||||
mt.rootKey = siblings[0]
|
mt.rootKey = siblings[0]
|
||||||
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -664,7 +676,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mt.rootKey = newRootKey
|
mt.rootKey = newRootKey
|
||||||
err = mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
err = mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -673,7 +685,7 @@ func (mt *MerkleTree) rmAndUpload(tx db.Tx, path []bool, kHash *Hash, siblings [
|
|||||||
// if i==0 (root position), stop and store the sibling of the deleted leaf as root
|
// if i==0 (root position), stop and store the sibling of the deleted leaf as root
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
mt.rootKey = toUpload
|
mt.rootKey = toUpload
|
||||||
err := mt.dbInsert(tx, rootNodeValue, DBEntryTypeRoot, mt.rootKey[:])
|
err := mt.dbInsert(tx, dbKeyRootNode, DBEntryTypeRoot, mt.rootKey[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -108,6 +108,10 @@ func TestNewTree(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
|
assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
|
||||||
|
|
||||||
|
dbRoot, err := mt.dbGetRoot()
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, mt.Root(), dbRoot)
|
||||||
|
|
||||||
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
|
proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, big.NewInt(44), v)
|
assert.Equal(t, big.NewInt(44), v)
|
||||||
@@ -205,6 +209,10 @@ func TestUpdate(t *testing.T) {
|
|||||||
|
|
||||||
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
|
_, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
|
||||||
assert.Equal(t, ErrKeyNotFound, err)
|
assert.Equal(t, ErrKeyNotFound, err)
|
||||||
|
|
||||||
|
dbRoot, err := mt.dbGetRoot()
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, mt.Root(), dbRoot)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdate2(t *testing.T) {
|
func TestUpdate2(t *testing.T) {
|
||||||
@@ -449,6 +457,10 @@ func TestDelete(t *testing.T) {
|
|||||||
err = mt.Delete(big.NewInt(1))
|
err = mt.Delete(big.NewInt(1))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "0", mt.Root().String())
|
assert.Equal(t, "0", mt.Root().String())
|
||||||
|
|
||||||
|
dbRoot, err := mt.dbGetRoot()
|
||||||
|
require.Nil(t, err)
|
||||||
|
assert.Equal(t, mt.Root(), dbRoot)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDelete2(t *testing.T) {
|
func TestDelete2(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user