diff --git a/go.mod b/go.mod index f2b66dc..087372a 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module arbo +module github.com/arnaucube/arbo go 1.14 diff --git a/tree.go b/tree.go index bf59e49..6e8affa 100644 --- a/tree.go +++ b/tree.go @@ -388,6 +388,29 @@ func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { return res } +// UnpackSiblings unpacks the siblings from a byte array. +func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) { + l := b[0] + bitmapBytes := b[1 : 1+l] + bitmap := bytesToBitmap(bitmapBytes) + siblingsBytes := b[1+l:] + iSibl := 0 + emptySibl := make([]byte, hashFunc.Len()) + var siblings [][]byte + for i := 0; i < len(bitmap); i++ { + if iSibl >= len(siblingsBytes) { + break + } + if bitmap[i] { + siblings = append(siblings, siblingsBytes[iSibl:iSibl+hashFunc.Len()]) + iSibl += hashFunc.Len() + } else { + siblings = append(siblings, emptySibl) + } + } + return siblings, nil +} + func bitmapToBytes(bitmap []bool) []byte { bitmapBytesLen := int(math.Ceil(float64(len(bitmap)) / 8)) //nolint:gomnd b := make([]byte, bitmapBytesLen) @@ -399,14 +422,54 @@ func bitmapToBytes(bitmap []bool) []byte { return b } +func bytesToBitmap(b []byte) []bool { + var bitmap []bool + for i := 0; i < len(b); i++ { + for j := 0; j < 8; j++ { + bitmap = append(bitmap, b[i]&(1< 0) + } + } + return bitmap +} + // Get returns the value for a given key func (t *Tree) Get(k []byte) ([]byte, []byte, error) { // unimplemented return nil, nil, fmt.Errorf("unimplemented") } -// CheckProof verifies the given proof -func CheckProof(k, v, root, mproof []byte) (bool, error) { - // unimplemented - return false, fmt.Errorf("unimplemented") +// CheckProof verifies the given proof. The proof verification depends on the +// HashFunction passed as parameter. +func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) { + siblings, err := UnpackSiblings(hashFunc, packedSiblings) + if err != nil { + return false, err + } + + keyPath := make([]byte, hashFunc.Len()) + copy(keyPath[:], k) + + key, _, err := newLeafValue(hashFunc, k, v) + if err != nil { + return false, err + } + + path := getPath(len(siblings), keyPath) + for i := len(siblings) - 1; i >= 0; i-- { + if path[i] { + key, _, err = newIntermediate(hashFunc, siblings[i], key) + if err != nil { + return false, err + } + } else { + key, _, err = newIntermediate(hashFunc, key, siblings[i]) + if err != nil { + return false, err + } + } + } + if bytes.Equal(key[:], root) { + return true, nil + } + return false, nil } diff --git a/tree_test.go b/tree_test.go index db0ac56..0a00326 100644 --- a/tree_test.go +++ b/tree_test.go @@ -145,7 +145,7 @@ func TestAux(t *testing.T) { assert.Nil(t, err) } -func TestGenProof(t *testing.T) { +func TestGenProofAndVerify(t *testing.T) { tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) require.Nil(t, err) @@ -159,8 +159,14 @@ func TestGenProof(t *testing.T) { } k := BigIntToBytes(big.NewInt(int64(7))) - _, err = tree.GenProof(k) + siblings, err := tree.GenProof(k) assert.Nil(t, err) + + k = BigIntToBytes(big.NewInt(int64(7))) + v := BigIntToBytes(big.NewInt(int64(14))) + verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings) + require.Nil(t, err) + assert.True(t, verif) } func BenchmarkAdd(b *testing.B) {