mirror of
https://github.com/arnaucube/go-merkletree-iden3.git
synced 2026-02-06 19:16:43 +01:00
Small fixes
This commit is contained in:
@@ -29,6 +29,12 @@ func TestExampleMerkleTree(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
fmt.Println(mt.Root().String())
|
fmt.Println(mt.Root().String())
|
||||||
|
|
||||||
|
v, err := mt.Get(key)
|
||||||
|
asseert.Equal(t, value, v)
|
||||||
|
|
||||||
|
value = big.NewInt(3)
|
||||||
|
err = mt.Update(key, value)
|
||||||
|
|
||||||
proof, err := mt.GenerateProof(key, nil)
|
proof, err := mt.GenerateProof(key, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ type LevelDbStorageTx struct {
|
|||||||
cache db.KvMap
|
cache db.KvMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLevelStorage returns a new LevelDbStorage
|
// NewLevelDbStorage returns a new LevelDbStorage
|
||||||
func NewLevelDbStorage(path string, errorIfMissing bool) (*LevelDbStorage, error) {
|
func NewLevelDbStorage(path string, errorIfMissing bool) (*LevelDbStorage, error) {
|
||||||
o := &opt.Options{
|
o := &opt.Options{
|
||||||
ErrorIfMissing: errorIfMissing,
|
ErrorIfMissing: errorIfMissing,
|
||||||
@@ -113,16 +113,16 @@ func (l *LevelDbStorage) Iterate(f func([]byte, []byte) (bool, error)) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get retreives a value from a key in the interface db.Tx
|
// Get retreives a value from a key in the interface db.Tx
|
||||||
func (l *LevelDbStorageTx) Get(key []byte) ([]byte, error) {
|
func (tx *LevelDbStorageTx) Get(key []byte) ([]byte, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
fullkey := db.Concat(l.prefix, key)
|
fullkey := db.Concat(tx.prefix, key)
|
||||||
|
|
||||||
if value, ok := l.cache.Get(fullkey); ok {
|
if value, ok := tx.cache.Get(fullkey); ok {
|
||||||
return value, nil
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
value, err := l.ldb.Get(fullkey, nil)
|
value, err := tx.ldb.Get(fullkey, nil)
|
||||||
if err == errors.ErrNotFound {
|
if err == errors.ErrNotFound {
|
||||||
return nil, db.ErrNotFound
|
return nil, db.ErrNotFound
|
||||||
}
|
}
|
||||||
@@ -130,7 +130,7 @@ func (l *LevelDbStorageTx) Get(key []byte) ([]byte, error) {
|
|||||||
return value, err
|
return value, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert saves a key:value into the db.Storage
|
// Put saves a key:value into the db.Storage
|
||||||
func (tx *LevelDbStorageTx) Put(k, v []byte) {
|
func (tx *LevelDbStorageTx) Put(k, v []byte) {
|
||||||
tx.cache.Put(db.Concat(tx.prefix, k[:]), v)
|
tx.cache.Put(db.Concat(tx.prefix, k[:]), v)
|
||||||
}
|
}
|
||||||
@@ -144,19 +144,19 @@ func (tx *LevelDbStorageTx) Add(atx db.Tx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Commit implements the method Commit of the interface db.Tx
|
// Commit implements the method Commit of the interface db.Tx
|
||||||
func (l *LevelDbStorageTx) Commit() error {
|
func (tx *LevelDbStorageTx) Commit() error {
|
||||||
var batch leveldb.Batch
|
var batch leveldb.Batch
|
||||||
for _, v := range l.cache {
|
for _, v := range tx.cache {
|
||||||
batch.Put(v.K, v.V)
|
batch.Put(v.K, v.V)
|
||||||
}
|
}
|
||||||
|
|
||||||
l.cache = nil
|
tx.cache = nil
|
||||||
return l.ldb.Write(&batch, nil)
|
return tx.ldb.Write(&batch, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close implements the method Close of the interface db.Tx
|
// Close implements the method Close of the interface db.Tx
|
||||||
func (l *LevelDbStorageTx) Close() {
|
func (tx *LevelDbStorageTx) Close() {
|
||||||
l.cache = nil
|
tx.cache = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close implements the method Close of the interface db.Storage
|
// Close implements the method Close of the interface db.Storage
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func NewMemoryStorage() *MemoryStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Info implements the method Info of the interface db.Storage
|
// Info implements the method Info of the interface db.Storage
|
||||||
func (l *MemoryStorage) Info() string {
|
func (m *MemoryStorage) Info() string {
|
||||||
return "in-memory"
|
return "in-memory"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,21 +41,21 @@ func (m *MemoryStorage) NewTx() (db.Tx, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get retreives a value from a key in the db.Storage
|
// Get retreives a value from a key in the db.Storage
|
||||||
func (l *MemoryStorage) Get(key []byte) ([]byte, error) {
|
func (m *MemoryStorage) Get(key []byte) ([]byte, error) {
|
||||||
if v, ok := l.kv.Get(db.Concat(l.prefix, key[:])); ok {
|
if v, ok := m.kv.Get(db.Concat(m.prefix, key[:])); ok {
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, db.ErrNotFound
|
return nil, db.ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate implements the method Iterate of the interface db.Storage
|
// Iterate implements the method Iterate of the interface db.Storage
|
||||||
func (l *MemoryStorage) Iterate(f func([]byte, []byte) (bool, error)) error {
|
func (m *MemoryStorage) Iterate(f func([]byte, []byte) (bool, error)) error {
|
||||||
kvs := make([]db.KV, 0)
|
kvs := make([]db.KV, 0)
|
||||||
for _, v := range l.kv {
|
for _, v := range m.kv {
|
||||||
if len(v.K) < len(l.prefix) || !bytes.Equal(v.K[:len(l.prefix)], l.prefix) {
|
if len(v.K) < len(m.prefix) || !bytes.Equal(v.K[:len(m.prefix)], m.prefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
localkey := v.K[len(l.prefix):]
|
localkey := v.K[len(m.prefix):]
|
||||||
kvs = append(kvs, db.KV{K: localkey, V: v.V})
|
kvs = append(kvs, db.KV{K: localkey, V: v.V})
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -116,9 +116,9 @@ func (m *MemoryStorage) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List implements the method List of the interface db.Storage
|
// List implements the method List of the interface db.Storage
|
||||||
func (l *MemoryStorage) List(limit int) ([]db.KV, error) {
|
func (m *MemoryStorage) List(limit int) ([]db.KV, error) {
|
||||||
ret := []db.KV{}
|
ret := []db.KV{}
|
||||||
err := l.Iterate(func(key []byte, value []byte) (bool, error) {
|
err := m.Iterate(func(key []byte, value []byte) (bool, error) {
|
||||||
ret = append(ret, db.KV{K: db.Clone(key), V: db.Clone(value)})
|
ret = append(ret, db.KV{K: db.Clone(key), V: db.Clone(value)})
|
||||||
if len(ret) == limit {
|
if len(ret) == limit {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ func (mt *MerkleTree) Root() *Hash {
|
|||||||
return mt.rootKey
|
return mt.rootKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Snapshot returns a read-only copy of the MerkleTree
|
||||||
func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
|
func (mt *MerkleTree) Snapshot(rootKey *Hash) (*MerkleTree, error) {
|
||||||
mt.RLock()
|
mt.RLock()
|
||||||
defer mt.RUnlock()
|
defer mt.RUnlock()
|
||||||
@@ -383,6 +384,9 @@ func (mt *MerkleTree) Update(k, v *big.Int) error {
|
|||||||
// update leaf and upload to the root
|
// update leaf and upload to the root
|
||||||
newNodeLeaf := NewNodeLeaf(kHash, vHash)
|
newNodeLeaf := NewNodeLeaf(kHash, vHash)
|
||||||
_, err := mt.addNode(tx, newNodeLeaf)
|
_, err := mt.addNode(tx, newNodeLeaf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
|
newRootKey, err := mt.recalculatePathUntilRoot(tx, path, newNodeLeaf, siblings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -676,7 +680,7 @@ func (p *Proof) Bytes() []byte {
|
|||||||
return bs
|
return bs
|
||||||
}
|
}
|
||||||
|
|
||||||
// SiblingsFromProof returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits.
|
// SiblingsFromProof returns all the siblings of the proof.
|
||||||
func SiblingsFromProof(proof *Proof) []*Hash {
|
func SiblingsFromProof(proof *Proof) []*Hash {
|
||||||
sibIdx := 0
|
sibIdx := 0
|
||||||
var siblings []*Hash
|
var siblings []*Hash
|
||||||
@@ -691,10 +695,12 @@ func SiblingsFromProof(proof *Proof) []*Hash {
|
|||||||
return siblings
|
return siblings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllSiblings returns all the siblings of the proof.
|
||||||
func (p *Proof) AllSiblings() []*Hash {
|
func (p *Proof) AllSiblings() []*Hash {
|
||||||
return SiblingsFromProof(p)
|
return SiblingsFromProof(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllSiblingsCircom returns all the siblings of the proof. This function is used to generate the siblings input for the circom circuits.
|
||||||
func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
|
func (p *Proof) AllSiblingsCircom(levels int) []*big.Int {
|
||||||
siblings := p.AllSiblings()
|
siblings := p.AllSiblings()
|
||||||
// Add the rest of empty levels to the siblings
|
// Add the rest of empty levels to the siblings
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ func TestUpdate(t *testing.T) {
|
|||||||
assert.Equal(t, big.NewInt(20), v)
|
assert.Equal(t, big.NewInt(20), v)
|
||||||
|
|
||||||
err = mt.Update(big.NewInt(10), big.NewInt(1024))
|
err = mt.Update(big.NewInt(10), big.NewInt(1024))
|
||||||
|
assert.Nil(t, err)
|
||||||
v, err = mt.Get(big.NewInt(10))
|
v, err = mt.Get(big.NewInt(10))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, big.NewInt(1024), v)
|
assert.Equal(t, big.NewInt(1024), v)
|
||||||
|
|||||||
Reference in New Issue
Block a user