Browse Source

Add Tree.Add compatible with circomlib

master
arnaucube 3 years ago
parent
commit
43cb6041c9
3 changed files with 529 additions and 0 deletions
  1. +1
    -0
      go.mod
  2. +352
    -0
      tree.go
  3. +176
    -0
      tree_test.go

+ 1
- 0
go.mod

@ -4,5 +4,6 @@ go 1.14
require ( require (
github.com/iden3/go-iden3-crypto v0.0.6-0.20210308142348-8f85683b2cef github.com/iden3/go-iden3-crypto v0.0.6-0.20210308142348-8f85683b2cef
github.com/iden3/go-merkletree v0.0.0-20210308143313-8b63ca866189
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
) )

+ 352
- 0
tree.go

@ -0,0 +1,352 @@
/*
Package arbo implements a Merkle Tree compatible with the circomlib
implementation of the MerkleTree (when using the Poseidon hash function),
following the specification from
https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and
https://eprint.iacr.org/2018/955.
Also allows to define which hash function to use. So for example, when working
with zkSnarks the Poseidon hash function can be used, but when not, it can be
used the Blake3 hash function, which improves the computation time.
*/
package arbo
import (
"bytes"
"fmt"
"sync/atomic"
"time"
"github.com/iden3/go-merkletree/db"
)
const (
// PrefixValueLen defines the bytes-prefix length used for the Value
// bytes representation stored in the db
PrefixValueLen = 2
// PrefixValueEmpty is used for the first byte of a Value to indicate
// that is an Empty value
PrefixValueEmpty = 0
// PrefixValueLeaf is used for the first byte of a Value to indicate
// that is a Leaf value
PrefixValueLeaf = 1
// PrefixValueIntermediate is used for the first byte of a Value to
// indicate that is a Intermediate value
PrefixValueIntermediate = 2
)
var (
dbKeyRoot = []byte("root")
emptyValue = []byte{0}
)
// Tree defines the struct that implements the MerkleTree functionalities
type Tree struct {
db db.Storage
lastAccess int64 // in unix time
maxLevels int
root []byte
hashFunction HashFunction
}
// NewTree returns a new Tree, if there is a Tree still in the given storage, it
// will load it.
func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) {
t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash}
t.updateAccessTime()
root, err := t.db.Get(dbKeyRoot)
if err == db.ErrNotFound {
// store new root 0
tx, err := t.db.NewTx()
if err != nil {
return nil, err
}
t.root = make([]byte, t.hashFunction.Len()) // empty
err = tx.Put(dbKeyRoot, t.root)
if err != nil {
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, err
}
return &t, err
} else if err != nil {
return nil, err
}
t.root = root
return &t, nil
}
func (t *Tree) updateAccessTime() {
atomic.StoreInt64(&t.lastAccess, time.Now().Unix())
}
// LastAccess returns the last access timestamp in Unixtime
func (t *Tree) LastAccess() int64 {
return atomic.LoadInt64(&t.lastAccess)
}
// Root returns the root of the Tree
func (t *Tree) Root() []byte {
return t.root
}
// AddBatch adds a batch of key-values to the Tree. This method is optimized to
// do some internal parallelization. Returns an array containing the indexes of
// the keys failed to add.
func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) {
return nil, fmt.Errorf("unimplemented")
}
// Add inserts the key-value into the Tree.
// If the inputs come from a *big.Int, is expected that are represented by a
// Little-Endian byte array (for circom compatibility).
func (t *Tree) Add(k, v []byte) error {
// TODO check validity of key & value (for the Tree.HashFunction type)
keyPath := make([]byte, t.hashFunction.Len())
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)
// go down to the leaf
var siblings [][]byte
_, _, siblings, err := t.down(k, t.root, siblings, path, 0)
if err != nil {
return err
}
leafKey, leafValue, err := t.newLeafValue(k, v)
if err != nil {
return err
}
tx, err := t.db.NewTx()
if err != nil {
return err
}
if err := tx.Put(leafKey, leafValue); err != nil {
return err
}
// go up to the root
if len(siblings) == 0 {
t.root = leafKey
return tx.Commit()
}
root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1)
if err != nil {
return err
}
t.root = root
// store root to db
return tx.Commit()
}
// down goes down to the leaf recursively
func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) (
[]byte, []byte, [][]byte, error) {
if l > t.maxLevels-1 {
return nil, nil, nil, fmt.Errorf("max level")
}
var err error
var currValue []byte
emptyKey := make([]byte, t.hashFunction.Len())
if bytes.Equal(currKey, emptyKey) {
// empty value
return currKey, emptyValue, siblings, nil
}
currValue, err = t.db.Get(currKey)
if err != nil {
return nil, nil, nil, err
}
switch currValue[0] {
case PrefixValueEmpty: // empty
// TODO WIP WARNING should not be reached, as the 'if' above should avoid
// reaching this point
// return currKey, empty, siblings, nil
panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP
case PrefixValueLeaf: // leaf
if bytes.Equal(newKey, currKey) {
return nil, nil, nil, fmt.Errorf("key already exists")
}
if !bytes.Equal(currValue, emptyValue) {
oldLeafKey, _ := readLeafValue(currValue)
oldLeafKeyFull := make([]byte, t.hashFunction.Len())
copy(oldLeafKeyFull[:], oldLeafKey)
// if currKey is already used, go down until paths diverge
oldPath := getPath(t.maxLevels, oldLeafKeyFull)
siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l)
if err != nil {
return nil, nil, nil, err
}
}
return currKey, currValue, siblings, nil
case PrefixValueIntermediate: // intermediate
if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 {
return nil, nil, nil,
fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)",
PrefixValueLen+t.hashFunction.Len()*2, len(currValue))
}
// collect siblings while going down
if path[l] {
// right
lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, lChild)
return t.down(newKey, rChild, siblings, path, l+1)
}
// left
lChild, rChild := readIntermediateChilds(currValue)
siblings = append(siblings, rChild)
return t.down(newKey, lChild, siblings, path, l+1)
default:
return nil, nil, nil, fmt.Errorf("invalid value")
}
}
// downVirtually is used when in a leaf already exists, and a new leaf which
// shares the path until the existing leaf is being added
func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath,
newPath []bool, l int) ([][]byte, error) {
var err error
if l > t.maxLevels-1 {
return nil, fmt.Errorf("max virtual level %d", l)
}
if oldPath[l] == newPath[l] {
emptyKey := make([]byte, t.hashFunction.Len()) // empty
siblings = append(siblings, emptyKey)
siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1)
if err != nil {
return nil, err
}
return siblings, nil
}
// reached the divergence
siblings = append(siblings, oldKey)
return siblings, nil
}
// up goes up recursively updating the intermediate nodes
func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) {
var k, v []byte
var err error
if path[l] {
k, v, err = t.newIntermediate(siblings[l], key)
if err != nil {
return nil, err
}
} else {
k, v, err = t.newIntermediate(key, siblings[l])
if err != nil {
return nil, err
}
}
// store k-v to db
err = tx.Put(k, v)
if err != nil {
return nil, err
}
if l == 0 {
// reached the root
return k, nil
}
return t.up(tx, k, siblings, path, l-1)
}
func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) {
leafKey, err := t.hashFunction.Hash(k, v, []byte{1})
if err != nil {
return nil, nil, err
}
var leafValue []byte
leafValue = append(leafValue, byte(1))
leafValue = append(leafValue, byte(len(k)))
leafValue = append(leafValue, k...)
leafValue = append(leafValue, v...)
return leafKey, leafValue, nil
}
func readLeafValue(b []byte) ([]byte, []byte) {
if len(b) < PrefixValueLen {
return []byte{}, []byte{}
}
kLen := b[1]
if len(b) < PrefixValueLen+int(kLen) {
return []byte{}, []byte{}
}
k := b[PrefixValueLen : PrefixValueLen+kLen]
v := b[PrefixValueLen+kLen:]
return k, v
}
func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) {
b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2)
b[0] = 2
b[1] = byte(len(l))
copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l)
copy(b[PrefixValueLen+t.hashFunction.Len():], r)
key, err := t.hashFunction.Hash(l, r)
if err != nil {
return nil, nil, err
}
return key, b, nil
}
func readIntermediateChilds(b []byte) ([]byte, []byte) {
if len(b) < PrefixValueLen {
return []byte{}, []byte{}
}
lLen := b[1]
if len(b) < PrefixValueLen+int(lLen) {
return []byte{}, []byte{}
}
l := b[PrefixValueLen : PrefixValueLen+lLen]
r := b[PrefixValueLen+lLen:]
return l, r
}
func getPath(numLevels int, k []byte) []bool {
path := make([]bool, numLevels)
for n := 0; n < numLevels; n++ {
path[n] = k[n/8]&(1<<(n%8)) != 0
}
return path
}
// GenProof generates a MerkleTree proof for the given key. If the key exists in
// the Tree, the proof will be of existence, if the key does not exist in the
// tree, the proof will be of non-existence.
func (t *Tree) GenProof(k, v []byte) ([]byte, error) {
// unimplemented
return nil, fmt.Errorf("unimplemented")
}
// Get returns the value for a given key
func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
// unimplemented
return nil, nil, fmt.Errorf("unimplemented")
}
// CheckProof verifies the given proof
func CheckProof(k, v, root, mproof []byte) (bool, error) {
// unimplemented
return false, fmt.Errorf("unimplemented")
}
// TODO method to export & import the full Tree without values

