diff --git a/addbatch_test.go b/addbatch_test.go index d02fd10..e6e4cb2 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -39,12 +39,12 @@ func debugTime(descr string, time1, time2 time.Duration) { func testInit(c *qt.C, n int) (*Tree, *Tree) { database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionPoseidon) + tree1, err := NewTree(database1, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) bLen := HashFunctionPoseidon.Len() @@ -70,11 +70,11 @@ func TestAddBatchTreeEmpty(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -93,7 +93,7 @@ func TestAddBatchTreeEmpty(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() @@ -120,11 +120,11 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -135,7 +135,7 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -167,13 +167,13 @@ func TestAddBatchTestVector1(t *testing.T) { c := qt.New(t) database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionBlake2b) + tree1, err := NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionBlake2b) + tree2, err := NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -207,13 +207,13 @@ func TestAddBatchTestVector1(t *testing.T) { // 2nd test vectors database1, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err = NewTree(database1, 100, HashFunctionBlake2b) + tree1, err = NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err = NewTree(database2, 100, HashFunctionBlake2b) + tree2, err = NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -255,13 +255,13 @@ func TestAddBatchTestVector2(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database, 100, HashFunctionPoseidon) + tree1, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -300,13 +300,13 @@ func TestAddBatchTestVector3(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database, 100, HashFunctionPoseidon) + tree1, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -349,13 +349,13 @@ func TestAddBatchTreeEmptyRandomKeys(t *testing.T) { database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionBlake2b) + tree1, err := NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionBlake2b) + tree2, err := NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -699,7 +699,7 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() @@ -776,7 +776,7 @@ func benchAdd(t *testing.T, ks, vs [][]byte) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 140, HashFunctionBlake2b) + tree, err := NewTree(database, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -796,7 +796,7 @@ func benchAddBatch(t *testing.T, ks, vs [][]byte) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 140, HashFunctionBlake2b) + tree, err := NewTree(database, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -829,7 +829,7 @@ func TestDbgStats(t *testing.T) { // 1 database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionBlake2b) + tree1, err := NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck @@ -843,7 +843,7 @@ func TestDbgStats(t *testing.T) { // 2 database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionBlake2b) + tree2, err := NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -856,7 +856,7 @@ func TestDbgStats(t *testing.T) { // 3 database3, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree3, err := NewTree(database3, 100, HashFunctionBlake2b) + tree3, err := NewTree(database3, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree3.db.Close() //nolint:errcheck @@ -891,7 +891,7 @@ func TestLoadVT(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -927,11 +927,11 @@ func TestAddKeysWithEmptyValues(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -948,7 +948,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() @@ -962,7 +962,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) { // use tree3 to add nil value array database3, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree3, err := NewTree(database3, 100, HashFunctionPoseidon) + tree3, err := NewTree(database3, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree3.db.Close() //nolint:errcheck diff --git a/circomproofs_test.go b/circomproofs_test.go index 39ad1e9..c8345a1 100644 --- a/circomproofs_test.go +++ b/circomproofs_test.go @@ -17,14 +17,13 @@ func TestCircomVerifierProof(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() - testVector := [][]int64{ {1, 11}, {2, 22}, {3, 33}, {4, 44}, } + bLen := 1 for i := 0; i < len(testVector); i++ { k := BigIntToBytes(bLen, big.NewInt(testVector[i][0])) v := BigIntToBytes(bLen, big.NewInt(testVector[i][1])) diff --git a/testvectors/circom/go-data-generator/generator_test.go b/testvectors/circom/go-data-generator/generator_test.go index 91e45d9..ff55129 100644 --- a/testvectors/circom/go-data-generator/generator_test.go +++ b/testvectors/circom/go-data-generator/generator_test.go @@ -18,14 +18,13 @@ func TestGenerator(t *testing.T) { tree, err := arbo.NewTree(database, 4, arbo.HashFunctionPoseidon) c.Assert(err, qt.IsNil) - bLen := tree.HashFunction().Len() - testVector := [][]int64{ {1, 11}, {2, 22}, {3, 33}, {4, 44}, } + bLen := 1 for i := 0; i < len(testVector); i++ { k := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][0])) v := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][1])) diff --git a/tree.go b/tree.go index 6f856d0..0e11b8f 100644 --- a/tree.go +++ b/tree.go @@ -229,8 +229,10 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err } // store root (from the vt) to db - if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil { - return nil, err + if vt.root != nil { + if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil { + return nil, err + } } // update nLeafs @@ -310,14 +312,34 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { return nil } -func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd +// keyPathFromKey returns the keyPath and checks that the key is not bigger +// than maximum key length for the tree maxLevels size. +// This is because if the key bits length is bigger than the maxLevels of the +// tree, two different keys that their difference is at the end, will collision +// in the same leaf of the tree (at the max depth). +func keyPathFromKey(maxLevels int, k []byte) ([]byte, error) { + maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd + if len(k) > maxKeyLen { + return nil, fmt.Errorf("len(k) can not be bigger than ceil(maxLevels/8), where"+ + " len(k): %d, maxLevels: %d, max key len=ceil(maxLevels/8): %d. Might need"+ + " a bigger tree depth (maxLevels>=%d) in order to input keys of length %d", + len(k), maxLevels, maxKeyLen, len(k)*8, len(k)) //nolint:gomnd + } + keyPath := make([]byte, maxKeyLen) //nolint:gomnd copy(keyPath[:], k) + return keyPath, nil +} + +func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return nil, err + } path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false) + _, _, siblings, err = t.down(wTx, k, root, siblings, path, fromLvl, false) if err != nil { return nil, err } @@ -590,8 +612,10 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { return ErrSnapshotNotEditable } - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], k) + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return err + } path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(wTx) @@ -647,8 +671,10 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) { // GenProofWithTx does the same than the GenProof method, but allowing to pass // the db.ReadTx that is used. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) { - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], k) + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return nil, nil, nil, false, err + } path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(rTx) @@ -782,8 +808,10 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { // ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data // found in the tree in the leaf that was on the path going to the input key. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) { - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], k) + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return nil, nil, err + } path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(rTx) diff --git a/tree_test.go b/tree_test.go index f6592c4..a8912c5 100644 --- a/tree_test.go +++ b/tree_test.go @@ -2,7 +2,9 @@ package arbo import ( "encoding/hex" + "math" "math/big" + "runtime" "testing" "time" @@ -60,7 +62,7 @@ func TestAddTestVectors(t *testing.T) { func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 10, hashFunc) + tree, err := NewTree(database, 256, hashFunc) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -68,7 +70,7 @@ func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { c.Assert(err, qt.IsNil) c.Check(hex.EncodeToString(root), qt.Equals, testVectors[0]) - bLen := hashFunc.Len() + bLen := 32 err = tree.Add( BigIntToBytes(bLen, big.NewInt(1)), BigIntToBytes(bLen, big.NewInt(2))) @@ -92,11 +94,11 @@ func TestAddBatch(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(0)) @@ -110,7 +112,7 @@ func TestAddBatch(t *testing.T) { database, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database, 100, HashFunctionPoseidon) + tree2, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -133,11 +135,11 @@ func TestAddDifferentOrder(t *testing.T) { c := qt.New(t) database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionPoseidon) + tree1, err := NewTree(database1, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck - bLen := tree1.HashFunction().Len() + bLen := 32 for i := 0; i < 16; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(0)) @@ -148,7 +150,7 @@ func TestAddDifferentOrder(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -173,11 +175,11 @@ func TestAddRepeatedIndex(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(3))) v := BigIntToBytes(bLen, big.NewInt(int64(12))) @@ -191,11 +193,11 @@ func TestUpdate(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(20))) v := BigIntToBytes(bLen, big.NewInt(int64(12))) if err := tree.Add(k, v); err != nil { @@ -244,11 +246,11 @@ func TestAux(t *testing.T) { // TODO split in proper tests c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(1))) v := BigIntToBytes(bLen, big.NewInt(int64(0))) err = tree.Add(k, v) @@ -283,11 +285,11 @@ func TestGet(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 for i := 0; i < 10; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -307,11 +309,11 @@ func TestGenProofAndVerify(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() - 1 + bLen := 32 for i := 0; i < 10; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -339,11 +341,11 @@ func TestDumpAndImportDump(t *testing.T) { c := qt.New(t) database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionPoseidon) + tree1, err := NewTree(database1, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck - bLen := tree1.HashFunction().Len() + bLen := 32 for i := 0; i < 16; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -357,7 +359,7 @@ func TestDumpAndImportDump(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck err = tree2.ImportDump(e) @@ -376,11 +378,11 @@ func TestRWMutex(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -469,7 +471,7 @@ func TestAddBatchFullyUsed(t *testing.T) { var keys, values [][]byte for i := 0; i < 16; i++ { - k := BigIntToBytes(32, big.NewInt(int64(i))) + k := BigIntToBytes(1, big.NewInt(int64(i))) v := k keys = append(keys, k) @@ -492,10 +494,10 @@ func TestAddBatchFullyUsed(t *testing.T) { // get all key-values and check that are equal between both trees for i := 0; i < 16; i++ { - auxK1, auxV1, err := tree1.Get(BigIntToBytes(32, big.NewInt(int64(i)))) + auxK1, auxV1, err := tree1.Get(BigIntToBytes(1, big.NewInt(int64(i)))) c.Assert(err, qt.IsNil) - auxK2, auxV2, err := tree2.Get(BigIntToBytes(32, big.NewInt(int64(i)))) + auxK2, auxV2, err := tree2.Get(BigIntToBytes(1, big.NewInt(int64(i)))) c.Assert(err, qt.IsNil) c.Assert(auxK1, qt.DeepEquals, auxK2) @@ -504,7 +506,7 @@ func TestAddBatchFullyUsed(t *testing.T) { // try adding one more key to both trees (through Add & AddBatch) and // expect not being added due the tree is already full - k := BigIntToBytes(32, big.NewInt(int64(16))) + k := BigIntToBytes(1, big.NewInt(int64(16))) v := k err = tree1.Add(k, v) c.Assert(err, qt.Equals, ErrMaxVirtualLevel) @@ -518,13 +520,13 @@ func TestSetRoot(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) expectedRoot := "13742386369878513332697380582061714160370929283209286127733983161245560237407" // fill the tree - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -574,11 +576,11 @@ func TestSnapshot(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) // fill the tree - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -624,11 +626,11 @@ func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(3))) root, err := tree.Root() @@ -646,7 +648,7 @@ func TestKeyLen(t *testing.T) { c.Assert(err, qt.IsNil) // maxLevels is 100, keyPath length = ceil(maxLevels/8) = 13 maxLevels := 100 - tree, err := NewTree(database, maxLevels, HashFunctionPoseidon) + tree, err := NewTree(database, maxLevels, HashFunctionBlake2b) c.Assert(err, qt.IsNil) // expect no errors when adding a key of only 4 bytes (when the @@ -672,6 +674,75 @@ func TestKeyLen(t *testing.T) { invalids, err := tree.AddBatch([][]byte{k}, [][]byte{v}) c.Assert(err, qt.IsNil) c.Assert(len(invalids), qt.Equals, 0) + + // expect errors when adding a key bigger than maximum capacity of the + // tree: ceil(maxLevels/8) + maxLevels = 32 + database, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) + c.Assert(err, qt.IsNil) + tree, err = NewTree(database, maxLevels, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + + maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd + k = BigIntToBytes(maxKeyLen+1, big.NewInt(1)) + v = BigIntToBytes(maxKeyLen+1, big.NewInt(1)) + + expectedErrMsg := "len(k) can not be bigger than ceil(maxLevels/8)," + + " where len(k): 5, maxLevels: 32, max key len=ceil(maxLevels/8): 4." + + " Might need a bigger tree depth (maxLevels>=40) in order to input" + + " keys of length 5" + + err = tree.Add(k, v) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + err = tree.Update(k, v) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + _, _, _, _, err = tree.GenProof(k) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + _, _, err = tree.Get(k) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + // check AddBatch with few key-values + invalids, err = tree.AddBatch([][]byte{k}, [][]byte{v}) + c.Assert(err, qt.IsNil) + c.Assert(len(invalids), qt.Equals, 1) + + // check AddBatch with many key-values + nCPU := flp2(runtime.NumCPU()) + nKVs := nCPU + 1 + var ks, vs [][]byte + for i := 0; i < nKVs; i++ { + ks = append(ks, BigIntToBytes(maxKeyLen+1, big.NewInt(1))) + vs = append(vs, BigIntToBytes(maxKeyLen+1, big.NewInt(1))) + } + invalids, err = tree.AddBatch(ks, vs) + c.Assert(err, qt.IsNil) + c.Assert(len(invalids), qt.Equals, nKVs) + + // check that with maxKeyLen it can be added + k = BigIntToBytes(maxKeyLen, big.NewInt(1)) + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) + + // check CheckProof check with key longer + kAux, vAux, packedSiblings, existence, err := tree.GenProof(k) + c.Assert(err, qt.IsNil) + c.Assert(existence, qt.IsTrue) + + root, err := tree.Root() + c.Assert(err, qt.IsNil) + verif, err := CheckProof(tree.HashFunction(), kAux, vAux, root, packedSiblings) + c.Assert(err, qt.IsNil) + c.Assert(verif, qt.IsTrue) + + // use a similar key but with one zero, expect that CheckProof fails on + // the verification + kAux = append(kAux, 0) + verif, err = CheckProof(tree.HashFunction(), kAux, vAux, root, packedSiblings) + c.Assert(err, qt.IsNil) + c.Assert(verif, qt.IsFalse) } func BenchmarkAdd(b *testing.B) { diff --git a/utils.go b/utils.go index d7a225f..c8fe02e 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package arbo -import "math/big" +import ( + "math/big" +) // SwapEndianness swaps the order of the bytes in the byte slice. func SwapEndianness(b []byte) []byte { diff --git a/vt.go b/vt.go index 939c6b6..55bac06 100644 --- a/vt.go +++ b/vt.go @@ -37,22 +37,32 @@ type kv struct { v []byte } -func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { +func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, []int, error) { if len(ks) != len(vs) { - return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", + return nil, nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", len(ks), len(vs)) } - kvs := make([]kv, len(ks)) + var invalids []int + var kvs []kv for i := 0; i < len(ks); i++ { - keyPath := make([]byte, int(math.Ceil(float64(p.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], ks[i]) - kvs[i].pos = i - kvs[i].keyPath = keyPath - kvs[i].k = ks[i] - kvs[i].v = vs[i] + keyPath, err := keyPathFromKey(p.maxLevels, ks[i]) + if err != nil { + // TODO in a future iteration, invalids will contain + // the reason of the error of why each index is + // invalid. + invalids = append(invalids, i) + continue + } + + var kvsI kv + kvsI.pos = i + kvsI.keyPath = keyPath + kvsI.k = ks[i] + kvsI.v = vs[i] + kvs = append(kvs, kvsI) } - return kvs, nil + return kvs, invalids, nil } // vt stands for virtual tree. It's a tree that does not have any computed hash @@ -94,9 +104,9 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { l := int(math.Log2(float64(nCPU))) - kvs, err := t.params.keysValuesToKvs(ks, vs) + kvs, invalids, err := t.params.keysValuesToKvs(ks, vs) if err != nil { - return nil, err + return invalids, err } buckets := splitInBuckets(kvs, nCPU) @@ -186,7 +196,6 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { } wg.Wait() - var invalids []int for i := 0; i < len(invalidsInBucket); i++ { invalids = append(invalids, invalidsInBucket[i]...) } @@ -284,7 +293,10 @@ func upFromNodes(ns []*node) (*node, error) { // add adds a key&value as a leaf in the VirtualTree func (t *vt) add(fromLvl int, k, v []byte) error { - leaf := newLeafNode(t.params, k, v) + leaf, err := newLeafNode(t.params, k, v) + if err != nil { + return err + } if t.root == nil { t.root = leaf return nil @@ -366,16 +378,18 @@ func (t *vt) computeHashes() ([][2][]byte, error) { return pairs, nil } -func newLeafNode(p *params, k, v []byte) *node { - keyPath := make([]byte, p.hashFunction.Len()) - copy(keyPath[:], k) +func newLeafNode(p *params, k, v []byte) (*node, error) { + keyPath, err := keyPathFromKey(p.maxLevels, k) + if err != nil { + return nil, err + } path := getPath(p.maxLevels, keyPath) n := &node{ k: k, v: v, path: path, } - return n + return n, nil } type virtualNodeType int diff --git a/vt_test.go b/vt_test.go index 2ec9f23..4f60454 100644 --- a/vt_test.go +++ b/vt_test.go @@ -2,6 +2,7 @@ package arbo import ( "encoding/hex" + "math" "math/big" "testing" @@ -9,28 +10,62 @@ import ( "go.vocdoni.io/dvote/db/badgerdb" ) +// testVirtualTree adds the given key-values and tests the vt root against the +// Tree +func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { + c.Assert(len(keys), qt.Equals, len(values)) + + // normal tree, to have an expected root value + database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) + c.Assert(err, qt.IsNil) + tree, err := NewTree(database, maxLevels, HashFunctionSha256) + c.Assert(err, qt.IsNil) + for i := 0; i < len(keys); i++ { + err := tree.Add(keys[i], values[i]) + c.Assert(err, qt.IsNil) + } + + // virtual tree + vTree := newVT(maxLevels, HashFunctionSha256) + + c.Assert(vTree.root, qt.IsNil) + + for i := 0; i < len(keys); i++ { + err := vTree.add(0, keys[i], values[i]) + c.Assert(err, qt.IsNil) + } + + // compute hashes, and check Root + _, err = vTree.computeHashes() + c.Assert(err, qt.IsNil) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, root) +} + func TestVirtualTreeTestVectors(t *testing.T) { c := qt.New(t) - bLen := 32 + maxLevels := 32 + keyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd keys := [][]byte{ - BigIntToBytes(bLen, big.NewInt(1)), - BigIntToBytes(bLen, big.NewInt(33)), - BigIntToBytes(bLen, big.NewInt(1234)), - BigIntToBytes(bLen, big.NewInt(123456789)), + BigIntToBytes(keyLen, big.NewInt(1)), + BigIntToBytes(keyLen, big.NewInt(33)), + BigIntToBytes(keyLen, big.NewInt(1234)), + BigIntToBytes(keyLen, big.NewInt(123456789)), } values := [][]byte{ - BigIntToBytes(bLen, big.NewInt(2)), - BigIntToBytes(bLen, big.NewInt(44)), - BigIntToBytes(bLen, big.NewInt(9876)), - BigIntToBytes(bLen, big.NewInt(987654321)), + BigIntToBytes(keyLen, big.NewInt(2)), + BigIntToBytes(keyLen, big.NewInt(44)), + BigIntToBytes(keyLen, big.NewInt(9876)), + BigIntToBytes(keyLen, big.NewInt(987654321)), } // check the root for different batches of leafs - testVirtualTree(c, 10, keys[:1], values[:1]) - testVirtualTree(c, 10, keys[:2], values[:2]) - testVirtualTree(c, 10, keys[:3], values[:3]) - testVirtualTree(c, 10, keys[:4], values[:4]) + testVirtualTree(c, maxLevels, keys[:1], values[:1]) + testVirtualTree(c, maxLevels, keys[:2], values[:2]) + testVirtualTree(c, maxLevels, keys[:3], values[:3]) + testVirtualTree(c, maxLevels, keys[:4], values[:4]) // test with hardcoded values testvectorKeys := []string{ @@ -53,8 +88,8 @@ func TestVirtualTreeTestVectors(t *testing.T) { } // check the root for different batches of leafs - testVirtualTree(c, 10, keys[:1], values[:1]) - testVirtualTree(c, 10, keys, values) + testVirtualTree(c, 256, keys[:1], values[:1]) + testVirtualTree(c, 256, keys, values) } func TestVirtualTreeRandomKeys(t *testing.T) { @@ -69,45 +104,14 @@ func TestVirtualTreeRandomKeys(t *testing.T) { values[i] = randomBytes(32) } - testVirtualTree(c, 100, keys, values) -} - -func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { - c.Assert(len(keys), qt.Equals, len(values)) - - // normal tree, to have an expected root value - database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) - c.Assert(err, qt.IsNil) - tree, err := NewTree(database, maxLevels, HashFunctionSha256) - c.Assert(err, qt.IsNil) - for i := 0; i < len(keys); i++ { - err := tree.Add(keys[i], values[i]) - c.Assert(err, qt.IsNil) - } - - // virtual tree - vTree := newVT(maxLevels, HashFunctionSha256) - - c.Assert(vTree.root, qt.IsNil) - - for i := 0; i < len(keys); i++ { - err := vTree.add(0, keys[i], values[i]) - c.Assert(err, qt.IsNil) - } - - // compute hashes, and check Root - _, err = vTree.computeHashes() - c.Assert(err, qt.IsNil) - root, err := tree.Root() - c.Assert(err, qt.IsNil) - c.Assert(vTree.root.h, qt.DeepEquals, root) + testVirtualTree(c, 256, keys, values) } func TestVirtualTreeAddBatch(t *testing.T) { c := qt.New(t) nLeafs := 2000 - maxLevels := 100 + maxLevels := 256 keys := make([][]byte, nLeafs) values := make([][]byte, nLeafs) @@ -151,7 +155,7 @@ func TestVirtualTreeAddBatchFullyUsed(t *testing.T) { var keys, values [][]byte for i := 0; i < 128; i++ { - k := BigIntToBytes(32, big.NewInt(int64(i))) + k := BigIntToBytes(1, big.NewInt(int64(i))) v := k keys = append(keys, k)