Browse Source

fix AddBatch and add a benchmark

Signed-off-by: p4u <pau@dabax.net>
master
p4u 3 years ago
parent
commit
137219f5f0
2 changed files with 199 additions and 21 deletions
  1. +16
    -21
      censustree.go
  2. +183
    -0
      censustree_test.go

+ 16
- 21
censustree.go

@ -136,8 +136,8 @@ func (t *Tree) Add(index, value []byte) error {
if t.snapshotRoot != nil { if t.snapshotRoot != nil {
return fmt.Errorf("cannot add to a snapshot trie") return fmt.Errorf("cannot add to a snapshot trie")
} }
if len(index) < 4 {
return fmt.Errorf("index too small (%d), minimum size is 4 bytes", len(index))
if len(index) < 4 || len(index) > MaxKeySize {
return fmt.Errorf("wrong key size: %d", len(index))
} }
if len(value) > MaxValueSize { if len(value) > MaxValueSize {
return fmt.Errorf("index or value claim data too big") return fmt.Errorf("index or value claim data too big")
@ -162,11 +162,9 @@ func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) {
if len(values) > 0 && len(indexes) != len(values) { if len(values) > 0 && len(indexes) != len(values) {
return wrongIndexes, fmt.Errorf("indexes and values have different size") return wrongIndexes, fmt.Errorf("indexes and values have different size")
} }
var hashedIndexes [][]byte
var hashedValues [][]byte
var value []byte var value []byte
for i, key := range indexes { for i, key := range indexes {
if len(key) < 4 {
if len(key) < 4 || len(key) > MaxKeySize {
wrongIndexes = append(wrongIndexes, i) wrongIndexes = append(wrongIndexes, i)
continue continue
} }
@ -178,12 +176,10 @@ func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) {
} }
value = values[i] value = values[i]
} }
hashedIndexes = append(hashedIndexes, asmt.Hasher(key))
hashedValues = append(hashedValues, asmt.Hasher(value))
}
_, err := t.Tree.Update(hashedIndexes, hashedValues)
if err != nil {
return wrongIndexes, err
_, err := t.Tree.Update([][]byte{asmt.Hasher(key)}, [][]byte{asmt.Hasher(value)})
if err != nil {
return wrongIndexes, err
}
} }
atomic.StoreUint64(&t.size, 0) // TBD: improve this atomic.StoreUint64(&t.size, 0) // TBD: improve this
return wrongIndexes, t.Commit() return wrongIndexes, t.Commit()
@ -205,28 +201,27 @@ func (t *Tree) GenProof(index, value []byte) ([]byte, error) {
t.updateAccessTime() t.updateAccessTime()
var err error var err error
var ap [][]byte var ap [][]byte
var pvalue []byte
var bitmap []byte
var pvalue, bitmap []byte
var length int var length int
var included bool
key := asmt.Hasher(index)
if t.snapshotRoot != nil { if t.snapshotRoot != nil {
bitmap, ap, length, _, _, pvalue, err = t.Tree.MerkleProofCompressedR(
asmt.Hasher(index),
bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressedR(key,
t.snapshotRoot) t.snapshotRoot)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
bitmap, ap, length, _, _, pvalue, err = t.Tree.MerkleProofCompressed(
asmt.Hasher(index))
bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressed(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
//if !included {
// return nil, fmt.Errorf("not included")
//}
if !included {
return nil, nil
}
if !bytes.Equal(pvalue, asmt.Hasher(value)) { if !bytes.Equal(pvalue, asmt.Hasher(value)) {
return nil, fmt.Errorf("incorrect value on genProof")
return nil, fmt.Errorf("incorrect value or key on genProof")
} }
return bare.Marshal(&Proof{Bitmap: bitmap, Length: length, Siblings: ap, Value: pvalue}) return bare.Marshal(&Proof{Bitmap: bitmap, Length: length, Siblings: ap, Value: pvalue})
} }

+ 183
- 0
censustree_test.go

@ -2,8 +2,10 @@ package asmtree
import ( import (
"bytes" "bytes"
"crypto/rand"
"fmt" "fmt"
"testing" "testing"
"time"
) )
func TestTree(t *testing.T) { func TestTree(t *testing.T) {
@ -86,5 +88,186 @@ func TestTree(t *testing.T) {
if !valid { if !valid {
t.Errorf("proof is invalid on tree2") t.Errorf("proof is invalid on tree2")
} }
}
func TestProofs(t *testing.T) {
censusSize := 10000
storage := t.TempDir()
tr1 := &Tree{}
err := tr1.Init("test1", storage)
if err != nil {
t.Fatal(err)
}
var keys, values [][]byte
for i := 0; i < censusSize; i++ {
keys = append(keys, RandomBytes(32))
values = append(values, RandomBytes(32))
}
i := 0
for i < censusSize-200 {
if fail, err := tr1.AddBatch(keys[i:i+200], values[i:i+200]); err != nil {
t.Fatal(err)
} else if len(fail) > 0 {
t.Fatalf("some keys failed to add on addBatch: %v", fail)
}
i += 200
}
// Add remaining claims (if size%200 != 0)
if i < censusSize {
if fail, err := tr1.AddBatch(keys[i:], values[i:]); err != nil {
t.Fatal(err)
} else if len(fail) > 0 {
t.Fatalf("some keys failed to add on addBatch: %v", fail)
}
}
root1 := tr1.Root()
data, err := tr1.Dump(root1)
if err != nil {
t.Fatal(err)
}
t.Logf("dumped data size is: %d bytes", len(data))
// Get the size
s, err := tr1.Size(nil)
if err != nil {
t.Errorf("cannot get te size: %v", err)
}
if s != int64(censusSize) {
t.Errorf("size is wrong (have %d, expexted %d)", s, censusSize)
}
// Generate a proofs
time.Sleep(5 * time.Second)
proofs := [][]byte{}
for i := 0; i < censusSize; i++ {
p, err := tr1.GenProof(keys[i], values[i])
if err != nil {
t.Fatal(err)
}
if len(p) == 0 {
t.Fatal("proof not generated")
}
proofs = append(proofs, p)
}
// Check proofs
for i := 0; i < censusSize; i++ {
valid, err := tr1.CheckProof(keys[i], values[i], root1, proofs[i])
if err != nil {
t.Fatal(err)
}
if !valid {
t.Errorf("proof %d is invalid", i)
}
}
} }
// go test -v -run=- -bench=Tree -benchtime=30s .
func BenchmarkTree(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
// Create websocket client
for pb.Next() {
benchProofs(b, 100000)
}
})
}
func benchProofs(b *testing.B, censusSize int) {
totalTimer := time.Now()
storage := b.TempDir()
tr1 := &Tree{}
err := tr1.Init("test1", storage)
if err != nil {
b.Fatal(err)
}
var keys, values [][]byte
for i := 0; i < censusSize; i++ {
keys = append(keys, RandomBytes(32))
values = append(values, RandomBytes(32))
}
timer := time.Now()
i := 0
for i < censusSize-200 {
if fail, err := tr1.AddBatch(keys[i:i+200], values[i:i+200]); err != nil {
b.Fatal(err)
} else if len(fail) > 0 {
b.Fatalf("some keys failed to add on addBatch: %v", fail)
}
i += 200
}
// Add remaining claims (if size%200 != 0)
if i < censusSize {
if fail, err := tr1.AddBatch(keys[i:], values[i:]); err != nil {
b.Fatal(err)
} else if len(fail) > 0 {
b.Fatalf("some keys failed to add on addBatch: %v", fail)
}
}
b.Logf("addBatch took %d ms", time.Since(timer).Milliseconds())
timer = time.Now()
root1 := tr1.Root()
data, err := tr1.Dump(root1)
if err != nil {
b.Fatal(err)
}
b.Logf("dumped data size is: %d bytes", len(data))
b.Logf("dump took %d ms", time.Since(timer).Milliseconds())
// Get the size
s, err := tr1.Size(nil)
if err != nil {
b.Errorf("cannot get te size: %v", err)
}
if s != int64(censusSize) {
b.Errorf("size is wrong (have %d, expexted %d)", s, censusSize)
}
// Generate a proofs
timer = time.Now()
time.Sleep(5 * time.Second)
proofs := [][]byte{}
for i := 0; i < censusSize; i++ {
p, err := tr1.GenProof(keys[i], values[i])
if err != nil {
b.Fatal(err)
}
if len(p) == 0 {
b.Fatal("proof not generated")
}
proofs = append(proofs, p)
}
b.Logf("gen proofs took %d ms", time.Since(timer).Milliseconds())
// Check proofs
timer = time.Now()
for i := 0; i < censusSize; i++ {
valid, err := tr1.CheckProof(keys[i], values[i], root1, proofs[i])
if err != nil {
b.Fatal(err)
}
if !valid {
b.Errorf("proof %d is invalid", i)
}
}
b.Logf("check proofs took %d ms", time.Since(timer).Milliseconds())
b.Logf("[finished] %d ms", time.Since(totalTimer).Milliseconds())
}
func RandomBytes(n int) []byte {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
panic(err)
}
return b
}

Loading…
Cancel
Save