From 43cb6041c9802097607b8bf867804399339e8f2c Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 30 Mar 2021 22:37:56 +0200 Subject: [PATCH] Add Tree.Add compatible with circomlib --- go.mod | 1 + tree.go | 352 +++++++++++++++++++++++++++++++++++++++++++++++++++ tree_test.go | 176 ++++++++++++++++++++++++++ 3 files changed, 529 insertions(+) create mode 100644 tree.go create mode 100644 tree_test.go diff --git a/go.mod b/go.mod index acc0a1b..f2b66dc 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.14 require ( github.com/iden3/go-iden3-crypto v0.0.6-0.20210308142348-8f85683b2cef + github.com/iden3/go-merkletree v0.0.0-20210308143313-8b63ca866189 github.com/stretchr/testify v1.7.0 ) diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..7df1f91 --- /dev/null +++ b/tree.go @@ -0,0 +1,352 @@ +/* +Package arbo implements a Merkle Tree compatible with the circomlib +implementation of the MerkleTree (when using the Poseidon hash function), +following the specification from +https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf and +https://eprint.iacr.org/2018/955. + +Also allows to define which hash function to use. So for example, when working +with zkSnarks the Poseidon hash function can be used, but when not, it can be +used the Blake3 hash function, which improves the computation time. +*/ +package arbo + +import ( + "bytes" + "fmt" + "sync/atomic" + "time" + + "github.com/iden3/go-merkletree/db" +) + +const ( + // PrefixValueLen defines the bytes-prefix length used for the Value + // bytes representation stored in the db + PrefixValueLen = 2 + + // PrefixValueEmpty is used for the first byte of a Value to indicate + // that is an Empty value + PrefixValueEmpty = 0 + // PrefixValueLeaf is used for the first byte of a Value to indicate + // that is a Leaf value + PrefixValueLeaf = 1 + // PrefixValueIntermediate is used for the first byte of a Value to + // indicate that is a Intermediate value + PrefixValueIntermediate = 2 +) + +var ( + dbKeyRoot = []byte("root") + emptyValue = []byte{0} +) + +// Tree defines the struct that implements the MerkleTree functionalities +type Tree struct { + db db.Storage + lastAccess int64 // in unix time + maxLevels int + root []byte + + hashFunction HashFunction +} + +// NewTree returns a new Tree, if there is a Tree still in the given storage, it +// will load it. +func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) { + t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} + + t.updateAccessTime() + root, err := t.db.Get(dbKeyRoot) + if err == db.ErrNotFound { + // store new root 0 + tx, err := t.db.NewTx() + if err != nil { + return nil, err + } + t.root = make([]byte, t.hashFunction.Len()) // empty + err = tx.Put(dbKeyRoot, t.root) + if err != nil { + return nil, err + } + err = tx.Commit() + if err != nil { + return nil, err + } + return &t, err + } else if err != nil { + return nil, err + } + t.root = root + return &t, nil +} + +func (t *Tree) updateAccessTime() { + atomic.StoreInt64(&t.lastAccess, time.Now().Unix()) +} + +// LastAccess returns the last access timestamp in Unixtime +func (t *Tree) LastAccess() int64 { + return atomic.LoadInt64(&t.lastAccess) +} + +// Root returns the root of the Tree +func (t *Tree) Root() []byte { + return t.root +} + +// AddBatch adds a batch of key-values to the Tree. This method is optimized to +// do some internal parallelization. Returns an array containing the indexes of +// the keys failed to add. +func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { + return nil, fmt.Errorf("unimplemented") +} + +// Add inserts the key-value into the Tree. +// If the inputs come from a *big.Int, is expected that are represented by a +// Little-Endian byte array (for circom compatibility). +func (t *Tree) Add(k, v []byte) error { + // TODO check validity of key & value (for the Tree.HashFunction type) + + keyPath := make([]byte, t.hashFunction.Len()) + copy(keyPath[:], k) + + path := getPath(t.maxLevels, keyPath) + // go down to the leaf + var siblings [][]byte + _, _, siblings, err := t.down(k, t.root, siblings, path, 0) + if err != nil { + return err + } + + leafKey, leafValue, err := t.newLeafValue(k, v) + if err != nil { + return err + } + + tx, err := t.db.NewTx() + if err != nil { + return err + } + if err := tx.Put(leafKey, leafValue); err != nil { + return err + } + + // go up to the root + if len(siblings) == 0 { + t.root = leafKey + return tx.Commit() + } + root, err := t.up(tx, leafKey, siblings, path, len(siblings)-1) + if err != nil { + return err + } + + t.root = root + // store root to db + return tx.Commit() +} + +// down goes down to the leaf recursively +func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, l int) ( + []byte, []byte, [][]byte, error) { + if l > t.maxLevels-1 { + return nil, nil, nil, fmt.Errorf("max level") + } + var err error + var currValue []byte + emptyKey := make([]byte, t.hashFunction.Len()) + if bytes.Equal(currKey, emptyKey) { + // empty value + return currKey, emptyValue, siblings, nil + } + currValue, err = t.db.Get(currKey) + if err != nil { + return nil, nil, nil, err + } + + switch currValue[0] { + case PrefixValueEmpty: // empty + // TODO WIP WARNING should not be reached, as the 'if' above should avoid + // reaching this point + // return currKey, empty, siblings, nil + panic("should not be reached, as the 'if' above should avoid reaching this point") // TMP + case PrefixValueLeaf: // leaf + if bytes.Equal(newKey, currKey) { + return nil, nil, nil, fmt.Errorf("key already exists") + } + + if !bytes.Equal(currValue, emptyValue) { + oldLeafKey, _ := readLeafValue(currValue) + oldLeafKeyFull := make([]byte, t.hashFunction.Len()) + copy(oldLeafKeyFull[:], oldLeafKey) + + // if currKey is already used, go down until paths diverge + oldPath := getPath(t.maxLevels, oldLeafKeyFull) + siblings, err = t.downVirtually(siblings, currKey, newKey, oldPath, path, l) + if err != nil { + return nil, nil, nil, err + } + } + return currKey, currValue, siblings, nil + case PrefixValueIntermediate: // intermediate + if len(currValue) != PrefixValueLen+t.hashFunction.Len()*2 { + return nil, nil, nil, + fmt.Errorf("intermediate value invalid length (expected: %d, actual: %d)", + PrefixValueLen+t.hashFunction.Len()*2, len(currValue)) + } + // collect siblings while going down + if path[l] { + // right + lChild, rChild := readIntermediateChilds(currValue) + siblings = append(siblings, lChild) + return t.down(newKey, rChild, siblings, path, l+1) + } + // left + lChild, rChild := readIntermediateChilds(currValue) + siblings = append(siblings, rChild) + return t.down(newKey, lChild, siblings, path, l+1) + default: + return nil, nil, nil, fmt.Errorf("invalid value") + } +} + +// downVirtually is used when in a leaf already exists, and a new leaf which +// shares the path until the existing leaf is being added +func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, + newPath []bool, l int) ([][]byte, error) { + var err error + if l > t.maxLevels-1 { + return nil, fmt.Errorf("max virtual level %d", l) + } + + if oldPath[l] == newPath[l] { + emptyKey := make([]byte, t.hashFunction.Len()) // empty + siblings = append(siblings, emptyKey) + + siblings, err = t.downVirtually(siblings, oldKey, newKey, oldPath, newPath, l+1) + if err != nil { + return nil, err + } + return siblings, nil + } + // reached the divergence + siblings = append(siblings, oldKey) + + return siblings, nil +} + +// up goes up recursively updating the intermediate nodes +func (t *Tree) up(tx db.Tx, key []byte, siblings [][]byte, path []bool, l int) ([]byte, error) { + var k, v []byte + var err error + if path[l] { + k, v, err = t.newIntermediate(siblings[l], key) + if err != nil { + return nil, err + } + } else { + k, v, err = t.newIntermediate(key, siblings[l]) + if err != nil { + return nil, err + } + } + // store k-v to db + err = tx.Put(k, v) + if err != nil { + return nil, err + } + + if l == 0 { + // reached the root + return k, nil + } + + return t.up(tx, k, siblings, path, l-1) +} + +func (t *Tree) newLeafValue(k, v []byte) ([]byte, []byte, error) { + leafKey, err := t.hashFunction.Hash(k, v, []byte{1}) + if err != nil { + return nil, nil, err + } + var leafValue []byte + leafValue = append(leafValue, byte(1)) + leafValue = append(leafValue, byte(len(k))) + leafValue = append(leafValue, k...) + leafValue = append(leafValue, v...) + return leafKey, leafValue, nil +} + +func readLeafValue(b []byte) ([]byte, []byte) { + if len(b) < PrefixValueLen { + return []byte{}, []byte{} + } + + kLen := b[1] + if len(b) < PrefixValueLen+int(kLen) { + return []byte{}, []byte{} + } + k := b[PrefixValueLen : PrefixValueLen+kLen] + v := b[PrefixValueLen+kLen:] + return k, v +} + +func (t *Tree) newIntermediate(l, r []byte) ([]byte, []byte, error) { + b := make([]byte, PrefixValueLen+t.hashFunction.Len()*2) + b[0] = 2 + b[1] = byte(len(l)) + copy(b[PrefixValueLen:PrefixValueLen+t.hashFunction.Len()], l) + copy(b[PrefixValueLen+t.hashFunction.Len():], r) + + key, err := t.hashFunction.Hash(l, r) + if err != nil { + return nil, nil, err + } + + return key, b, nil +} + +func readIntermediateChilds(b []byte) ([]byte, []byte) { + if len(b) < PrefixValueLen { + return []byte{}, []byte{} + } + + lLen := b[1] + if len(b) < PrefixValueLen+int(lLen) { + return []byte{}, []byte{} + } + l := b[PrefixValueLen : PrefixValueLen+lLen] + r := b[PrefixValueLen+lLen:] + return l, r +} + +func getPath(numLevels int, k []byte) []bool { + path := make([]bool, numLevels) + for n := 0; n < numLevels; n++ { + path[n] = k[n/8]&(1<<(n%8)) != 0 + } + return path +} + +// GenProof generates a MerkleTree proof for the given key. If the key exists in +// the Tree, the proof will be of existence, if the key does not exist in the +// tree, the proof will be of non-existence. +func (t *Tree) GenProof(k, v []byte) ([]byte, error) { + // unimplemented + return nil, fmt.Errorf("unimplemented") +} + +// Get returns the value for a given key +func (t *Tree) Get(k []byte) ([]byte, []byte, error) { + // unimplemented + return nil, nil, fmt.Errorf("unimplemented") +} + +// CheckProof verifies the given proof +func CheckProof(k, v, root, mproof []byte) (bool, error) { + // unimplemented + return false, fmt.Errorf("unimplemented") +} + +// TODO method to export & import the full Tree without values diff --git a/tree_test.go b/tree_test.go new file mode 100644 index 0000000..672a674 --- /dev/null +++ b/tree_test.go @@ -0,0 +1,176 @@ +package arbo + +import ( + "encoding/hex" + "fmt" + "math/big" + "testing" + + "github.com/iden3/go-merkletree/db/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddTestVectors(t *testing.T) { + // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js + testVectorsPoseidon := []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "13578938674299138072471463694055224830892726234048532520316387704878000008795", + "5412393676474193513566895793055462193090331607895808993925969873307089394741", + "14204494359367183802864593755198662203838502594566452929175967972147978322084", + } + testAdd(t, HashFunctionPoseidon, testVectorsPoseidon) + + testVectorsSha256 := []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "46910109172468462938850740851377282682950237270676610513794735904325820156367", + "59481735341404520835410489183267411392292882901306595567679529387376287440550", + "20573794434149960984975763118181266662429997821552560184909083010514790081771", + } + testAdd(t, HashFunctionSha256, testVectorsSha256) +} + +func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) { + tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc) + assert.Nil(t, err) + defer tree.db.Close() + assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root())) + + err = tree.Add( + BigIntToBytes(big.NewInt(1)), + BigIntToBytes(big.NewInt(2))) + assert.Nil(t, err) + rootBI := BytesToBigInt(tree.Root()) + assert.Equal(t, testVectors[1], rootBI.String()) + + err = tree.Add( + BigIntToBytes(big.NewInt(33)), + BigIntToBytes(big.NewInt(44))) + assert.Nil(t, err) + rootBI = BytesToBigInt(tree.Root()) + assert.Equal(t, testVectors[2], rootBI.String()) + + err = tree.Add( + BigIntToBytes(big.NewInt(1234)), + BigIntToBytes(big.NewInt(9876))) + assert.Nil(t, err) + rootBI = BytesToBigInt(tree.Root()) + assert.Equal(t, testVectors[3], rootBI.String()) +} + +func TestAdd1000(t *testing.T) { + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + + defer tree.db.Close() + for i := 0; i < 1000; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(0)) + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + } + + rootBI := BytesToBigInt(tree.Root()) + assert.Equal(t, + "296519252211642170490407814696803112091039265640052570497930797516015811235", + rootBI.String()) +} + +func TestAddDifferentOrder(t *testing.T) { + tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + + defer tree1.db.Close() + for i := 0; i < 16; i++ { + k := SwapEndianness(big.NewInt(int64(i)).Bytes()) + v := SwapEndianness(big.NewInt(0).Bytes()) + if err := tree1.Add(k, v); err != nil { + t.Fatal(err) + } + } + + tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + defer tree2.db.Close() + for i := 16 - 1; i >= 0; i-- { + k := big.NewInt(int64(i)).Bytes() + v := big.NewInt(0).Bytes() + if err := tree2.Add(k, v); err != nil { + t.Fatal(err) + } + } + + assert.Equal(t, hex.EncodeToString(tree1.Root()), hex.EncodeToString(tree2.Root())) + assert.Equal(t, + "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", + hex.EncodeToString(tree1.Root())) +} + +func TestAddRepeatedIndex(t *testing.T) { + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + defer tree.db.Close() + k := big.NewInt(int64(3)).Bytes() + v := big.NewInt(int64(12)).Bytes() + if err := tree.Add(k, v); err != nil { + t.Fatal(err) + } + err = tree.Add(k, v) + assert.NotNil(t, err) + assert.Equal(t, fmt.Errorf("max virtual level 100"), err) +} + +func TestAux(t *testing.T) { + tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) + require.Nil(t, err) + defer tree.db.Close() + k := BigIntToBytes(big.NewInt(int64(1))) + v := BigIntToBytes(big.NewInt(int64(0))) + err = tree.Add(k, v) + assert.Nil(t, err) + k = BigIntToBytes(big.NewInt(int64(256))) + err = tree.Add(k, v) + assert.Nil(t, err) + + k = BigIntToBytes(big.NewInt(int64(257))) + err = tree.Add(k, v) + assert.Nil(t, err) + + k = BigIntToBytes(big.NewInt(int64(515))) + err = tree.Add(k, v) + assert.Nil(t, err) + k = BigIntToBytes(big.NewInt(int64(770))) + err = tree.Add(k, v) + assert.Nil(t, err) +} + +func BenchmarkAdd(b *testing.B) { + // prepare inputs + var ks, vs [][]byte + for i := 0; i < 1000; i++ { + k := BigIntToBytes(big.NewInt(int64(i))) + v := BigIntToBytes(big.NewInt(int64(i))) + ks = append(ks, k) + vs = append(vs, v) + } + + b.Run("Poseidon", func(b *testing.B) { + benchmarkAdd(b, HashFunctionPoseidon, ks, vs) + }) + b.Run("Sha256", func(b *testing.B) { + benchmarkAdd(b, HashFunctionSha256, ks, vs) + }) +} + +func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) { + tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc) + require.Nil(b, err) + + defer tree.db.Close() + for i := 0; i < len(ks); i++ { + if err := tree.Add(ks[i], vs[i]); err != nil { + b.Fatal(err) + } + } +}