diff --git a/censustree.go b/censustree.go index c27d0d6..a83a155 100644 --- a/censustree.go +++ b/censustree.go @@ -136,8 +136,8 @@ func (t *Tree) Add(index, value []byte) error { if t.snapshotRoot != nil { return fmt.Errorf("cannot add to a snapshot trie") } - if len(index) < 4 { - return fmt.Errorf("index too small (%d), minimum size is 4 bytes", len(index)) + if len(index) < 4 || len(index) > MaxKeySize { + return fmt.Errorf("wrong key size: %d", len(index)) } if len(value) > MaxValueSize { return fmt.Errorf("index or value claim data too big") @@ -162,11 +162,9 @@ func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) { if len(values) > 0 && len(indexes) != len(values) { return wrongIndexes, fmt.Errorf("indexes and values have different size") } - var hashedIndexes [][]byte - var hashedValues [][]byte var value []byte for i, key := range indexes { - if len(key) < 4 { + if len(key) < 4 || len(key) > MaxKeySize { wrongIndexes = append(wrongIndexes, i) continue } @@ -178,12 +176,10 @@ func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) { } value = values[i] } - hashedIndexes = append(hashedIndexes, asmt.Hasher(key)) - hashedValues = append(hashedValues, asmt.Hasher(value)) - } - _, err := t.Tree.Update(hashedIndexes, hashedValues) - if err != nil { - return wrongIndexes, err + _, err := t.Tree.Update([][]byte{asmt.Hasher(key)}, [][]byte{asmt.Hasher(value)}) + if err != nil { + return wrongIndexes, err + } } atomic.StoreUint64(&t.size, 0) // TBD: improve this return wrongIndexes, t.Commit() @@ -205,28 +201,27 @@ func (t *Tree) GenProof(index, value []byte) ([]byte, error) { t.updateAccessTime() var err error var ap [][]byte - var pvalue []byte - var bitmap []byte + var pvalue, bitmap []byte var length int + var included bool + key := asmt.Hasher(index) if t.snapshotRoot != nil { - bitmap, ap, length, _, _, pvalue, err = t.Tree.MerkleProofCompressedR( - asmt.Hasher(index), + bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressedR(key, t.snapshotRoot) if err != nil { return nil, err } } else { - bitmap, ap, length, _, _, pvalue, err = t.Tree.MerkleProofCompressed( - asmt.Hasher(index)) + bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressed(key) if err != nil { return nil, err } } - //if !included { - // return nil, fmt.Errorf("not included") - //} + if !included { + return nil, nil + } if !bytes.Equal(pvalue, asmt.Hasher(value)) { - return nil, fmt.Errorf("incorrect value on genProof") + return nil, fmt.Errorf("incorrect value or key on genProof") } return bare.Marshal(&Proof{Bitmap: bitmap, Length: length, Siblings: ap, Value: pvalue}) } diff --git a/censustree_test.go b/censustree_test.go index 0746f79..b09bea4 100644 --- a/censustree_test.go +++ b/censustree_test.go @@ -2,8 +2,10 @@ package asmtree import ( "bytes" + "crypto/rand" "fmt" "testing" + "time" ) func TestTree(t *testing.T) { @@ -86,5 +88,186 @@ func TestTree(t *testing.T) { if !valid { t.Errorf("proof is invalid on tree2") } +} + +func TestProofs(t *testing.T) { + censusSize := 10000 + storage := t.TempDir() + tr1 := &Tree{} + err := tr1.Init("test1", storage) + if err != nil { + t.Fatal(err) + } + + var keys, values [][]byte + for i := 0; i < censusSize; i++ { + keys = append(keys, RandomBytes(32)) + values = append(values, RandomBytes(32)) + } + + i := 0 + for i < censusSize-200 { + if fail, err := tr1.AddBatch(keys[i:i+200], values[i:i+200]); err != nil { + t.Fatal(err) + } else if len(fail) > 0 { + t.Fatalf("some keys failed to add on addBatch: %v", fail) + } + i += 200 + } + // Add remaining claims (if size%200 != 0) + if i < censusSize { + if fail, err := tr1.AddBatch(keys[i:], values[i:]); err != nil { + t.Fatal(err) + } else if len(fail) > 0 { + t.Fatalf("some keys failed to add on addBatch: %v", fail) + } + } + + root1 := tr1.Root() + data, err := tr1.Dump(root1) + if err != nil { + t.Fatal(err) + } + t.Logf("dumped data size is: %d bytes", len(data)) + + // Get the size + s, err := tr1.Size(nil) + if err != nil { + t.Errorf("cannot get te size: %v", err) + } + if s != int64(censusSize) { + t.Errorf("size is wrong (have %d, expexted %d)", s, censusSize) + } + + // Generate a proofs + time.Sleep(5 * time.Second) + proofs := [][]byte{} + for i := 0; i < censusSize; i++ { + p, err := tr1.GenProof(keys[i], values[i]) + if err != nil { + t.Fatal(err) + } + if len(p) == 0 { + t.Fatal("proof not generated") + } + proofs = append(proofs, p) + } + + // Check proofs + for i := 0; i < censusSize; i++ { + valid, err := tr1.CheckProof(keys[i], values[i], root1, proofs[i]) + if err != nil { + t.Fatal(err) + } + if !valid { + t.Errorf("proof %d is invalid", i) + } + } } + +// go test -v -run=- -bench=Tree -benchtime=30s . +func BenchmarkTree(b *testing.B) { + b.ReportAllocs() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Create websocket client + for pb.Next() { + benchProofs(b, 100000) + } + }) +} + +func benchProofs(b *testing.B, censusSize int) { + totalTimer := time.Now() + storage := b.TempDir() + tr1 := &Tree{} + err := tr1.Init("test1", storage) + if err != nil { + b.Fatal(err) + } + + var keys, values [][]byte + for i := 0; i < censusSize; i++ { + keys = append(keys, RandomBytes(32)) + values = append(values, RandomBytes(32)) + } + + timer := time.Now() + i := 0 + for i < censusSize-200 { + if fail, err := tr1.AddBatch(keys[i:i+200], values[i:i+200]); err != nil { + b.Fatal(err) + } else if len(fail) > 0 { + b.Fatalf("some keys failed to add on addBatch: %v", fail) + } + i += 200 + } + // Add remaining claims (if size%200 != 0) + if i < censusSize { + if fail, err := tr1.AddBatch(keys[i:], values[i:]); err != nil { + b.Fatal(err) + } else if len(fail) > 0 { + b.Fatalf("some keys failed to add on addBatch: %v", fail) + } + } + b.Logf("addBatch took %d ms", time.Since(timer).Milliseconds()) + + timer = time.Now() + root1 := tr1.Root() + data, err := tr1.Dump(root1) + if err != nil { + b.Fatal(err) + } + b.Logf("dumped data size is: %d bytes", len(data)) + b.Logf("dump took %d ms", time.Since(timer).Milliseconds()) + + // Get the size + s, err := tr1.Size(nil) + if err != nil { + b.Errorf("cannot get te size: %v", err) + } + if s != int64(censusSize) { + b.Errorf("size is wrong (have %d, expexted %d)", s, censusSize) + } + + // Generate a proofs + timer = time.Now() + time.Sleep(5 * time.Second) + proofs := [][]byte{} + for i := 0; i < censusSize; i++ { + p, err := tr1.GenProof(keys[i], values[i]) + if err != nil { + b.Fatal(err) + } + if len(p) == 0 { + b.Fatal("proof not generated") + } + proofs = append(proofs, p) + } + b.Logf("gen proofs took %d ms", time.Since(timer).Milliseconds()) + + // Check proofs + timer = time.Now() + for i := 0; i < censusSize; i++ { + valid, err := tr1.CheckProof(keys[i], values[i], root1, proofs[i]) + if err != nil { + b.Fatal(err) + } + if !valid { + b.Errorf("proof %d is invalid", i) + } + } + b.Logf("check proofs took %d ms", time.Since(timer).Milliseconds()) + b.Logf("[finished] %d ms", time.Since(totalTimer).Milliseconds()) +} + +func RandomBytes(n int) []byte { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + panic(err) + } + return b +}