mirror of
https://github.com/arnaucube/asmt.git
synced 2026-02-07 11:26:39 +01:00
@@ -136,8 +136,8 @@ func (t *Tree) Add(index, value []byte) error {
|
|||||||
if t.snapshotRoot != nil {
|
if t.snapshotRoot != nil {
|
||||||
return fmt.Errorf("cannot add to a snapshot trie")
|
return fmt.Errorf("cannot add to a snapshot trie")
|
||||||
}
|
}
|
||||||
if len(index) < 4 {
|
if len(index) < 4 || len(index) > MaxKeySize {
|
||||||
return fmt.Errorf("index too small (%d), minimum size is 4 bytes", len(index))
|
return fmt.Errorf("wrong key size: %d", len(index))
|
||||||
}
|
}
|
||||||
if len(value) > MaxValueSize {
|
if len(value) > MaxValueSize {
|
||||||
return fmt.Errorf("index or value claim data too big")
|
return fmt.Errorf("index or value claim data too big")
|
||||||
@@ -162,11 +162,9 @@ func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) {
|
|||||||
if len(values) > 0 && len(indexes) != len(values) {
|
if len(values) > 0 && len(indexes) != len(values) {
|
||||||
return wrongIndexes, fmt.Errorf("indexes and values have different size")
|
return wrongIndexes, fmt.Errorf("indexes and values have different size")
|
||||||
}
|
}
|
||||||
var hashedIndexes [][]byte
|
|
||||||
var hashedValues [][]byte
|
|
||||||
var value []byte
|
var value []byte
|
||||||
for i, key := range indexes {
|
for i, key := range indexes {
|
||||||
if len(key) < 4 {
|
if len(key) < 4 || len(key) > MaxKeySize {
|
||||||
wrongIndexes = append(wrongIndexes, i)
|
wrongIndexes = append(wrongIndexes, i)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -178,12 +176,10 @@ func (t *Tree) AddBatch(indexes, values [][]byte) ([]int, error) {
|
|||||||
}
|
}
|
||||||
value = values[i]
|
value = values[i]
|
||||||
}
|
}
|
||||||
hashedIndexes = append(hashedIndexes, asmt.Hasher(key))
|
_, err := t.Tree.Update([][]byte{asmt.Hasher(key)}, [][]byte{asmt.Hasher(value)})
|
||||||
hashedValues = append(hashedValues, asmt.Hasher(value))
|
if err != nil {
|
||||||
}
|
return wrongIndexes, err
|
||||||
_, err := t.Tree.Update(hashedIndexes, hashedValues)
|
}
|
||||||
if err != nil {
|
|
||||||
return wrongIndexes, err
|
|
||||||
}
|
}
|
||||||
atomic.StoreUint64(&t.size, 0) // TBD: improve this
|
atomic.StoreUint64(&t.size, 0) // TBD: improve this
|
||||||
return wrongIndexes, t.Commit()
|
return wrongIndexes, t.Commit()
|
||||||
@@ -205,28 +201,27 @@ func (t *Tree) GenProof(index, value []byte) ([]byte, error) {
|
|||||||
t.updateAccessTime()
|
t.updateAccessTime()
|
||||||
var err error
|
var err error
|
||||||
var ap [][]byte
|
var ap [][]byte
|
||||||
var pvalue []byte
|
var pvalue, bitmap []byte
|
||||||
var bitmap []byte
|
|
||||||
var length int
|
var length int
|
||||||
|
var included bool
|
||||||
|
key := asmt.Hasher(index)
|
||||||
if t.snapshotRoot != nil {
|
if t.snapshotRoot != nil {
|
||||||
bitmap, ap, length, _, _, pvalue, err = t.Tree.MerkleProofCompressedR(
|
bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressedR(key,
|
||||||
asmt.Hasher(index),
|
|
||||||
t.snapshotRoot)
|
t.snapshotRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bitmap, ap, length, _, _, pvalue, err = t.Tree.MerkleProofCompressed(
|
bitmap, ap, length, included, _, pvalue, err = t.Tree.MerkleProofCompressed(key)
|
||||||
asmt.Hasher(index))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//if !included {
|
if !included {
|
||||||
// return nil, fmt.Errorf("not included")
|
return nil, nil
|
||||||
//}
|
}
|
||||||
if !bytes.Equal(pvalue, asmt.Hasher(value)) {
|
if !bytes.Equal(pvalue, asmt.Hasher(value)) {
|
||||||
return nil, fmt.Errorf("incorrect value on genProof")
|
return nil, fmt.Errorf("incorrect value or key on genProof")
|
||||||
}
|
}
|
||||||
return bare.Marshal(&Proof{Bitmap: bitmap, Length: length, Siblings: ap, Value: pvalue})
|
return bare.Marshal(&Proof{Bitmap: bitmap, Length: length, Siblings: ap, Value: pvalue})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package asmtree
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTree(t *testing.T) {
|
func TestTree(t *testing.T) {
|
||||||
@@ -86,5 +88,186 @@ func TestTree(t *testing.T) {
|
|||||||
if !valid {
|
if !valid {
|
||||||
t.Errorf("proof is invalid on tree2")
|
t.Errorf("proof is invalid on tree2")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProofs(t *testing.T) {
|
||||||
|
censusSize := 10000
|
||||||
|
storage := t.TempDir()
|
||||||
|
tr1 := &Tree{}
|
||||||
|
err := tr1.Init("test1", storage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys, values [][]byte
|
||||||
|
for i := 0; i < censusSize; i++ {
|
||||||
|
keys = append(keys, RandomBytes(32))
|
||||||
|
values = append(values, RandomBytes(32))
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < censusSize-200 {
|
||||||
|
if fail, err := tr1.AddBatch(keys[i:i+200], values[i:i+200]); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if len(fail) > 0 {
|
||||||
|
t.Fatalf("some keys failed to add on addBatch: %v", fail)
|
||||||
|
}
|
||||||
|
i += 200
|
||||||
|
}
|
||||||
|
// Add remaining claims (if size%200 != 0)
|
||||||
|
if i < censusSize {
|
||||||
|
if fail, err := tr1.AddBatch(keys[i:], values[i:]); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if len(fail) > 0 {
|
||||||
|
t.Fatalf("some keys failed to add on addBatch: %v", fail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
root1 := tr1.Root()
|
||||||
|
data, err := tr1.Dump(root1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("dumped data size is: %d bytes", len(data))
|
||||||
|
|
||||||
|
// Get the size
|
||||||
|
s, err := tr1.Size(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot get te size: %v", err)
|
||||||
|
}
|
||||||
|
if s != int64(censusSize) {
|
||||||
|
t.Errorf("size is wrong (have %d, expexted %d)", s, censusSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a proofs
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
proofs := [][]byte{}
|
||||||
|
for i := 0; i < censusSize; i++ {
|
||||||
|
p, err := tr1.GenProof(keys[i], values[i])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(p) == 0 {
|
||||||
|
t.Fatal("proof not generated")
|
||||||
|
}
|
||||||
|
proofs = append(proofs, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check proofs
|
||||||
|
for i := 0; i < censusSize; i++ {
|
||||||
|
valid, err := tr1.CheckProof(keys[i], values[i], root1, proofs[i])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
t.Errorf("proof %d is invalid", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// go test -v -run=- -bench=Tree -benchtime=30s .
|
||||||
|
func BenchmarkTree(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
// Create websocket client
|
||||||
|
for pb.Next() {
|
||||||
|
benchProofs(b, 100000)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchProofs(b *testing.B, censusSize int) {
|
||||||
|
totalTimer := time.Now()
|
||||||
|
storage := b.TempDir()
|
||||||
|
tr1 := &Tree{}
|
||||||
|
err := tr1.Init("test1", storage)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys, values [][]byte
|
||||||
|
for i := 0; i < censusSize; i++ {
|
||||||
|
keys = append(keys, RandomBytes(32))
|
||||||
|
values = append(values, RandomBytes(32))
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.Now()
|
||||||
|
i := 0
|
||||||
|
for i < censusSize-200 {
|
||||||
|
if fail, err := tr1.AddBatch(keys[i:i+200], values[i:i+200]); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
} else if len(fail) > 0 {
|
||||||
|
b.Fatalf("some keys failed to add on addBatch: %v", fail)
|
||||||
|
}
|
||||||
|
i += 200
|
||||||
|
}
|
||||||
|
// Add remaining claims (if size%200 != 0)
|
||||||
|
if i < censusSize {
|
||||||
|
if fail, err := tr1.AddBatch(keys[i:], values[i:]); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
} else if len(fail) > 0 {
|
||||||
|
b.Fatalf("some keys failed to add on addBatch: %v", fail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.Logf("addBatch took %d ms", time.Since(timer).Milliseconds())
|
||||||
|
|
||||||
|
timer = time.Now()
|
||||||
|
root1 := tr1.Root()
|
||||||
|
data, err := tr1.Dump(root1)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.Logf("dumped data size is: %d bytes", len(data))
|
||||||
|
b.Logf("dump took %d ms", time.Since(timer).Milliseconds())
|
||||||
|
|
||||||
|
// Get the size
|
||||||
|
s, err := tr1.Size(nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("cannot get te size: %v", err)
|
||||||
|
}
|
||||||
|
if s != int64(censusSize) {
|
||||||
|
b.Errorf("size is wrong (have %d, expexted %d)", s, censusSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a proofs
|
||||||
|
timer = time.Now()
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
proofs := [][]byte{}
|
||||||
|
for i := 0; i < censusSize; i++ {
|
||||||
|
p, err := tr1.GenProof(keys[i], values[i])
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(p) == 0 {
|
||||||
|
b.Fatal("proof not generated")
|
||||||
|
}
|
||||||
|
proofs = append(proofs, p)
|
||||||
|
}
|
||||||
|
b.Logf("gen proofs took %d ms", time.Since(timer).Milliseconds())
|
||||||
|
|
||||||
|
// Check proofs
|
||||||
|
timer = time.Now()
|
||||||
|
for i := 0; i < censusSize; i++ {
|
||||||
|
valid, err := tr1.CheckProof(keys[i], values[i], root1, proofs[i])
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
b.Errorf("proof %d is invalid", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.Logf("check proofs took %d ms", time.Since(timer).Milliseconds())
|
||||||
|
b.Logf("[finished] %d ms", time.Since(totalTimer).Milliseconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomBytes(n int) []byte {
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user