+ 176
- 0
tree_test.go

@ -0,0 +1,176 @@
package arbo
import (
"encoding/hex"
"fmt"
"math/big"
"testing"
"github.com/iden3/go-merkletree/db/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddTestVectors(t *testing.T) {
// Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
testVectorsPoseidon := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"13578938674299138072471463694055224830892726234048532520316387704878000008795",
"5412393676474193513566895793055462193090331607895808993925969873307089394741",
"14204494359367183802864593755198662203838502594566452929175967972147978322084",
}
testAdd(t, HashFunctionPoseidon, testVectorsPoseidon)
testVectorsSha256 := []string{
"0000000000000000000000000000000000000000000000000000000000000000",
"46910109172468462938850740851377282682950237270676610513794735904325820156367",
"59481735341404520835410489183267411392292882901306595567679529387376287440550",
"20573794434149960984975763118181266662429997821552560184909083010514790081771",
}
testAdd(t, HashFunctionSha256, testVectorsSha256)
}
func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) {
tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
assert.Nil(t, err)
defer tree.db.Close()
assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root()))
err = tree.Add(
BigIntToBytes(big.NewInt(1)),
BigIntToBytes(big.NewInt(2)))
assert.Nil(t, err)
rootBI := BytesToBigInt(tree.Root())
assert.Equal(t, testVectors[1], rootBI.String())
err = tree.Add(
BigIntToBytes(big.NewInt(33)),
BigIntToBytes(big.NewInt(44)))
assert.Nil(t, err)
rootBI = BytesToBigInt(tree.Root())
assert.Equal(t, testVectors[2], rootBI.String())
err = tree.Add(
BigIntToBytes(big.NewInt(1234)),
BigIntToBytes(big.NewInt(9876)))
assert.Nil(t, err)
rootBI = BytesToBigInt(tree.Root())
assert.Equal(t, testVectors[3], rootBI.String())
}
func TestAdd1000(t *testing.T) {
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree.db.Close()
for i := 0; i < 1000; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(0))
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
}
rootBI := BytesToBigInt(tree.Root())
assert.Equal(t,
"296519252211642170490407814696803112091039265640052570497930797516015811235",
rootBI.String())
}
func TestAddDifferentOrder(t *testing.T) {
tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree1.db.Close()
for i := 0; i < 16; i++ {
k := SwapEndianness(big.NewInt(int64(i)).Bytes())
v := SwapEndianness(big.NewInt(0).Bytes())
if err := tree1.Add(k, v); err != nil {
t.Fatal(err)
}
}
tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree2.db.Close()
for i := 16 - 1; i >= 0; i-- {
k := big.NewInt(int64(i)).Bytes()
v := big.NewInt(0).Bytes()
if err := tree2.Add(k, v); err != nil {
t.Fatal(err)
}
}
assert.Equal(t, hex.EncodeToString(tree1.Root()), hex.EncodeToString(tree2.Root()))
assert.Equal(t,
"3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f",
hex.EncodeToString(tree1.Root()))
}
func TestAddRepeatedIndex(t *testing.T) {
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree.db.Close()
k := big.NewInt(int64(3)).Bytes()
v := big.NewInt(int64(12)).Bytes()
if err := tree.Add(k, v); err != nil {
t.Fatal(err)
}
err = tree.Add(k, v)
assert.NotNil(t, err)
assert.Equal(t, fmt.Errorf("max virtual level 100"), err)
}
func TestAux(t *testing.T) {
tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
require.Nil(t, err)
defer tree.db.Close()
k := BigIntToBytes(big.NewInt(int64(1)))
v := BigIntToBytes(big.NewInt(int64(0)))
err = tree.Add(k, v)
assert.Nil(t, err)
k = BigIntToBytes(big.NewInt(int64(256)))
err = tree.Add(k, v)
assert.Nil(t, err)
k = BigIntToBytes(big.NewInt(int64(257)))
err = tree.Add(k, v)
assert.Nil(t, err)
k = BigIntToBytes(big.NewInt(int64(515)))
err = tree.Add(k, v)
assert.Nil(t, err)
k = BigIntToBytes(big.NewInt(int64(770)))
err = tree.Add(k, v)
assert.Nil(t, err)
}
func BenchmarkAdd(b *testing.B) {
// prepare inputs
var ks, vs [][]byte
for i := 0; i < 1000; i++ {
k := BigIntToBytes(big.NewInt(int64(i)))
v := BigIntToBytes(big.NewInt(int64(i)))
ks = append(ks, k)
vs = append(vs, v)
}
b.Run("Poseidon", func(b *testing.B) {
benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
})
b.Run("Sha256", func(b *testing.B) {
benchmarkAdd(b, HashFunctionSha256, ks, vs)
})
}
func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
require.Nil(b, err)
defer tree.db.Close()
for i := 0; i < len(ks); i++ {
if err := tree.Add(ks[i], vs[i]); err != nil {
b.Fatal(err)
}
}
}

Loading…
Cancel
Save