From 0b2c3b07edb9f6764faa7b2c5d7ba8da89ead8fe Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 26 May 2021 17:01:09 +0200 Subject: [PATCH] Update public methods signatures - Update public methods signatures - Remove 'lastAccess' param - Add option to pass root for tree.Dump, Iterate, IterateWithStop, Graphviz (and related) - Move error messages to const defined error messages for external usage --- addbatch_test.go | 82 ++++++++++++++++++--------------- hash.go | 6 +-- hash_test.go | 5 ++- tree.go | 115 ++++++++++++++++++++++++++--------------------- tree_test.go | 113 +++++++++++++++++++++++++--------------------- utils.go | 5 ++- vt.go | 9 ++-- vt_test.go | 66 +++++++++++++++------------ 8 files changed, 222 insertions(+), 179 deletions(-) diff --git a/addbatch_test.go b/addbatch_test.go index 467b233..579c613 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -46,11 +46,12 @@ func testInit(c *qt.C, n int) (*Tree, *Tree) { c.Assert(err, qt.IsNil) defer tree2.db.Close() + bLen := HashFunctionPoseidon.Len() // add the initial leafs to fill a bit the trees before calling the // AddBatch method for i := 0; i < n; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree1.Add(k, v); err != nil { c.Fatal(err) } @@ -70,10 +71,11 @@ func TestAddBatchTreeEmpty(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() + bLen := tree.HashFunction().Len() start := time.Now() for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } @@ -87,8 +89,8 @@ func TestAddBatchTreeEmpty(t *testing.T) { var keys, values [][]byte for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -116,9 +118,10 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() + bLen := tree.HashFunction().Len() for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } @@ -130,8 +133,8 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { var keys, values [][]byte for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -271,10 +274,11 @@ func TestAddBatchTreeNotEmptyFewLeafs(t *testing.T) { tree1, tree2 := testInit(c, initialNLeafs) tree2.dbgInit() + bLen := tree1.HashFunction().Len() start := time.Now() for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } @@ -284,8 +288,8 @@ func TestAddBatchTreeNotEmptyFewLeafs(t *testing.T) { // prepare the key-values to be added var keys, values [][]byte for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -313,10 +317,11 @@ func TestAddBatchTreeNotEmptyEnoughLeafs(t *testing.T) { tree1, tree2 := testInit(c, initialNLeafs) tree2.dbgInit() + bLen := tree1.HashFunction().Len() start := time.Now() for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } @@ -326,8 +331,8 @@ func TestAddBatchTreeNotEmptyEnoughLeafs(t *testing.T) { // prepare the key-values to be added var keys, values [][]byte for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -353,18 +358,19 @@ func TestAddBatchTreeEmptyRepeatedLeafs(t *testing.T) { tree1, tree2 := testInit(c, 0) + bLen := tree1.HashFunction().Len() // prepare the key-values to be added var keys, values [][]byte for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } // add repeated key-values for i := 0; i < nRepeatedKeys; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -391,11 +397,12 @@ func TestAddBatchTreeNotEmptyFewLeafsRepeatedLeafs(t *testing.T) { tree1, tree2 := testInit(c, initialNLeafs) + bLen := tree1.HashFunction().Len() // prepare the key-values to be added var keys, values [][]byte for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -417,11 +424,12 @@ func TestAddBatchTreeNotEmptyFewLeafsRepeatedLeafs(t *testing.T) { func TestSplitInBuckets(t *testing.T) { c := qt.New(t) + bLen := HashFunctionPoseidon.Len() nLeafs := 16 kvs := make([]kv, nLeafs) for i := 0; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keyPath := make([]byte, 32) copy(keyPath[:], k) kvs[i].pos = i @@ -523,10 +531,11 @@ func TestAddBatchTreeNotEmpty(t *testing.T) { tree1, tree2 := testInit(c, initialNLeafs) tree2.dbgInit() + bLen := tree1.HashFunction().Len() start := time.Now() for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } @@ -536,8 +545,8 @@ func TestAddBatchTreeNotEmpty(t *testing.T) { // prepare the key-values to be added var keys, values [][]byte for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } @@ -563,11 +572,12 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) { initialNLeafs := 900 tree1, _ := testInit(c, initialNLeafs) + bLen := tree1.HashFunction().Len() start := time.Now() for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } @@ -583,8 +593,8 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) { // add the initial leafs to fill a bit the tree before calling the // AddBatch method for i := 0; i < initialNLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) // use only the keys of one bucket, store the not used ones for // later if i%4 != 0 { @@ -598,8 +608,8 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) { } for i := initialNLeafs; i < nLeafs; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) keys = append(keys, k) values = append(values, v) } diff --git a/hash.go b/hash.go index c17edb7..cb51468 100644 --- a/hash.go +++ b/hash.go @@ -37,8 +37,8 @@ type HashFunction interface { Type() []byte Len() int Hash(...[]byte) ([]byte, error) - // CheckInputs checks if the inputs are valid without computing the hash - // CheckInputs(...[]byte) error + // CheckInput checks if the input is valid without computing the hash + // CheckInput(...[]byte) error } // HashSha256 implements the HashFunction interface for the Sha256 hash @@ -88,7 +88,7 @@ func (f HashPoseidon) Hash(b ...[]byte) ([]byte, error) { if err != nil { return nil, err } - hB := BigIntToBytes(h) + hB := BigIntToBytes(f.Len(), h) return hB, nil } diff --git a/hash_test.go b/hash_test.go index 65ac535..da0d629 100644 --- a/hash_test.go +++ b/hash_test.go @@ -25,9 +25,10 @@ func TestHashSha256(t *testing.T) { func TestHashPoseidon(t *testing.T) { // Poseidon hash hashFunc := &HashPoseidon{} + bLen := hashFunc.Len() h, err := hashFunc.Hash( - BigIntToBytes(big.NewInt(1)), - BigIntToBytes(big.NewInt(2))) + BigIntToBytes(bLen, big.NewInt(1)), + BigIntToBytes(bLen, big.NewInt(2))) if err != nil { t.Fatal(err) } diff --git a/tree.go b/tree.go index 14083ad..d2be1ab 100644 --- a/tree.go +++ b/tree.go @@ -19,8 +19,6 @@ import ( "io" "math" "sync" - "sync/atomic" - "time" "github.com/iden3/go-merkletree/db" ) @@ -48,16 +46,30 @@ var ( dbKeyRoot = []byte("root") dbKeyNLeafs = []byte("nleafs") emptyValue = []byte{0} + + // ErrKeyAlreadyExists is used when trying to add a key as leaf to the + // tree that already exists. + ErrKeyAlreadyExists = fmt.Errorf("key already exists") + // ErrInvalidValuePrefix is used when going down into the tree, a value + // is read from the db and has an unrecognized prefix. + ErrInvalidValuePrefix = fmt.Errorf("invalid value prefix") + // ErrDBNoTx is used when trying to use Tree.dbPut but Tree.tx==nil + ErrDBNoTx = fmt.Errorf("dbPut error: no db Tx") + // ErrMaxLevel indicates when going down into the tree, the max level is + // reached + ErrMaxLevel = fmt.Errorf("max level reached") + // ErrMaxVirtualLevel indicates when going down into the tree, the max + // virtual level is reached + ErrMaxVirtualLevel = fmt.Errorf("max virtual level reached") ) // Tree defines the struct that implements the MerkleTree functionalities type Tree struct { sync.RWMutex - tx db.Tx - db db.Storage - lastAccess int64 // in unix time // TODO delete, is a feature of a upper abstraction level - maxLevels int - root []byte + tx db.Tx + db db.Storage + maxLevels int + root []byte hashFunction HashFunction // TODO in the methods that use it, check if emptyHash param is len>0 @@ -71,8 +83,6 @@ type Tree struct { // will load it. func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error) { t := Tree{db: storage, maxLevels: maxLevels, hashFunction: hash} - t.updateAccessTime() - t.emptyHash = make([]byte, t.hashFunction.Len()) // empty root, err := t.dbGet(dbKeyRoot) @@ -100,15 +110,6 @@ func NewTree(storage db.Storage, maxLevels int, hash HashFunction) (*Tree, error return &t, nil } -func (t *Tree) updateAccessTime() { - atomic.StoreInt64(&t.lastAccess, time.Now().Unix()) -} - -// LastAccess returns the last access timestamp in Unixtime -func (t *Tree) LastAccess() int64 { - return atomic.LoadInt64(&t.lastAccess) -} - // Root returns the root of the Tree func (t *Tree) Root() []byte { return t.root @@ -122,7 +123,6 @@ func (t *Tree) HashFunction() HashFunction { // AddBatch adds a batch of key-values to the Tree. Returns an array containing // the indexes of the keys failed to add. func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { - t.updateAccessTime() t.Lock() defer t.Unlock() @@ -131,7 +131,8 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { return nil, err } - // TODO check that keys & values is valid for Tree.hashFunction + // TODO check validity of keys & values for Tree.hashFunction + invalids, err := vt.addBatch(keys, values) if err != nil { return nil, err @@ -140,6 +141,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { // once the VirtualTree is build, compute the hashes pairs, err := vt.computeHashes() if err != nil { + // TODO currently invalids in computeHashes are not counted return nil, err } t.root = vt.root.h @@ -177,7 +179,7 @@ func (t *Tree) AddBatch(keys, values [][]byte) ([]int, error) { func (t *Tree) loadVT() (vt, error) { vt := newVT(t.maxLevels, t.hashFunction) vt.params.dbg = t.dbg - err := t.Iterate(func(k, v []byte) { + err := t.Iterate(nil, func(k, v []byte) { if v[0] != PrefixValueLeaf { return } @@ -194,8 +196,6 @@ func (t *Tree) loadVT() (vt, error) { // is expected that are represented by a Little-Endian byte array (for circom // compatibility). func (t *Tree) Add(k, v []byte) error { - t.updateAccessTime() - t.Lock() defer t.Unlock() @@ -205,6 +205,8 @@ func (t *Tree) Add(k, v []byte) error { return err } + // TODO check validity of key & value for Tree.hashFunction + err = t.add(0, k, v) // add from level 0 if err != nil { return err @@ -221,8 +223,6 @@ func (t *Tree) Add(k, v []byte) error { } func (t *Tree) add(fromLvl int, k, v []byte) error { - // TODO check validity of key & value (for the Tree.HashFunction type) - keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) @@ -262,7 +262,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, path []bool, currLvl int, getLeaf bool) ( []byte, []byte, [][]byte, error) { if currLvl > t.maxLevels-1 { - return nil, nil, nil, fmt.Errorf("max level") + return nil, nil, nil, ErrMaxLevel } var err error @@ -287,7 +287,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, // TODO move this error msg to const & add test that // checks that adding a repeated key this error is // returned - return nil, nil, nil, fmt.Errorf("key already exists") + return nil, nil, nil, ErrKeyAlreadyExists } if !bytes.Equal(currValue, emptyValue) { @@ -324,7 +324,7 @@ func (t *Tree) down(newKey, currKey []byte, siblings [][]byte, siblings = append(siblings, rChild) return t.down(newKey, lChild, siblings, path, currLvl+1, getLeaf) default: - return nil, nil, nil, fmt.Errorf("invalid value") + return nil, nil, nil, ErrInvalidValuePrefix } } @@ -334,7 +334,7 @@ func (t *Tree) downVirtually(siblings [][]byte, oldKey, newKey []byte, oldPath, newPath []bool, currLvl int) ([][]byte, error) { var err error if currLvl > t.maxLevels-1 { - return nil, fmt.Errorf("max virtual level %d", currLvl) + return nil, ErrMaxVirtualLevel } if oldPath[currLvl] == newPath[currLvl] { @@ -459,8 +459,6 @@ func getPath(numLevels int, k []byte) []bool { // Update updates the value for a given existing key. If the given key does not // exist, returns an error. func (t *Tree) Update(k, v []byte) error { - t.updateAccessTime() - t.Lock() defer t.Unlock() @@ -515,7 +513,6 @@ func (t *Tree) Update(k, v []byte) error { // the Tree, the proof will be of existence, if the key does not exist in the // tree, the proof will be of non-existence. func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) { - t.updateAccessTime() keyPath := make([]byte, t.hashFunction.Len()) copy(keyPath[:], k) @@ -533,7 +530,7 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, error) { fmt.Println(leafK) fmt.Println(leafV) // TODO proof of non-existence - panic(fmt.Errorf("unimplemented")) + panic("unimplemented") } s := PackSiblings(t.hashFunction, siblings) @@ -627,8 +624,8 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { } leafK, leafV := ReadLeafValue(value) if !bytes.Equal(k, leafK) { - panic(fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s", - BytesToBigInt(k), BytesToBigInt(leafK))) + return leafK, leafV, fmt.Errorf("Tree.Get error: keys doesn't match, %s != %s", + BytesToBigInt(k), BytesToBigInt(leafK)) } return leafK, leafV, nil @@ -672,7 +669,7 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, func (t *Tree) dbPut(k, v []byte) error { if t.tx == nil { - return fmt.Errorf("dbPut error: no db Tx") + return ErrDBNoTx } t.dbg.incDbPut() return t.tx.Put(k, v) @@ -729,18 +726,23 @@ func (t *Tree) GetNLeafs() (int, error) { // Iterate iterates through the full Tree, executing the given function on each // node of the Tree. -func (t *Tree) Iterate(f func([]byte, []byte)) error { - // TODO allow to define which root to use - t.updateAccessTime() - return t.iter(t.root, f) +func (t *Tree) Iterate(rootKey []byte, f func([]byte, []byte)) error { + // allow to define which root to use + if rootKey == nil { + rootKey = t.Root() + } + return t.iter(rootKey, f) } // IterateWithStop does the same than Iterate, but with int for the current // level, and a boolean parameter used by the passed function, is to indicate to // stop iterating on the branch when the method returns 'true'. -func (t *Tree) IterateWithStop(f func(int, []byte, []byte) bool) error { - t.updateAccessTime() - return t.iterWithStop(t.root, 0, f) +func (t *Tree) IterateWithStop(rootKey []byte, f func(int, []byte, []byte) bool) error { + // allow to define which root to use + if rootKey == nil { + rootKey = t.Root() + } + return t.iterWithStop(rootKey, 0, f) } func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) bool) error { @@ -768,7 +770,7 @@ func (t *Tree) iterWithStop(k []byte, currLevel int, f func(int, []byte, []byte) return err } default: - return fmt.Errorf("invalid value") + return ErrInvalidValuePrefix } return nil } @@ -786,14 +788,16 @@ func (t *Tree) iter(k []byte, f func([]byte, []byte)) error { // [ 1 byte | 1 byte | S bytes | len(v) bytes ] // [ len(k) | len(v) | key | value ] // Where S is the size of the output of the hash function used for the Tree. -func (t *Tree) Dump() ([]byte, error) { - t.updateAccessTime() - // TODO allow to define which root to use +func (t *Tree) Dump(rootKey []byte) ([]byte, error) { + // allow to define which root to use + if rootKey == nil { + rootKey = t.Root() + } // WARNING current encoding only supports key & values of 255 bytes each // (due using only 1 byte for the length headers). var b []byte - err := t.Iterate(func(k, v []byte) { + err := t.Iterate(rootKey, func(k, v []byte) { if v[0] != PrefixValueLeaf { return } @@ -811,7 +815,6 @@ func (t *Tree) Dump() ([]byte, error) { // ImportDump imports the leafs (that have been exported with the ExportLeafs // method) in the Tree. func (t *Tree) ImportDump(b []byte) error { - t.updateAccessTime() r := bytes.NewReader(b) var err error var keys, values [][]byte @@ -855,8 +858,11 @@ func (t *Tree) GraphvizFirstNLevels(w io.Writer, rootKey []byte, untilLvl int) e fmt.Fprintf(w, `digraph hierarchy { node [fontname=Monospace,fontsize=10,shape=box] `) + if rootKey == nil { + rootKey = t.Root() + } nEmpties := 0 - err := t.iterWithStop(t.root, 0, func(currLvl int, k, v []byte) bool { + err := t.iterWithStop(rootKey, 0, func(currLvl int, k, v []byte) bool { if currLvl == untilLvl { return true // to stop the iter from going down } @@ -901,6 +907,9 @@ node [fontname=Monospace,fontsize=10,shape=box] // PrintGraphviz prints the output of Tree.Graphviz func (t *Tree) PrintGraphviz(rootKey []byte) error { + if rootKey == nil { + rootKey = t.Root() + } return t.PrintGraphvizFirstNLevels(rootKey, t.maxLevels) } @@ -912,7 +921,7 @@ func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error { w := bytes.NewBufferString("") fmt.Fprintf(w, "--------\nGraphviz of the Tree with Root "+hex.EncodeToString(rootKey)+":\n") - err := t.GraphvizFirstNLevels(w, nil, untilLvl) + err := t.GraphvizFirstNLevels(w, rootKey, untilLvl) if err != nil { fmt.Println(w) return err @@ -924,7 +933,9 @@ func (t *Tree) PrintGraphvizFirstNLevels(rootKey []byte, untilLvl int) error { return nil } -// Purge WIP: unimplemented +// Purge WIP: unimplemented TODO func (t *Tree) Purge(keys [][]byte) error { return nil } + +// TODO circom proofs diff --git a/tree_test.go b/tree_test.go index 3a7eb14..273b15b 100644 --- a/tree_test.go +++ b/tree_test.go @@ -38,23 +38,24 @@ func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { c.Check(hex.EncodeToString(tree.Root()), qt.Equals, testVectors[0]) + bLen := hashFunc.Len() err = tree.Add( - BigIntToBytes(big.NewInt(1)), - BigIntToBytes(big.NewInt(2))) + BigIntToBytes(bLen, big.NewInt(1)), + BigIntToBytes(bLen, big.NewInt(2))) c.Assert(err, qt.IsNil) rootBI := BytesToBigInt(tree.Root()) c.Check(rootBI.String(), qt.Equals, testVectors[1]) err = tree.Add( - BigIntToBytes(big.NewInt(33)), - BigIntToBytes(big.NewInt(44))) + BigIntToBytes(bLen, big.NewInt(33)), + BigIntToBytes(bLen, big.NewInt(44))) c.Assert(err, qt.IsNil) rootBI = BytesToBigInt(tree.Root()) c.Check(rootBI.String(), qt.Equals, testVectors[2]) err = tree.Add( - BigIntToBytes(big.NewInt(1234)), - BigIntToBytes(big.NewInt(9876))) + BigIntToBytes(bLen, big.NewInt(1234)), + BigIntToBytes(bLen, big.NewInt(9876))) c.Assert(err, qt.IsNil) rootBI = BytesToBigInt(tree.Root()) c.Check(rootBI.String(), qt.Equals, testVectors[3]) @@ -66,9 +67,10 @@ func TestAddBatch(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() + bLen := tree.HashFunction().Len() for i := 0; i < 1000; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(0)) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(0)) if err := tree.Add(k, v); err != nil { t.Fatal(err) } @@ -84,8 +86,8 @@ func TestAddBatch(t *testing.T) { var keys, values [][]byte for i := 0; i < 1000; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(0)) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(0)) keys = append(keys, k) values = append(values, v) } @@ -104,9 +106,10 @@ func TestAddDifferentOrder(t *testing.T) { c.Assert(err, qt.IsNil) defer tree1.db.Close() + bLen := tree1.HashFunction().Len() for i := 0; i < 16; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(0)) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(0)) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } @@ -117,8 +120,8 @@ func TestAddDifferentOrder(t *testing.T) { defer tree2.db.Close() for i := 16 - 1; i >= 0; i-- { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(0)) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(0)) if err := tree2.Add(k, v); err != nil { t.Fatal(err) } @@ -135,14 +138,15 @@ func TestAddRepeatedIndex(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() - k := BigIntToBytes(big.NewInt(int64(3))) - v := BigIntToBytes(big.NewInt(int64(12))) + bLen := tree.HashFunction().Len() + k := BigIntToBytes(bLen, big.NewInt(int64(3))) + v := BigIntToBytes(bLen, big.NewInt(int64(12))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } err = tree.Add(k, v) c.Assert(err, qt.Not(qt.IsNil)) - c.Check(err, qt.ErrorMatches, "max virtual level 100") + c.Check(err, qt.Equals, ErrMaxVirtualLevel) } func TestUpdate(t *testing.T) { @@ -151,13 +155,14 @@ func TestUpdate(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() - k := BigIntToBytes(big.NewInt(int64(20))) - v := BigIntToBytes(big.NewInt(int64(12))) + bLen := tree.HashFunction().Len() + k := BigIntToBytes(bLen, big.NewInt(int64(20))) + v := BigIntToBytes(bLen, big.NewInt(int64(12))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } - v = BigIntToBytes(big.NewInt(int64(11))) + v = BigIntToBytes(bLen, big.NewInt(int64(11))) err = tree.Update(k, v) c.Assert(err, qt.IsNil) @@ -168,21 +173,21 @@ func TestUpdate(t *testing.T) { // add more leafs to the tree to do another test for i := 0; i < 16; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } } - k = BigIntToBytes(big.NewInt(int64(3))) - v = BigIntToBytes(big.NewInt(int64(11))) + k = BigIntToBytes(bLen, big.NewInt(int64(3))) + v = BigIntToBytes(bLen, big.NewInt(int64(11))) // check that before the Update, value for 3 is !=11 gettedKey, gettedValue, err = tree.Get(k) c.Assert(err, qt.IsNil) c.Check(gettedKey, qt.DeepEquals, k) c.Check(gettedValue, qt.Not(qt.DeepEquals), v) - c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(6))) + c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(bLen, big.NewInt(6))) err = tree.Update(k, v) c.Assert(err, qt.IsNil) @@ -192,7 +197,7 @@ func TestUpdate(t *testing.T) { c.Assert(err, qt.IsNil) c.Check(gettedKey, qt.DeepEquals, k) c.Check(gettedValue, qt.DeepEquals, v) - c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11))) + c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(bLen, big.NewInt(11))) } func TestAux(t *testing.T) { // TODO split in proper tests @@ -201,29 +206,30 @@ func TestAux(t *testing.T) { // TODO split in proper tests c.Assert(err, qt.IsNil) defer tree.db.Close() - k := BigIntToBytes(big.NewInt(int64(1))) - v := BigIntToBytes(big.NewInt(int64(0))) + bLen := tree.HashFunction().Len() + k := BigIntToBytes(bLen, big.NewInt(int64(1))) + v := BigIntToBytes(bLen, big.NewInt(int64(0))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) - k = BigIntToBytes(big.NewInt(int64(256))) + k = BigIntToBytes(bLen, big.NewInt(int64(256))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) - k = BigIntToBytes(big.NewInt(int64(257))) + k = BigIntToBytes(bLen, big.NewInt(int64(257))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) - k = BigIntToBytes(big.NewInt(int64(515))) + k = BigIntToBytes(bLen, big.NewInt(int64(515))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) - k = BigIntToBytes(big.NewInt(int64(770))) + k = BigIntToBytes(bLen, big.NewInt(int64(770))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) - k = BigIntToBytes(big.NewInt(int64(388))) + k = BigIntToBytes(bLen, big.NewInt(int64(388))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) - k = BigIntToBytes(big.NewInt(int64(900))) + k = BigIntToBytes(bLen, big.NewInt(int64(900))) err = tree.Add(k, v) c.Assert(err, qt.IsNil) // @@ -237,19 +243,20 @@ func TestGet(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() + bLen := tree.HashFunction().Len() for i := 0; i < 10; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } } - k := BigIntToBytes(big.NewInt(int64(7))) + k := BigIntToBytes(bLen, big.NewInt(int64(7))) gettedKey, gettedValue, err := tree.Get(k) c.Assert(err, qt.IsNil) c.Check(gettedKey, qt.DeepEquals, k) - c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(int64(7*2)))) + c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(bLen, big.NewInt(int64(7*2)))) } func TestGenProofAndVerify(t *testing.T) { @@ -258,16 +265,17 @@ func TestGenProofAndVerify(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() + bLen := tree.HashFunction().Len() for i := 0; i < 10; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } } - k := BigIntToBytes(big.NewInt(int64(7))) - v := BigIntToBytes(big.NewInt(int64(14))) + k := BigIntToBytes(bLen, big.NewInt(int64(7))) + v := BigIntToBytes(bLen, big.NewInt(int64(14))) proofV, siblings, err := tree.GenProof(k) c.Assert(err, qt.IsNil) c.Assert(proofV, qt.DeepEquals, v) @@ -283,15 +291,16 @@ func TestDumpAndImportDump(t *testing.T) { c.Assert(err, qt.IsNil) defer tree1.db.Close() + bLen := tree1.HashFunction().Len() for i := 0; i < 16; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i * 2))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) if err := tree1.Add(k, v); err != nil { t.Fatal(err) } } - e, err := tree1.Dump() + e, err := tree1.Dump(nil) c.Assert(err, qt.IsNil) tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon) @@ -310,10 +319,11 @@ func TestRWMutex(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() + bLen := tree.HashFunction().Len() var keys, values [][]byte for i := 0; i < 1000; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(0)) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(0)) keys = append(keys, k) values = append(values, v) } @@ -325,8 +335,8 @@ func TestRWMutex(t *testing.T) { }() time.Sleep(500 * time.Millisecond) - k := BigIntToBytes(big.NewInt(int64(99999))) - v := BigIntToBytes(big.NewInt(int64(99999))) + k := BigIntToBytes(bLen, big.NewInt(int64(99999))) + v := BigIntToBytes(bLen, big.NewInt(int64(99999))) if err := tree.Add(k, v); err != nil { t.Fatal(err) } @@ -384,11 +394,12 @@ func TestSetGetNLeafs(t *testing.T) { } func BenchmarkAdd(b *testing.B) { + bLen := 32 // for both Poseidon & Sha256 // prepare inputs var ks, vs [][]byte for i := 0; i < 1000; i++ { - k := BigIntToBytes(big.NewInt(int64(i))) - v := BigIntToBytes(big.NewInt(int64(i))) + k := BigIntToBytes(bLen, big.NewInt(int64(i))) + v := BigIntToBytes(bLen, big.NewInt(int64(i))) ks = append(ks, k) vs = append(vs, v) } diff --git a/utils.go b/utils.go index de19c8f..e6e979e 100644 --- a/utils.go +++ b/utils.go @@ -12,8 +12,9 @@ func SwapEndianness(b []byte) []byte { } // BigIntToBytes converts a *big.Int into a byte array in Little-Endian -func BigIntToBytes(bi *big.Int) []byte { - var b [32]byte // TODO make the length depending on the tree.hashFunction.Len() +func BigIntToBytes(blen int, bi *big.Int) []byte { + // var b [blen]byte // TODO make the length depending on the tree.hashFunction.Len() + b := make([]byte, blen) copy(b[:], SwapEndianness(bi.Bytes())) return b[:] } diff --git a/vt.go b/vt.go index 5f9f0f8..5d13f95 100644 --- a/vt.go +++ b/vt.go @@ -378,7 +378,7 @@ func (n *node) typ() virtualNodeType { func (n *node) add(p *params, currLvl int, leaf *node) error { if currLvl > p.maxLevels-1 { - return fmt.Errorf("max virtual level %d", currLvl) + return ErrMaxVirtualLevel } if n == nil { @@ -411,8 +411,9 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { } case vtLeaf: if bytes.Equal(n.k, leaf.k) { - return fmt.Errorf("key already exists. Existing node: %s, trying to add node: %s", - hex.EncodeToString(n.k), hex.EncodeToString(leaf.k)) + return fmt.Errorf("%s. Existing node: %s, trying to add node: %s", + ErrKeyAlreadyExists, hex.EncodeToString(n.k), + hex.EncodeToString(leaf.k)) } oldLeaf := &node{ @@ -439,7 +440,7 @@ func (n *node) add(p *params, currLvl int, leaf *node) error { func (n *node) downUntilDivergence(p *params, currLvl int, oldLeaf, newLeaf *node) error { if currLvl > p.maxLevels-1 { - return fmt.Errorf("max virtual level %d", currLvl) + return ErrMaxVirtualLevel } if oldLeaf.path[currLvl] != newLeaf.path[currLvl] { diff --git a/vt_test.go b/vt_test.go index f6b4436..9d22226 100644 --- a/vt_test.go +++ b/vt_test.go @@ -12,53 +12,61 @@ import ( func TestVirtualTreeTestVectors(t *testing.T) { c := qt.New(t) + bLen := 32 keys := [][]byte{ - BigIntToBytes(big.NewInt(1)), - BigIntToBytes(big.NewInt(33)), - BigIntToBytes(big.NewInt(1234)), - BigIntToBytes(big.NewInt(123456789)), + BigIntToBytes(bLen, big.NewInt(1)), + BigIntToBytes(bLen, big.NewInt(33)), + BigIntToBytes(bLen, big.NewInt(1234)), + BigIntToBytes(bLen, big.NewInt(123456789)), } values := [][]byte{ - BigIntToBytes(big.NewInt(2)), - BigIntToBytes(big.NewInt(44)), - BigIntToBytes(big.NewInt(9876)), - BigIntToBytes(big.NewInt(987654321)), + BigIntToBytes(bLen, big.NewInt(2)), + BigIntToBytes(bLen, big.NewInt(44)), + BigIntToBytes(bLen, big.NewInt(9876)), + BigIntToBytes(bLen, big.NewInt(987654321)), } // check the root for different batches of leafs - // testVirtualTree(c, 10, keys[:1], values[:1]) - // testVirtualTree(c, 10, keys[:2], values[:2]) - // testVirtualTree(c, 10, keys[:3], values[:3]) + testVirtualTree(c, 10, keys[:1], values[:1]) + testVirtualTree(c, 10, keys[:2], values[:2]) + testVirtualTree(c, 10, keys[:3], values[:3]) testVirtualTree(c, 10, keys[:4], values[:4]) -} - -func TestVirtualTreeRandomKeys(t *testing.T) { - c := qt.New(t) // test with hardcoded values - keys := make([][]byte, 8) - values := make([][]byte, 8) - keys[0], _ = hex.DecodeString("1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642") - keys[1], _ = hex.DecodeString("2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf") - keys[2], _ = hex.DecodeString("9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e") - keys[3], _ = hex.DecodeString("9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d") - keys[4], _ = hex.DecodeString("1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5") - keys[5], _ = hex.DecodeString("d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7") - keys[6], _ = hex.DecodeString("3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c") - keys[7], _ = hex.DecodeString("5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5") + testvectorKeys := []string{ + "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642", + "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf", + "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e", + "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d", + "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5", + "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7", + "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c", + "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5", + } + keys = [][]byte{} + values = [][]byte{} + for i := 0; i < len(testvectorKeys); i++ { + key, err := hex.DecodeString(testvectorKeys[i]) + c.Assert(err, qt.IsNil) + keys = append(keys, key) + values = append(values, []byte{0}) + } // check the root for different batches of leafs testVirtualTree(c, 10, keys[:1], values[:1]) testVirtualTree(c, 10, keys, values) +} + +func TestVirtualTreeRandomKeys(t *testing.T) { + c := qt.New(t) // test with random values nLeafs := 1024 - - keys = make([][]byte, nLeafs) - values = make([][]byte, nLeafs) + keys := make([][]byte, nLeafs) + values := make([][]byte, nLeafs) for i := 0; i < nLeafs; i++ { keys[i] = randomBytes(32) - values[i] = []byte{0} + values[i] = randomBytes(32) } testVirtualTree(c, 100, keys, values)