Browse Source

Poseidon in SMT

feature/pr-19
Jordi Baylina 5 years ago
parent
commit
c4490b2ce9
No known key found for this signature in database GPG Key ID: 7480C80C1BE43112
9 changed files with 98 additions and 20 deletions
  1. +0
    -0
      circuits/smt/smthash_mimc.circom
  2. +56
    -0
      circuits/smt/smthash_poseidon.circom
  3. +1
    -1
      circuits/smt/smtprocessor.circom
  4. +1
    -1
      circuits/smt/smtverifier.circom
  5. +15
    -15
      src/smt.js
  6. +10
    -0
      src/smt_hashes_mimc.js
  7. +12
    -0
      src/smt_hashes_poseidon.js
  8. +2
    -2
      test/poseidoncontract.js
  9. +1
    -1
      test/smtprocessor.js

circuits/smt/smthash.circom → circuits/smt/smthash_mimc.circom


+ 56
- 0
circuits/smt/smthash_poseidon.circom

@ -0,0 +1,56 @@
/*
Copyright 2018 0KIMS association.
This file is part of circom (Zero Knowledge Circuit Compiler).
circom is a free software: you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
circom is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
License for more details.
You should have received a copy of the GNU General Public License
along with circom. If not, see <https://www.gnu.org/licenses/>.
*/
include "../poseidon.circom";
/*
Hash1 = H(1 | key | value)
*/
template SMTHash1() {
signal input key;
signal input value;
signal output out;
component h = Poseidon(3, 6, 8, 57); // Constant
h.inputs[0] <== key;
h.inputs[1] <== value;
h.inputs[2] <== 1;
out <== h.out;
}
/*
This component is used to create the 2 nodes.
Hash2 = H(Hl | Hr)
*/
template SMTHash2() {
signal input L;
signal input R;
signal output out;
component h = Poseidon(2, 6, 8, 57); // Constant
h.inputs[0] <== L;
h.inputs[1] <== R;
out <== h.out;
}

+ 1
- 1
circuits/smt/smtprocessor.circom

