diff --git a/circuits/smt/smtinsert.circom b/circuits/smt/smtprocessor.circom similarity index 98% rename from circuits/smt/smtinsert.circom rename to circuits/smt/smtprocessor.circom index 4206ef9..ac45b60 100644 --- a/circuits/smt/smtinsert.circom +++ b/circuits/smt/smtprocessor.circom @@ -110,11 +110,11 @@ include "../bitify.circom"; include "../comparators.circom"; include "../switcher.circom"; include "smtlevins.circom"; -include "smtinsertlevel.circom"; -include "smtinsertsm.circom"; +include "smtprocessorlevel.circom"; +include "smtprocessorsm.circom"; include "smthash.circom"; -template SMTInsert(nLevels) { +template SMTProcessor(nLevels) { signal input oldRoot; signal input newRoot; signal input siblings[nLevels]; @@ -156,7 +156,7 @@ template SMTInsert(nLevels) { component sm[nLevels]; for (var i=0; i=0; level--) { + let oldNode, newNode; + const sibling = resFind.siblings[level]; + if (keyBits[level]) { + oldNode = [sibling, rtOld]; + newNode = [sibling, rtNew]; + } else { + oldNode = [rtOld, sibling, ]; + newNode = [rtNew, sibling, ]; + } + rtOld = smtHash(oldNode); + rtNew = smtHash(newNode); + dels.push(rtOld); + ins.push([rtNew, newNode]); + } + + res.newRoot = rtNew; + + await this.db.multiIns(ins); + await this.db.setRoot(rtNew); + this.root = rtNew; + await this.db.multiDel(dels); + + return res; + } + async delete(_key) { const key = bigInt(_key); @@ -44,7 +93,7 @@ class SMT { if (!resFind.found) throw new Error("Key does not exists"); const res = { - sibblings: [], + siblings: [], delKey: key, delValue: resFind.foundValue }; @@ -55,16 +104,15 @@ class SMT { let rtNew; dels.push(rtOld); - let mixed; - if (resFind.sibblings.length > 0) { - const record = await this.db.get(resFind.sibblings[resFind.sibblings.length - 1]); + if (resFind.siblings.length > 0) { + const record = await this.db.get(resFind.siblings[resFind.siblings.length - 1]); if ((record.length == 3)&&(record[0].equals(bigInt.one))) { mixed = false; res.oldKey = record[1]; res.oldValue = record[2]; res.isOld0 = false; - rtNew = resFind.sibblings[resFind.sibblings.length - 1]; + rtNew = resFind.siblings[resFind.siblings.length - 1]; } else if (record.length == 2) { mixed = true; res.oldKey = key; @@ -83,12 +131,12 @@ class SMT { const keyBits = this._splitBits(key); - for (let level = resFind.sibblings.length-1; level >=0; level--) { - let newSibling = resFind.sibblings[level]; - if ((level == resFind.sibblings.length-1)&&(!res.isOld0)) { + for (let level = resFind.siblings.length-1; level >=0; level--) { + let newSibling = resFind.siblings[level]; + if ((level == resFind.siblings.length-1)&&(!res.isOld0)) { newSibling = bigInt.zero; } - const oldSibling = resFind.sibblings[level]; + const oldSibling = resFind.siblings[level]; if (keyBits[level]) { rtOld = smtHash([oldSibling, rtOld]); } else { @@ -100,7 +148,7 @@ class SMT { } if (mixed) { - res.sibblings.unshift(resFind.sibblings[level]); + res.siblings.unshift(resFind.siblings[level]); let newNode; if (keyBits[level]) { newNode = [newSibling, rtNew]; @@ -137,19 +185,19 @@ class SMT { if (resFind.found) throw new Error("Key already exists"); - res.sibblings = resFind.sibblings; + res.siblings = resFind.siblings; let mixed; if (!resFind.isOld0) { const oldKeyits = this._splitBits(resFind.notFoundKey); - for (let i= res.sibblings.length; oldKeyits[i] == newKeyBits[i]; i++) { - res.sibblings.push(bigInt.zero); + for (let i= res.siblings.length; oldKeyits[i] == newKeyBits[i]; i++) { + res.siblings.push(bigInt.zero); } rtOld = smtHash([1, resFind.notFoundKey, resFind.notFoundValue]); - res.sibblings.push(rtOld); + res.siblings.push(rtOld); addedOne = true; mixed = false; - } else if (res.sibblings.length >0) { + } else if (res.siblings.length >0) { mixed = true; rtOld = bigInt.zero; } @@ -160,12 +208,12 @@ class SMT { let rt = smtHash([1, key, value]); inserts.push([rt,[1, key, value]] ); - for (let i=res.sibblings.length-1; i>=0; i--) { - if ((i=0; i--) { + if ((i0) && (res.sibblings[res.sibblings.length-1].isZero())) { - res.sibblings.pop(); + if (addedOne) res.siblings.pop(); + while ((res.siblings.length>0) && (res.siblings[res.siblings.length-1].isZero())) { + res.siblings.pop(); } res.oldKey = resFind.notFoundKey; res.oldValue = resFind.notFoundValue; @@ -216,7 +264,7 @@ class SMT { if (root.isZero()) { res = { found: false, - sibblings: [], + siblings: [], notFoundKey: key, notFoundValue: bigInt.zero, isOld0: true @@ -230,14 +278,14 @@ class SMT { if (record[1].equals(key)) { res = { found: true, - sibblings: [], + siblings: [], foundValue: record[2], isOld0: false }; } else { res = { found: false, - sibblings: [], + siblings: [], notFoundKey: record[1], notFoundValue: record[2], isOld0: false @@ -246,10 +294,10 @@ class SMT { } else { if (keyBits[level] == 0) { res = await this._find(key, keyBits, record[0], level+1); - res.sibblings.unshift(record[1]); + res.siblings.unshift(record[1]); } else { res = await this._find(key, keyBits, record[1], level+1); - res.sibblings.unshift(record[0]); + res.siblings.unshift(record[0]); } } return res; diff --git a/test/circuits/smtinsert10_test.circom b/test/circuits/smtinsert10_test.circom deleted file mode 100644 index ca866c1..0000000 --- a/test/circuits/smtinsert10_test.circom +++ /dev/null @@ -1,3 +0,0 @@ -include "../../circuits/smt/smtinsert.circom"; - -component main = SMTInsert(10); diff --git a/test/circuits/smtprocessor10_test.circom b/test/circuits/smtprocessor10_test.circom new file mode 100644 index 0000000..ecf15d0 --- /dev/null +++ b/test/circuits/smtprocessor10_test.circom @@ -0,0 +1,3 @@ +include "../../circuits/smt/smtprocessor.circom"; + +component main = SMTProcessor(10); diff --git a/test/smt.js b/test/smt.js index 4459645..a12cdd7 100644 --- a/test/smt.js +++ b/test/smt.js @@ -16,7 +16,7 @@ function print(circuit, w, s) { async function testInsert(tree, key, value, circuit, log ) { const res = await tree.insert(key,value); - let siblings = res.sibblings; + let siblings = res.siblings; while (siblings.length<10) siblings.push(bigInt(0)); const w = circuit.calculateWitness({ @@ -38,7 +38,7 @@ async function testInsert(tree, key, value, circuit, log ) { async function testDelete(tree, key, circuit) { const res = await tree.delete(key); - let siblings = res.sibblings; + let siblings = res.siblings; while (siblings.length<10) siblings.push(bigInt(0)); const w = circuit.calculateWitness({ @@ -59,6 +59,29 @@ async function testDelete(tree, key, circuit) { assert(root1.equals(res.newRoot)); } +async function testUpdate(tree, key, newValue, circuit) { + const res = await tree.update(key, newValue); + let siblings = res.siblings; + while (siblings.length<10) siblings.push(bigInt(0)); + + const w = circuit.calculateWitness({ + fnc: [0,1], + oldRoot: res.oldRoot, + newRoot: res.newRoot, + siblings: siblings, + oldKey: res.oldKey, + oldValue: res.oldValue, + isOld0: 0, + newKey: res.newKey, + newValue: res.newValue + }); + + const root1 = w[circuit.getSignalIdx("main.topSwitcher.outR")]; + + assert(circuit.checkWitness(w)); + assert(root1.equals(res.newRoot)); +} + describe("SMT test", function () { let circuit; @@ -67,11 +90,11 @@ describe("SMT test", function () { this.timeout(100000); before( async () => { - const cirDef = await compiler(path.join(__dirname, "circuits", "smtinsert10_test.circom")); + const cirDef = await compiler(path.join(__dirname, "circuits", "smtprocessor10_test.circom")); circuit = new snarkjs.Circuit(cirDef); - console.log("NConstrains SMTInsert: " + circuit.nConstraints); + console.log("NConstrains SMTProcessor: " + circuit.nConstraints); tree = await smt.newMemEmptyTrie(); }); @@ -175,7 +198,20 @@ describe("SMT test", function () { }); it("Should update an element", async () => { + const tree1 = await smt.newMemEmptyTrie(); + const tree2 = await smt.newMemEmptyTrie(); + + await testInsert(tree1,8,88, circuit); + await testInsert(tree1,9,99, circuit); + await testInsert(tree1,32,3232, circuit); + + await testInsert(tree2,8,888, circuit); + await testInsert(tree2,9,999, circuit); + await testInsert(tree2,32,323232, circuit); + await testUpdate(tree1, 8, 888, circuit); + await testUpdate(tree1, 9, 999, circuit); + await testUpdate(tree1, 32, 323232, circuit); }); it("Should verify existance of an element", async () => { diff --git a/test/smtjs.js b/test/smtjs.js index b632b0d..eb360ab 100644 --- a/test/smtjs.js +++ b/test/smtjs.js @@ -160,4 +160,23 @@ describe("SMT Javascript test", function () { assert.equal(Object.keys(tree.db.nodes).length, 0); }); + it("Should test update", async () => { + const tree1 = await smt.newMemEmptyTrie(); + const tree2 = await smt.newMemEmptyTrie(); + + await tree1.insert(8,88); + await tree1.insert(9,99,); + await tree1.insert(32,3232); + + await tree2.insert(8,888); + await tree2.insert(9,999); + await tree2.insert(32,323232); + + await tree1.update(8, 888); + await tree1.update(9, 999); + await tree1.update(32, 323232); + + assert(tree1.root.equals(tree2.root)); + }); + });