@ -0,0 +1,352 @@ |
|||||
|
/* |
||||
|
Package arbo implements a Merkle Tree compatible with the circomlib |
||||
|
implementation of the MerkleTree (when using the Poseidon hash function), |
||||
|
following the specification from |
||||
|
https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
|
||||
|
https://eprint.iacr.org/2018/955.
|
||||
|
|
||||
|
Also allows to define which hash function to use. So for example, when working |
||||
|
with zkSnarks the Poseidon hash function can be used, but when not, it can be |
||||
|
used the Blake3 hash function, which improves the computation time. |
||||
|
*/ |
||||
|
package arbo |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"fmt" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/iden3/go-merkletree/db" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
// PrefixValueLen defines the bytes-prefix length used for the Value
|
||||
|
// bytes representation stored in the db
|
||||
|
PrefixValueLen = 2 |
||||
|
|
||||
|
// PrefixValueEmpty is used for the first byte of a Value to indicate
|
||||
|
// that is an Empty value
|
||||
|
PrefixValueEmpty = 0 |
||||
|
// PrefixValueLeaf is used for the first byte of a Value to indicate
|
||||
|
// that is a Leaf value
|
||||
|
PrefixValueLeaf = 1 |
||||
|
// PrefixValueIntermediate is used for the first byte of a Value to
|
||||
|
// indicate that is a Intermediate value
|
||||
|
PrefixValueIntermediate = 2 |
||||
|
) |
||||
|
|
||||
|
var ( |
||||
|
dbKeyRoot = []byte("root") |
||||
|
emptyValue = []byte{0} |
||||
|
) |
||||
|
|
||||
|
// Tree defines the struct that implements the MerkleTree functionalities
|
||||
|
type Tree struct { |
||||
|
db db.Storage |
||||
|
lastAccess int64 // in unix time
|
||||
|
maxLevels int |
||||
|
root []byte |
||||
|
|
||||
|
hashFunction HashFunction |
||||
|
} |
||||
|
|
||||
|
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
|
||||
|
// will load it.
|
||||
|
func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) { |
||||
|
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} |
||||
|
|
||||
|
t.updateAccessTime() |
||||
|
root, err := t.db.Get(dbKeyRoot) |
||||
|
if err == db.ErrNotFound { |
||||
|
// store new root 0
|
||||
|
tx, err := t.db.NewTx() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
t.root = make([]byte, t.hashFunction.Len()) // empty
|
||||
|
err = tx.Put(dbKeyRoot, t.root) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
err = tx.Commit() |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return &t, err |
||||
|
} else if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
t.root = root |
||||
|
return &t, nil |
||||
|
} |
||||
|
|
||||
|
func (t *Tree) updateAccessTime() { |
||||
|
atomic.StoreInt64(&t.lastAccess, time.Now().Unix()) |
||||
|
} |
||||
|
|
||||
|
// LastAccess returns the last access timestamp in Unixtime
|
||||
|
func (t *Tree) LastAccess() int64 { |
||||
|
return atomic.LoadInt64(&t.lastAccess) |
||||
|
} |
||||
|
|
||||
|
// Root returns the root of the Tree
|
||||
|
func (t *Tree) Root() []byte { |
||||
|
return t.root |
||||
|
} |
||||
|
|
||||
|
// AddBatch adds a batch of key-values to the Tree. This method is optimized to
|
||||
|
// do some internal parallelization. Returns an array containing the indexes of
|
||||
|
// the keys failed to add.
|
||||
|
func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { |
||||
|
return nil, fmt.Errorf("unimplemented") |
||||
|
} |
||||
|
|
||||
|
// Add inserts the key-value into the Tree.
|
||||
|
// If the inputs come from a *big.Int, is expected that are represented by a
|
||||
|
// Little-Endian byte array (for circom compatibility).
|
||||
|
func (t *Tree) Add(k, v []byte) error { |
||||
|
// TODO check validity of key & value (for the Tree.HashFunction type)
|
||||
|
|
||||
|
keyPath := make([]byte, t.hashFunction.Len()) |
||||
|
copy(keyPath[:], k) |
||||
|
|
||||
|
path := getPath(t.maxLevels, keyPath) |
||||
|
// go down to the leaf
|
||||
|
var siblings [][]byte |
||||
|
_, _, siblings, err := t.down(k, t.root, siblings, path, 0) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
leafKey, leafValue, err := t.newLeafValue(k, v) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
tx, err := t.db.NewTx() |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
if err := tx.Put(leafKey, leafValue); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
// go up to the root
|
||||
|
if len(siblings) == 0 { |
||||
|
t.root = leafKey |
||||
|
return tx.Commit() |
||||
|
} |
||||
|
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
t.root = root |
||||
|
// store root to db
|
||||
|
return tx.Commit() |
||||
|
} |
||||
|
|
||||
|
// down goes down to the leaf recursively
|
||||
|
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) ( |
||||
|
[]byte, []byte, [][]byte, error) { |
||||
|
if l > t.maxLevels-1 { |
||||
|
return nil, nil, nil, fmt.Errorf("max level") |
||||
|
} |
||||
|
var err error |
||||
|
var currValue []byte |
||||
|
emptyKey := make([]byte, t.hashFunction.Len()) |
||||
|
if bytes.Equal(currKey, emptyKey) { |
||||
|
// empty value
|
||||
|
return currKey, emptyValue, siblings, nil |
||||
|
} |
||||
|
currValue, err = t.db.Get(currKey) |
||||
|
if err != nil { |
||||
|
return nil, nil, nil, err |
||||
|
} |
||||
|
|
||||
|
switch currValue[0] { |
||||
|
case PrefixValueEmpty: // empty
|
||||
|
// TODO WIP WARNING should not be reached, as the 'if' above should avoid
|
||||
|
// reaching this point
|
||||
|
// return currKey, empty, siblings, nil
|
||||
|
panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
|
||||
|
case PrefixValueLeaf: // leaf
|
||||
|
if bytes.Equal(newKey, currKey) { |
||||
|
return nil, nil, nil, fmt.Errorf("key already exists") |
||||
|
} |
||||
|
|
||||
|
if !bytes.Equal(currValue, emptyValue) { |
||||
|
oldLeafKey, _ := readLeafValue(currValue) |
||||
|
oldLeafKeyFull := make([]byte, t.hashFunction.Len()) |
||||
|
copy(oldLeafKeyFull[:], oldLeafKey) |
||||
|
|
||||
|
// if currKey is already used, go down until paths diverge
|
||||
|
oldPath := getPath(t.maxLevels, oldLeafKeyFull) |
||||
|
siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l) |
||||
|
if err != nil { |
||||
|
return nil, nil, nil, err |
||||
|
} |
||||
|
} |
||||
|
return currKey, currValue, siblings, nil |
||||
|
case PrefixValueIntermediate: // intermediate
|
||||
|
if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 { |
||||
|
return nil, nil, nil, |
||||
|
fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)", |
||||
|
PrefixValueLen+t.hashFunction.Len()*2, len(currValue)) |
||||
|
} |
||||
|
// collect siblings while going down
|
||||
|
if path[l] { |
||||
|
// right
|
||||
|
lChild, rChild := readIntermediateChilds(currValue) |
||||
|
siblings = append(siblings, lChild) |
||||
|
return t.down(newKey, rChild, siblings, path, l+1) |
||||
|
} |
||||
|
// left
|
||||
|
lChild, rChild := readIntermediateChilds(currValue) |
||||
|
siblings = append(siblings, rChild) |
||||
|
return t.down(newKey, lChild, siblings, path, l+1) |
||||
|
default: |
||||
|
return nil, nil, nil, fmt.Errorf("invalid value") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// downVirtually is used when in a leaf already exists, and a new leaf which
|
||||
|
// shares the path until the existing leaf is being added
|
||||
|
func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, |
||||
|
newPath []bool, l int) ([][]byte, error) { |
||||
|
var err error |
||||
|
if l > t.maxLevels-1 { |
||||
|
return nil, fmt.Errorf("max virtual level %d", l) |
||||
|
} |
||||
|
|
||||
|
if oldPath[l] == newPath[l] { |
||||
|
emptyKey := make([]byte, t.hashFunction.Len()) // empty
|
||||
|
siblings = append(siblings, emptyKey) |
||||
|
|
||||
|
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return siblings, nil |
||||
|
} |
||||
|
// reached the divergence
|
||||
|
siblings = append(siblings, oldKey) |
||||
|
|
||||
|
return siblings, nil |
||||
|
} |
||||
|
|
||||
|
// up goes up recursively updating the intermediate nodes
|
||||
|
func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) { |
||||
|
var k, v []byte |
||||
|
var err error |
||||
|
if path[l] { |
||||
|
k, v, err = t.newIntermediate(siblings[l], key) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
} else { |
||||
|
k, v, err = t.newIntermediate(key, siblings[l]) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
} |
||||
|
// store k-v to db
|
||||
|
err = tx.Put(k, v) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
if l == 0 { |
||||
|
// reached the root
|
||||
|
return k, nil |
||||
|
} |
||||
|
|
||||
|
return t.up(tx, k, siblings, path, l-1) |
||||
|
} |
||||
|
|
||||
|
func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) { |
||||
|
leafKey, err := t.hashFunction.Hash(k, v, []byte{1}) |
||||
|
if err != nil { |
||||
|
return nil, nil, err |
||||
|
} |
||||
|
var leafValue []byte |
||||
|
leafValue = append(leafValue, byte(1)) |
||||
|
leafValue = append(leafValue, byte(len(k))) |
||||
|
leafValue = append(leafValue, k...) |
||||
|
leafValue = append(leafValue, v...) |
||||
|
return leafKey, leafValue, nil |
||||
|
} |
||||
|
|
||||
|
func readLeafValue(b []byte) ([]byte, []byte) { |
||||
|
if len(b) < PrefixValueLen { |
||||
|
return []byte{}, []byte{} |
||||
|
} |
||||
|
|
||||
|
kLen := b[1] |
||||
|
if len(b) < PrefixValueLen+int(kLen) { |
||||
|
return []byte{}, []byte{} |
||||
|
} |
||||
|
k := b[PrefixValueLen : PrefixValueLen+kLen] |
||||
|
v := b[PrefixValueLen+kLen:] |
||||
|
return k, v |
||||
|
} |
||||
|
|
||||
|
func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) { |
||||
|
b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2) |
||||
|
b[0] = 2 |
||||
|
b[1] = byte(len(l)) |
||||
|
copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l) |
||||
|
copy(b[PrefixValueLen+t.hashFunction.Len():], r) |
||||
|
|
||||
|
key, err := t.hashFunction.Hash(l, r) |
||||
|
if err != nil { |
||||
|
return nil, nil, err |
||||
|
} |
||||
|
|
||||
|
return key, b, nil |
||||
|
} |
||||
|
|
||||
|
func readIntermediateChilds(b []byte) ([]byte, []byte) { |
||||
|
if len(b) < PrefixValueLen { |
||||
|
return []byte{}, []byte{} |
||||
|
} |
||||
|
|
||||
|
lLen := b[1] |
||||
|
if len(b) < PrefixValueLen+int(lLen) { |
||||
|
return []byte{}, []byte{} |
||||
|
} |
||||
|
l := b[PrefixValueLen : PrefixValueLen+lLen] |
||||
|
r := b[PrefixValueLen+lLen:] |
||||
|
return l, r |
||||
|
} |
||||
|
|
||||
|
func getPath(numLevels int, k []byte) []bool { |
||||
|
path := make([]bool, numLevels) |
||||
|
for n := 0; n < numLevels; n++ { |
||||
|
path[n] = k[n/8]&(1<<(n%8)) != 0 |
||||
|
} |
||||
|
return path |
||||
|
} |
||||
|
|
||||
|
// GenProof generates a MerkleTree proof for the given key. If the key exists in
|
||||
|
// the Tree, the proof will be of existence, if the key does not exist in the
|
||||
|
// tree, the proof will be of non-existence.
|
||||
|
func (t *Tree) GenProof(k, v []byte) ([]byte, error) { |
||||
|
// unimplemented
|
||||
|
return nil, fmt.Errorf("unimplemented") |
||||
|
} |
||||
|
|
||||
|
// Get returns the value for a given key
|
||||
|
func (t *Tree) Get(k []byte) ([]byte, []byte, error) { |
||||
|
// unimplemented
|
||||
|
return nil, nil, fmt.Errorf("unimplemented") |
||||
|
} |
||||
|
|
||||
|
// CheckProof verifies the given proof
|
||||
|
func CheckProof(k, v, root, mproof []byte) (bool, error) { |
||||
|
// unimplemented
|
||||
|
return false, fmt.Errorf("unimplemented") |
||||
|
} |
||||
|
|
||||
|
// TODO method to export & import the full Tree without values
|
@ -0,0 +1,176 @@ |
|||||
|
package arbo |
||||
|
|
||||
|
import ( |
||||
|
"encoding/hex" |
||||
|
"fmt" |
||||
|
"math/big" |
||||
|
"testing" |
||||
|
|
||||
|
"github.com/iden3/go-merkletree/db/memory" |
||||
|
"github.com/stretchr/testify/assert" |
||||
|
"github.com/stretchr/testify/require" |
||||
|
) |
||||
|
|
||||
|
func TestAddTestVectors(t *testing.T) { |
||||
|
// Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
|
testVectorsPoseidon := []string{ |
||||
|
"0000000000000000000000000000000000000000000000000000000000000000", |
||||
|
"13578938674299138072471463694055224830892726234048532520316387704878000008795", |
||||
|
"5412393676474193513566895793055462193090331607895808993925969873307089394741", |
||||
|
"14204494359367183802864593755198662203838502594566452929175967972147978322084", |
||||
|
} |
||||
|
testAdd(t, HashFunctionPoseidon, testVectorsPoseidon) |
||||
|
|
||||
|
testVectorsSha256 := []string{ |
||||
|
"0000000000000000000000000000000000000000000000000000000000000000", |
||||
|
"46910109172468462938850740851377282682950237270676610513794735904325820156367", |
||||
|
"59481735341404520835410489183267411392292882901306595567679529387376287440550", |
||||
|
"20573794434149960984975763118181266662429997821552560184909083010514790081771", |
||||
|
} |
||||
|
testAdd(t, HashFunctionSha256, testVectorsSha256) |
||||
|
} |
||||
|
|
||||
|
func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) { |
||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc) |
||||
|
assert.Nil(t, err) |
||||
|
defer tree.db.Close() |
||||
|
assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root())) |
||||
|
|
||||
|
err = tree.Add( |
||||
|
BigIntToBytes(big.NewInt(1)), |
||||
|
BigIntToBytes(big.NewInt(2))) |
||||
|
assert.Nil(t, err) |
||||
|
rootBI := BytesToBigInt(tree.Root()) |
||||
|
assert.Equal(t, testVectors[1], rootBI.String()) |
||||
|
|
||||
|
err = tree.Add( |
||||
|
BigIntToBytes(big.NewInt(33)), |
||||
|
BigIntToBytes(big.NewInt(44))) |
||||
|
assert.Nil(t, err) |
||||
|
rootBI = BytesToBigInt(tree.Root()) |
||||
|
assert.Equal(t, testVectors[2], rootBI.String()) |
||||
|
|
||||
|
err = tree.Add( |
||||
|
BigIntToBytes(big.NewInt(1234)), |
||||
|
BigIntToBytes(big.NewInt(9876))) |
||||
|
assert.Nil(t, err) |
||||
|
rootBI = BytesToBigInt(tree.Root()) |
||||
|
assert.Equal(t, testVectors[3], rootBI.String()) |
||||
|
} |
||||
|
|
||||
|
func TestAdd1000(t *testing.T) { |
||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) |
||||
|
require.Nil(t, err) |
||||
|
|
||||
|
defer tree.db.Close() |
||||
|
for i := 0; i < 1000; i++ { |
||||
|
k := BigIntToBytes(big.NewInt(int64(i))) |
||||
|
v := BigIntToBytes(big.NewInt(0)) |
||||
|
if err := tree.Add(k, v); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
rootBI := BytesToBigInt(tree.Root()) |
||||
|
assert.Equal(t, |
||||
|
"296519252211642170490407814696803112091039265640052570497930797516015811235", |
||||
|
rootBI.String()) |
||||
|
} |
||||
|
|
||||
|
func TestAddDifferentOrder(t *testing.T) { |
||||
|
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) |
||||
|
require.Nil(t, err) |
||||
|
|
||||
|
defer tree1.db.Close() |
||||
|
for i := 0; i < 16; i++ { |
||||
|
k := SwapEndianness(big.NewInt(int64(i)).Bytes()) |
||||
|
v := SwapEndianness(big.NewInt(0).Bytes()) |
||||
|
if err := tree1.Add(k, v); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) |
||||
|
require.Nil(t, err) |
||||
|
defer tree2.db.Close() |
||||
|
for i := 16 - 1; i >= 0; i-- { |
||||
|
k := big.NewInt(int64(i)).Bytes() |
||||
|
v := big.NewInt(0).Bytes() |
||||
|
if err := tree2.Add(k, v); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
assert.Equal(t, hex.EncodeToString(tree1.Root()), hex.EncodeToString(tree2.Root())) |
||||
|
assert.Equal(t, |
||||
|
"3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", |
||||
|
hex.EncodeToString(tree1.Root())) |
||||
|
} |
||||
|
|
||||
|
func TestAddRepeatedIndex(t *testing.T) { |
||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) |
||||
|
require.Nil(t, err) |
||||
|
defer tree.db.Close() |
||||
|
k := big.NewInt(int64(3)).Bytes() |
||||
|
v := big.NewInt(int64(12)).Bytes() |
||||
|
if err := tree.Add(k, v); err != nil { |
||||
|
t.Fatal(err) |
||||
|
} |
||||
|
err = tree.Add(k, v) |
||||
|
assert.NotNil(t, err) |
||||
|
assert.Equal(t, fmt.Errorf("max virtual level 100"), err) |
||||
|
} |
||||
|
|
||||
|
func TestAux(t *testing.T) { |
||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) |
||||
|
require.Nil(t, err) |
||||
|
defer tree.db.Close() |
||||
|
k := BigIntToBytes(big.NewInt(int64(1))) |
||||
|
v := BigIntToBytes(big.NewInt(int64(0))) |
||||
|
err = tree.Add(k, v) |
||||
|
assert.Nil(t, err) |
||||
|
k = BigIntToBytes(big.NewInt(int64(256))) |
||||
|
err = tree.Add(k, v) |
||||
|
assert.Nil(t, err) |
||||
|
|
||||
|
k = BigIntToBytes(big.NewInt(int64(257))) |
||||
|
err = tree.Add(k, v) |
||||
|
assert.Nil(t, err) |
||||
|
|
||||
|
k = BigIntToBytes(big.NewInt(int64(515))) |
||||
|
err = tree.Add(k, v) |
||||
|
assert.Nil(t, err) |
||||
|
k = BigIntToBytes(big.NewInt(int64(770))) |
||||
|
err = tree.Add(k, v) |
||||
|
assert.Nil(t, err) |
||||
|
} |
||||
|
|
||||
|
func BenchmarkAdd(b *testing.B) { |
||||
|
// prepare inputs
|
||||
|
var ks, vs [][]byte |
||||
|
for i := 0; i < 1000; i++ { |
||||
|
k := BigIntToBytes(big.NewInt(int64(i))) |
||||
|
v := BigIntToBytes(big.NewInt(int64(i))) |
||||
|
ks = append(ks, k) |
||||
|
vs = append(vs, v) |
||||
|
} |
||||
|
|
||||
|
b.Run("Poseidon", func(b *testing.B) { |
||||
|
benchmarkAdd(b, HashFunctionPoseidon, ks, vs) |
||||
|
}) |
||||
|
b.Run("Sha256", func(b *testing.B) { |
||||
|
benchmarkAdd(b, HashFunctionSha256, ks, vs) |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) { |
||||
|
tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc) |
||||
|
require.Nil(b, err) |
||||
|
|
||||
|
defer tree.db.Close() |
||||
|
for i := 0; i < len(ks); i++ { |
||||
|
if err := tree.Add(ks[i], vs[i]); err != nil { |
||||
|
b.Fatal(err) |
||||
|
} |
||||
|
} |
||||
|
} |