@ -135,7 +135,7 @@ include "../switcher.circom";
include "smtlevins.circom"; include "smtlevins.circom";
include "smtprocessorlevel.circom"; include "smtprocessorlevel.circom";
include "smtprocessorsm.circom"; include "smtprocessorsm.circom";
include "smthash.circom";
include "smthash_poseidon.circom";
template SMTProcessor(nLevels) { template SMTProcessor(nLevels) {
signal input oldRoot; signal input oldRoot;

+ 1
- 1
circuits/smt/smtverifier.circom

@ -35,7 +35,7 @@ include "../switcher.circom";
include "smtlevins.circom"; include "smtlevins.circom";
include "smtverifierlevel.circom"; include "smtverifierlevel.circom";
include "smtverifiersm.circom"; include "smtverifiersm.circom";
include "smthash.circom";
include "smthash_poseidon.circom";
template SMTVerifier(nLevels) { template SMTVerifier(nLevels) {
signal input enabled; signal input enabled;

+ 15
- 15
src/smt.js

@ -1,7 +1,7 @@
const bigInt = require("snarkjs").bigInt; const bigInt = require("snarkjs").bigInt;
const SMTMemDB = require("./smt_memdb"); const SMTMemDB = require("./smt_memdb");
const mimc7 = require("./mimc7");
const {hash0, hash1} = require("./smt_hashes_poseidon");
class SMT { class SMT {
@ -46,8 +46,8 @@ class SMT {
const ins = []; const ins = [];
const dels = []; const dels = [];
let rtOld = mimc7.multiHash([key, resFind.foundValue], bigInt.one);
let rtNew = mimc7.multiHash([key, newValue], bigInt.one);
let rtOld = hash1(key, resFind.foundValue);
let rtNew = hash1(key, newValue);
ins.push([rtNew, [1, key, newValue ]]); ins.push([rtNew, [1, key, newValue ]]);
dels.push(rtOld); dels.push(rtOld);
@ -62,8 +62,8 @@ class SMT {
oldNode = [rtOld, sibling]; oldNode = [rtOld, sibling];
newNode = [rtNew, sibling]; newNode = [rtNew, sibling];
} }
rtOld = mimc7.multiHash(oldNode, bigInt.zero);
rtNew = mimc7.multiHash(newNode, bigInt.zero);
rtOld = hash0(oldNode[0], oldNode[1]);
rtNew = hash0(newNode[0], newNode[1]);
dels.push(rtOld); dels.push(rtOld);
ins.push([rtNew, newNode]); ins.push([rtNew, newNode]);
} }
@ -92,7 +92,7 @@ class SMT {
const dels = []; const dels = [];
const ins = []; const ins = [];
let rtOld = mimc7.multiHash([key, resFind.foundValue], bigInt.one);
let rtOld = hash1(key, resFind.foundValue);
let rtNew; let rtNew;
dels.push(rtOld); dels.push(rtOld);
@ -130,9 +130,9 @@ class SMT {
} }
const oldSibling = resFind.siblings[level]; const oldSibling = resFind.siblings[level];
if (keyBits[level]) { if (keyBits[level]) {
rtOld = mimc7.multiHash([oldSibling, rtOld], bigInt.zero);
rtOld = hash0(oldSibling, rtOld);
} else { } else {
rtOld = mimc7.multiHash([rtOld, oldSibling], bigInt.zero);
rtOld = hash0(rtOld, oldSibling);
} }
dels.push(rtOld); dels.push(rtOld);
if (!newSibling.isZero()) { if (!newSibling.isZero()) {
@ -147,7 +147,7 @@ class SMT {
} else { } else {
newNode = [rtNew, newSibling]; newNode = [rtNew, newSibling];
} }
rtNew = mimc7.multiHash(newNode, bigInt.zero);
rtNew = hash0(newNode[0], newNode[1]);
ins.push([rtNew, newNode]); ins.push([rtNew, newNode]);
} }
} }
@ -185,7 +185,7 @@ class SMT {
for (let i= res.siblings.length; oldKeyits[i] == newKeyBits[i]; i++) { for (let i= res.siblings.length; oldKeyits[i] == newKeyBits[i]; i++) {
res.siblings.push(bigInt.zero); res.siblings.push(bigInt.zero);
} }
rtOld = mimc7.multiHash([resFind.notFoundKey, resFind.notFoundValue], bigInt.one);
rtOld = hash1(resFind.notFoundKey, resFind.notFoundValue);
res.siblings.push(rtOld); res.siblings.push(rtOld);
addedOne = true; addedOne = true;
mixed = false; mixed = false;
@ -197,7 +197,7 @@ class SMT {
const inserts = []; const inserts = [];
const dels = []; const dels = [];
let rt = mimc7.multiHash([key, value], bigInt.one);
let rt = hash1(key, value);
inserts.push([rt,[1, key, value]] ); inserts.push([rt,[1, key, value]] );
for (let i=res.siblings.length-1; i>=0; i--) { for (let i=res.siblings.length-1; i>=0; i--) {
@ -207,9 +207,9 @@ class SMT {
if (mixed) { if (mixed) {
const oldSibling = resFind.siblings[i]; const oldSibling = resFind.siblings[i];
if (newKeyBits[i]) { if (newKeyBits[i]) {
rtOld = mimc7.multiHash([oldSibling, rtOld], bigInt.zero);
rtOld = hash0(oldSibling, rtOld);
} else { } else {
rtOld = mimc7.multiHash([rtOld, oldSibling], bigInt.zero);
rtOld = hash0(rtOld, oldSibling);
} }
dels.push(rtOld); dels.push(rtOld);
} }
@ -217,10 +217,10 @@ class SMT {
let newRt; let newRt;
if (newKeyBits[i]) { if (newKeyBits[i]) {
newRt = mimc7.multiHash([res.siblings[i], rt], bigInt.zero);
newRt = hash0(res.siblings[i], rt);
inserts.push([newRt,[res.siblings[i], rt]] ); inserts.push([newRt,[res.siblings[i], rt]] );
} else { } else {
newRt = mimc7.multiHash([rt, res.siblings[i]], bigInt.zero);
newRt = hash0(rt, res.siblings[i]);
inserts.push([newRt,[rt, res.siblings[i]]] ); inserts.push([newRt,[rt, res.siblings[i]]] );
} }
rt = newRt; rt = newRt;

+ 10
- 0
src/smt_hashes_mimc.js

@ -0,0 +1,10 @@
const mimc7 = require("./mimc7");
const bigInt = require("snarkjs").bigInt;
exports.hash0 = function (left, right) {
return mimc7.multiHash(left, right);
};
exports.hash1 = function(key, value) {
return mimc7.multiHash([key, value], bigInt.one);
};

+ 12
- 0
src/smt_hashes_poseidon.js

@ -0,0 +1,12 @@
const Poseidon = require("./poseidon");
const bigInt = require("snarkjs").bigInt;
const hash = Poseidon.createHash(6, 8, 57);
exports.hash0 = function (left, right) {
return hash([left, right]);
};
exports.hash1 = function(key, value) {
return hash([key, value, bigInt.one]);
};

+ 2
- 2
test/poseidoncontract.js

@ -46,12 +46,12 @@ describe("Poseidon Smart contract test", () => {
const res = await mimc.methods.poseidon([1,2]).call(); const res = await mimc.methods.poseidon([1,2]).call();
console.log("Cir: " + bigInt(res.toString(16)).toString(16));
// console.log("Cir: " + bigInt(res.toString(16)).toString(16));
const hash = Poseidon.createHash(6, 8, 57); const hash = Poseidon.createHash(6, 8, 57);
const res2 = hash([1,2]); const res2 = hash([1,2]);
console.log("Ref: " + bigInt(res2).toString(16));
// console.log("Ref: " + bigInt(res2).toString(16));
assert.equal(res.toString(), res2.toString()); assert.equal(res.toString(), res2.toString());
}); });

+ 1
- 1
test/smtprocessor.js

@ -84,7 +84,7 @@ describe("SMT test", function () {
let circuit; let circuit;
let tree; let tree;
this.timeout(100000);
this.timeout(10000000);
before( async () => { before( async () => {
const cirDef = await compiler(path.join(__dirname, "circuits", "smtprocessor10_test.circom")); const cirDef = await compiler(path.join(__dirname, "circuits", "smtprocessor10_test.circom"));

Loading…
Cancel
Save