Browse Source

Add mt.Get to retrieve a value for the given key

circomproofs
arnaucube 3 years ago
parent
commit
c01a9f4e46
2 changed files with 70 additions and 8 deletions
  1. +45
    -8
      merkletree.go
  2. +25
    -0
      merkletree_test.go

+ 45
- 8
merkletree.go

@ -24,9 +24,9 @@ const (
var (
// ErrNodeKeyAlreadyExists is used when a node key already exists.
ErrNodeKeyAlreadyExists = errors.New("node already exists")
// ErrEntryIndexNotFound is used when no entry is found for an index.
ErrEntryIndexNotFound = errors.New("node index not found in the DB")
ErrNodeKeyAlreadyExists = errors.New("key already exists")
// ErrKeyNotFound is used when a key is not found in the MerkleTree.
ErrKeyNotFound = errors.New("Key not found in the MerkleTree")
// ErrNodeBytesBadSize is used when the data of a node has an incorrect
// size and can't be parsed.
ErrNodeBytesBadSize = errors.New("node data has incorrect size in the DB")
@ -47,8 +47,6 @@ var (
ErrEntryIndexAlreadyExists = errors.New("the entry index already exists in the tree")
// ErrNotWritable is used when the MerkleTree is not writable and a write function is called
ErrNotWritable = errors.New("Merkle Tree not writable")
// ErrKeyNotFound is used when a key is not found in the MerkleTree.
ErrKeyNotFound = errors.New("Key not found in the tree")
rootNodeValue = []byte("currentroot")
// HashZero is used at Empty nodes
HashZero = Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@ -158,7 +156,7 @@ func (mt *MerkleTree) Add(k, v *big.Int) error {
return ErrNotWritable
}
// verfy that the ElemBytes are valid and fit inside the Finite Field.
// verfy that k & v are valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) {
return errors.New("Key not inside the Finite Field")
}
@ -305,6 +303,45 @@ func (mt *MerkleTree) addNode(tx db.Tx, n *Node) (*Hash, error) {
return k, nil
}
// Get returns the value of the leaf for the given key
func (mt *MerkleTree) Get(k *big.Int) (*big.Int, error) {
// verfy that k is valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) {
return nil, errors.New("Key not inside the Finite Field")
}
kHash := NewHashFromBigInt(k)
path := getPath(mt.maxLevels, kHash[:])
nextKey := mt.rootKey
for i := 0; i < mt.maxLevels; i++ {
n, err := mt.GetNode(nextKey)
if err != nil {
return nil, err
}
switch n.Type {
case NodeTypeEmpty:
return nil, ErrKeyNotFound
case NodeTypeLeaf:
if bytes.Equal(kHash[:], n.Entry[0][:]) {
return n.Entry[1].BigInt(), nil
} else {
return nil, ErrKeyNotFound
}
case NodeTypeMiddle:
if path[i] {
nextKey = n.ChildR
} else {
nextKey = n.ChildL
}
default:
return nil, ErrInvalidNodeFound
}
}
return nil, ErrKeyNotFound
}
// Delete removes the specified Key from the MerkleTree and updates the path
// from the deleted key to the Root with the new values. This method removes
// the key from the MerkleTree, but does not remove the old nodes from the
@ -321,7 +358,7 @@ func (mt *MerkleTree) Delete(k *big.Int) error {
return ErrNotWritable
}
// verfy that the ElemBytes are valid and fit inside the Finite Field.
// verfy that k is valid and fit inside the Finite Field.
if !cryptoUtils.CheckBigIntInField(k) {
return errors.New("Key not inside the Finite Field")
}
@ -651,7 +688,7 @@ func (mt *MerkleTree) GenerateProof(k *big.Int, rootKey *Hash) (*Proof, error) {
p.Siblings = append(p.Siblings, siblingKey)
}
}
return nil, ErrEntryIndexNotFound
return nil, ErrKeyNotFound
}
// VerifyProof verifies the Merkle Proof for the entry and root.

+ 25
- 0
merkletree_test.go

@ -111,6 +111,31 @@ func TestAddRepeatedIndex(t *testing.T) {
assert.Equal(t, err, ErrEntryIndexAlreadyExists)
}
func TestGet(t *testing.T) {
mt := newTestingMerkle(t, 140)
defer mt.db.Close()
for i := 0; i < 16; i++ {
k := big.NewInt(int64(i))
v := big.NewInt(int64(i * 2))
if err := mt.Add(k, v); err != nil {
t.Fatal(err)
}
}
v, err := mt.Get(big.NewInt(10))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(20), v)
v, err = mt.Get(big.NewInt(15))
assert.Nil(t, err)
assert.Equal(t, big.NewInt(30), v)
v, err = mt.Get(big.NewInt(16))
assert.NotNil(t, err)
assert.Equal(t, ErrKeyNotFound, err)
assert.Nil(t, v)
}
func TestGenerateAndVerifyProof128(t *testing.T) {
mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
require.Nil(t, err)

Loading…
Cancel
Save