mirror of
https://github.com/arnaucube/arbo.git
synced 2026-01-07 14:31:28 +01:00
Add Tree.Add compatible with circomlib
This commit is contained in:
1
go.mod
1
go.mod
@@ -4,5 +4,6 @@ go 1.14
|
||||
|
||||
require (
|
||||
github.com/iden3/go-iden3-crypto v0.0.6-0.20210308142348-8f85683b2cef
|
||||
github.com/iden3/go-merkletree v0.0.0-20210308143313-8b63ca866189
|
||||
github.com/stretchr/testify v1.7.0
|
||||
)
|
||||
|
||||
352
tree.go
Normal file
352
tree.go
Normal file
@@ -0,0 +1,352 @@
|
||||
/*
|
||||
Package arbo implements a Merkle Tree compatible with the circomlib
|
||||
implementation of the MerkleTree (when using the Poseidon hash function),
|
||||
following the specification from
|
||||
https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
|
||||
https://eprint.iacr.org/2018/955.
|
||||
|
||||
Also allows to define which hash function to use. So for example, when working
|
||||
with zkSnarks the Poseidon hash function can be used, but when not, it can be
|
||||
used the Blake3 hash function, which improves the computation time.
|
||||
*/
|
||||
package arbo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/iden3/go-merkletree/db"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrefixValueLen defines the bytes-prefix length used for the Value
|
||||
// bytes representation stored in the db
|
||||
PrefixValueLen = 2
|
||||
|
||||
// PrefixValueEmpty is used for the first byte of a Value to indicate
|
||||
// that is an Empty value
|
||||
PrefixValueEmpty = 0
|
||||
// PrefixValueLeaf is used for the first byte of a Value to indicate
|
||||
// that is a Leaf value
|
||||
PrefixValueLeaf = 1
|
||||
// PrefixValueIntermediate is used for the first byte of a Value to
|
||||
// indicate that is a Intermediate value
|
||||
PrefixValueIntermediate = 2
|
||||
)
|
||||
|
||||
var (
|
||||
dbKeyRoot = []byte("root")
|
||||
emptyValue = []byte{0}
|
||||
)
|
||||
|
||||
// Tree defines the struct that implements the MerkleTree functionalities
|
||||
type Tree struct {
|
||||
db db.Storage
|
||||
lastAccess int64 // in unix time
|
||||
maxLevels int
|
||||
root []byte
|
||||
|
||||
hashFunction HashFunction
|
||||
}
|
||||
|
||||
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
|
||||
// will load it.
|
||||
func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
|
||||
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
|
||||
|
||||
t.updateAccessTime()
|
||||
root, err := t.db.Get(dbKeyRoot)
|
||||
if err == db.ErrNotFound {
|
||||
// store new root 0
|
||||
tx, err := t.db.NewTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.root = make([]byte, t.hashFunction.Len()) // empty
|
||||
err = tx.Put(dbKeyRoot, t.root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, err
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.root = root
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (t *Tree) updateAccessTime() {
|
||||
atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
|
||||
}
|
||||
|
||||
// LastAccess returns the last access timestamp in Unixtime
|
||||
func (t *Tree) LastAccess() int64 {
|
||||
return atomic.LoadInt64(&t.lastAccess)
|
||||
}
|
||||
|
||||
// Root returns the root of the Tree
|
||||
func (t *Tree) Root() []byte {
|
||||
return t.root
|
||||
}
|
||||
|
||||
// AddBatch adds a batch of key-values to the Tree. This method is optimized to
|
||||
// do some internal parallelization. Returns an array containing the indexes of
|
||||
// the keys failed to add.
|
||||
func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
// Add inserts the key-value into the Tree.
|
||||
// If the inputs come from a *big.Int, is expected that are represented by a
|
||||
// Little-Endian byte array (for circom compatibility).
|
||||
func (t *Tree) Add(k, v []byte) error {
|
||||
// TODO check validity of key & value (for the Tree.HashFunction type)
|
||||
|
||||
keyPath := make([]byte, t.hashFunction.Len())
|
||||
copy(keyPath[:], k)
|
||||
|
||||
path := getPath(t.maxLevels, keyPath)
|
||||
// go down to the leaf
|
||||
var siblings [][]byte
|
||||
_, _, siblings, err := t.down(k, t.root, siblings, path, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
leafKey, leafValue, err := t.newLeafValue(k, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := t.db.NewTx()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Put(leafKey, leafValue); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// go up to the root
|
||||
if len(siblings) == 0 {
|
||||
t.root = leafKey
|
||||
return tx.Commit()
|
||||
}
|
||||
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.root = root
|
||||
// store root to db
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// down goes down to the leaf recursively
|
||||
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) (
|
||||
[]byte, []byte, [][]byte, error) {
|
||||
if l > t.maxLevels-1 {
|
||||
return nil, nil, nil, fmt.Errorf("max level")
|
||||
}
|
||||
var err error
|
||||
var currValue []byte
|
||||
emptyKey := make([]byte, t.hashFunction.Len())
|
||||
if bytes.Equal(currKey, emptyKey) {
|
||||
// empty value
|
||||
return currKey, emptyValue, siblings, nil
|
||||
}
|
||||
currValue, err = t.db.Get(currKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
switch currValue[0] {
|
||||
case PrefixValueEmpty: // empty
|
||||
// TODO WIP WARNING should not be reached, as the 'if' above should avoid
|
||||
// reaching this point
|
||||
// return currKey, empty, siblings, nil
|
||||
panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
|
||||
case PrefixValueLeaf: // leaf
|
||||
if bytes.Equal(newKey, currKey) {
|
||||
return nil, nil, nil, fmt.Errorf("key already exists")
|
||||
}
|
||||
|
||||
if !bytes.Equal(currValue, emptyValue) {
|
||||
oldLeafKey, _ := readLeafValue(currValue)
|
||||
oldLeafKeyFull := make([]byte, t.hashFunction.Len())
|
||||
copy(oldLeafKeyFull[:], oldLeafKey)
|
||||
|
||||
// if currKey is already used, go down until paths diverge
|
||||
oldPath := getPath(t.maxLevels, oldLeafKeyFull)
|
||||
siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
return currKey, currValue, siblings, nil
|
||||
case PrefixValueIntermediate: // intermediate
|
||||
if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
|
||||
return nil, nil, nil,
|
||||
fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
|
||||
PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
|
||||
}
|
||||
// collect siblings while going down
|
||||
if path[l] {
|
||||
// right
|
||||
lChild, rChild := readIntermediateChilds(currValue)
|
||||
siblings = append(siblings, lChild)
|
||||
return t.down(newKey, rChild, siblings, path, l+1)
|
||||
}
|
||||
// left
|
||||
lChild, rChild := readIntermediateChilds(currValue)
|
||||
siblings = append(siblings, rChild)
|
||||
return t.down(newKey, lChild, siblings, path, l+1)
|
||||
default:
|
||||
return nil, nil, nil, fmt.Errorf("invalid value")
|
||||
}
|
||||
}
|
||||
|
||||
// downVirtually is used when in a leaf already exists, and a new leaf which
|
||||
// shares the path until the existing leaf is being added
|
||||
func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
|
||||
newPath []bool, l int) ([][]byte, error) {
|
||||
var err error
|
||||
if l > t.maxLevels-1 {
|
||||
return nil, fmt.Errorf("max virtual level %d", l)
|
||||
}
|
||||
|
||||
if oldPath[l] == newPath[l] {
|
||||
emptyKey := make([]byte, t.hashFunction.Len()) // empty
|
||||
siblings = append(siblings, emptyKey)
|
||||
|
||||
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return siblings, nil
|
||||
}
|
||||
// reached the divergence
|
||||
siblings = append(siblings, oldKey)
|
||||
|
||||
return siblings, nil
|
||||
}
|
||||
|
||||
// up goes up recursively updating the intermediate nodes
|
||||
func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
|
||||
var k, v []byte
|
||||
var err error
|
||||
if path[l] {
|
||||
k, v, err = t.newIntermediate(siblings[l], key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
k, v, err = t.newIntermediate(key, siblings[l])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// store k-v to db
|
||||
err = tx.Put(k, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if l == 0 {
|
||||
// reached the root
|
||||
return k, nil
|
||||
}
|
||||
|
||||
return t.up(tx, k, siblings, path, l-1)
|
||||
}
|
||||
|
||||
func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
|
||||
leafKey, err := t.hashFunction.Hash(k, v, []byte{1})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var leafValue []byte
|
||||
leafValue = append(leafValue, byte(1))
|
||||
leafValue = append(leafValue, byte(len(k)))
|
||||
leafValue = append(leafValue, k...)
|
||||
leafValue = append(leafValue, v...)
|
||||
return leafKey, leafValue, nil
|
||||
}
|
||||
|
||||
func readLeafValue(b []byte) ([]byte, []byte) {
|
||||
if len(b) < PrefixValueLen {
|
||||
return []byte{}, []byte{}
|
||||
}
|
||||
|
||||
kLen := b[1]
|
||||
if len(b) < PrefixValueLen+int(kLen) {
|
||||
return []byte{}, []byte{}
|
||||
}
|
||||
k := b[PrefixValueLen : PrefixValueLen+kLen]
|
||||
v := b[PrefixValueLen+kLen:]
|
||||
return k, v
|
||||
}
|
||||
|
||||
func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
|
||||
b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2)
|
||||
b[0] = 2
|
||||
b[1] = byte(len(l))
|
||||
copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l)
|
||||
copy(b[PrefixValueLen+t.hashFunction.Len():], r)
|
||||
|
||||
key, err := t.hashFunction.Hash(l, r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return key, b, nil
|
||||
}
|
||||
|
||||
func readIntermediateChilds(b []byte) ([]byte, []byte) {
|
||||
if len(b) < PrefixValueLen {
|
||||
return []byte{}, []byte{}
|
||||
}
|
||||
|
||||
lLen := b[1]
|
||||
if len(b) < PrefixValueLen+int(lLen) {
|
||||
return []byte{}, []byte{}
|
||||
}
|
||||
l := b[PrefixValueLen : PrefixValueLen+lLen]
|
||||
r := b[PrefixValueLen+lLen:]
|
||||
return l, r
|
||||
}
|
||||
|
||||
func getPath(numLevels int, k []byte) []bool {
|
||||
path := make([]bool, numLevels)
|
||||
for n := 0; n < numLevels; n++ {
|
||||
path[n] = k[n/8]&(1<<(n%8)) != 0
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// GenProof generates a MerkleTree proof for the given key. If the key exists in
|
||||
// the Tree, the proof will be of existence, if the key does not exist in the
|
||||
// tree, the proof will be of non-existence.
|
||||
func (t *Tree) GenProof(k, v []byte) ([]byte, error) {
|
||||
// unimplemented
|
||||
return nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
// Get returns the value for a given key
|
||||
func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
|
||||
// unimplemented
|
||||
return nil, nil, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
// CheckProof verifies the given proof
|
||||
func CheckProof(k, v, root, mproof []byte) (bool, error) {
|
||||
// unimplemented
|
||||
return false, fmt.Errorf("unimplemented")
|
||||
}
|
||||
|
||||
// TODO method to export & import the full Tree without values
|
||||
176
tree_test.go
Normal file
176
tree_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package arbo
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/iden3/go-merkletree/db/memory"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddTestVectors(t *testing.T) {
|
||||
// Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
|
||||
testVectorsPoseidon := []string{
|
||||
"0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"13578938674299138072471463694055224830892726234048532520316387704878000008795",
|
||||
"5412393676474193513566895793055462193090331607895808993925969873307089394741",
|
||||
"14204494359367183802864593755198662203838502594566452929175967972147978322084",
|
||||
}
|
||||
testAdd(t, HashFunctionPoseidon, testVectorsPoseidon)
|
||||
|
||||
testVectorsSha256 := []string{
|
||||
"0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"46910109172468462938850740851377282682950237270676610513794735904325820156367",
|
||||
"59481735341404520835410489183267411392292882901306595567679529387376287440550",
|
||||
"20573794434149960984975763118181266662429997821552560184909083010514790081771",
|
||||
}
|
||||
testAdd(t, HashFunctionSha256, testVectorsSha256)
|
||||
}
|
||||
|
||||
func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) {
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
|
||||
assert.Nil(t, err)
|
||||
defer tree.db.Close()
|
||||
assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root()))
|
||||
|
||||
err = tree.Add(
|
||||
BigIntToBytes(big.NewInt(1)),
|
||||
BigIntToBytes(big.NewInt(2)))
|
||||
assert.Nil(t, err)
|
||||
rootBI := BytesToBigInt(tree.Root())
|
||||
assert.Equal(t, testVectors[1], rootBI.String())
|
||||
|
||||
err = tree.Add(
|
||||
BigIntToBytes(big.NewInt(33)),
|
||||
BigIntToBytes(big.NewInt(44)))
|
||||
assert.Nil(t, err)
|
||||
rootBI = BytesToBigInt(tree.Root())
|
||||
assert.Equal(t, testVectors[2], rootBI.String())
|
||||
|
||||
err = tree.Add(
|
||||
BigIntToBytes(big.NewInt(1234)),
|
||||
BigIntToBytes(big.NewInt(9876)))
|
||||
assert.Nil(t, err)
|
||||
rootBI = BytesToBigInt(tree.Root())
|
||||
assert.Equal(t, testVectors[3], rootBI.String())
|
||||
}
|
||||
|
||||
func TestAdd1000(t *testing.T) {
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
require.Nil(t, err)
|
||||
|
||||
defer tree.db.Close()
|
||||
for i := 0; i < 1000; i++ {
|
||||
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||
v := BigIntToBytes(big.NewInt(0))
|
||||
if err := tree.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
rootBI := BytesToBigInt(tree.Root())
|
||||
assert.Equal(t,
|
||||
"296519252211642170490407814696803112091039265640052570497930797516015811235",
|
||||
rootBI.String())
|
||||
}
|
||||
|
||||
func TestAddDifferentOrder(t *testing.T) {
|
||||
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
require.Nil(t, err)
|
||||
|
||||
defer tree1.db.Close()
|
||||
for i := 0; i < 16; i++ {
|
||||
k := SwapEndianness(big.NewInt(int64(i)).Bytes())
|
||||
v := SwapEndianness(big.NewInt(0).Bytes())
|
||||
if err := tree1.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
require.Nil(t, err)
|
||||
defer tree2.db.Close()
|
||||
for i := 16 - 1; i >= 0; i-- {
|
||||
k := big.NewInt(int64(i)).Bytes()
|
||||
v := big.NewInt(0).Bytes()
|
||||
if err := tree2.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, hex.EncodeToString(tree1.Root()), hex.EncodeToString(tree2.Root()))
|
||||
assert.Equal(t,
|
||||
"3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f",
|
||||
hex.EncodeToString(tree1.Root()))
|
||||
}
|
||||
|
||||
func TestAddRepeatedIndex(t *testing.T) {
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
require.Nil(t, err)
|
||||
defer tree.db.Close()
|
||||
k := big.NewInt(int64(3)).Bytes()
|
||||
v := big.NewInt(int64(12)).Bytes()
|
||||
if err := tree.Add(k, v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = tree.Add(k, v)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, fmt.Errorf("max virtual level 100"), err)
|
||||
}
|
||||
|
||||
func TestAux(t *testing.T) {
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
|
||||
require.Nil(t, err)
|
||||
defer tree.db.Close()
|
||||
k := BigIntToBytes(big.NewInt(int64(1)))
|
||||
v := BigIntToBytes(big.NewInt(int64(0)))
|
||||
err = tree.Add(k, v)
|
||||
assert.Nil(t, err)
|
||||
k = BigIntToBytes(big.NewInt(int64(256)))
|
||||
err = tree.Add(k, v)
|
||||
assert.Nil(t, err)
|
||||
|
||||
k = BigIntToBytes(big.NewInt(int64(257)))
|
||||
err = tree.Add(k, v)
|
||||
assert.Nil(t, err)
|
||||
|
||||
k = BigIntToBytes(big.NewInt(int64(515)))
|
||||
err = tree.Add(k, v)
|
||||
assert.Nil(t, err)
|
||||
k = BigIntToBytes(big.NewInt(int64(770)))
|
||||
err = tree.Add(k, v)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func BenchmarkAdd(b *testing.B) {
|
||||
// prepare inputs
|
||||
var ks, vs [][]byte
|
||||
for i := 0; i < 1000; i++ {
|
||||
k := BigIntToBytes(big.NewInt(int64(i)))
|
||||
v := BigIntToBytes(big.NewInt(int64(i)))
|
||||
ks = append(ks, k)
|
||||
vs = append(vs, v)
|
||||
}
|
||||
|
||||
b.Run("Poseidon", func(b *testing.B) {
|
||||
benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
|
||||
})
|
||||
b.Run("Sha256", func(b *testing.B) {
|
||||
benchmarkAdd(b, HashFunctionSha256, ks, vs)
|
||||
})
|
||||
}
|
||||
|
||||
func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
|
||||
tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
|
||||
require.Nil(b, err)
|
||||
|
||||
defer tree.db.Close()
|
||||
for i := 0; i < len(ks); i++ {
|
||||
if err := tree.Add(ks[i], vs[i]); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user