From afa8201c2c7d92ce8c89ca2f6ea876c67c7d2f3f Mon Sep 17 00:00:00 2001 From: Jordi Baylina Date: Sun, 8 Dec 2019 16:20:15 +0100 Subject: [PATCH] All operators finished --- c/main.cpp | 9 +- src/c_gen.js | 153 ++++++++++++++++------------- src/c_tester.js | 3 + src/zqfield.js | 1 + test/basiccases.js | 41 ++++++++ test/circuits/condternary.circom | 15 +++ test/circuits/whilerolled.circom | 16 +++ test/circuits/whileunrolled.circom | 12 +++ 8 files changed, 178 insertions(+), 72 deletions(-) create mode 100644 test/circuits/condternary.circom create mode 100644 test/circuits/whilerolled.circom create mode 100644 test/circuits/whileunrolled.circom diff --git a/c/main.cpp b/c/main.cpp index 1d9f700..26d3167 100644 --- a/c/main.cpp +++ b/c/main.cpp @@ -100,7 +100,14 @@ void loadJson(Circom_CalcWit *ctx, std::string filename) { for (json::iterator it = j.begin(); it != j.end(); ++it) { // std::cout << it.key() << " => " << it.value() << '\n'; u64 h = fnv1a(it.key()); - int o = ctx->getSignalOffset(0, h); + int o; + try { + o = ctx->getSignalOffset(0, h); + } catch (std::runtime_error e) { + std::ostringstream errStrStream; + errStrStream << "Error loadin variable: " << it.key() << "\n" << e.what(); + throw std::runtime_error(errStrStream.str() ); + } Circom_Sizes sizes = ctx->getSignalSizes(0, h); iterateArr(ctx, o, sizes, it.value(), itFunc); } diff --git a/src/c_gen.js b/src/c_gen.js index 67a2755..bb845de 100644 --- a/src/c_gen.js +++ b/src/c_gen.js @@ -229,11 +229,11 @@ function gen(ctx, ast) { } else if (ast.type == "BLOCK") { return genBlock(ctx, ast); } else if (ast.type == "COMPUTE") { - return genCompute(ctx, ast); + return gen(ctx, ast.body); } else if (ast.type == "FOR") { - return genFor(ctx, ast); + return genLoop(ctx, ast); } else if (ast.type == "WHILE") { - return genWhile(ctx, ast); + return genLoop(ctx, ast); } else if (ast.type == "IF") { return genIf(ctx, ast); } else if (ast.type == "RETURN") { @@ -283,11 +283,18 @@ function genSrcComment(ctx, ast) { ctx.code += "\n/* "+code+" */\n"; } -function genForSrcComment(ctx, ast) { - const init = getSource(ctx, ast.init); - const condition = getSource(ctx, ast.condition); - const step = getSource(ctx, ast.step); - ctx.code += `\n/* for (${init},${condition},${step}) */\n`; +function genLoopSrcComment(ctx, ast) { + if (ast.type == "FOR") { + const init = getSource(ctx, ast.init); + const condition = getSource(ctx, ast.condition); + const step = getSource(ctx, ast.step); + ctx.code += `\n/* for (${init},${condition},${step}) */\n`; + } else if (ast.type == "WHILE") { + const condition = getSource(ctx, ast.condition); + ctx.code += `\n/* while (${condition}) */\n`; + } else { + assert(false, "Invalid loop type: "+ ast.type); + } } function genIfSrcComment(ctx, ast) { @@ -835,12 +842,14 @@ function leaveConditionalCode(ctx) { } } -function genFor(ctx, ast) { - genForSrcComment(ctx, ast); +function genLoop(ctx, ast) { + genLoopSrcComment(ctx, ast); let inLoop = false; - ctx.scopes.push({}); - gen(ctx, ast.init); - if (ctx.error) return; + + if (ast.init) { + gen(ctx, ast.init); + if (ctx.error) return; + } let end=false; let condVarRef; @@ -878,8 +887,10 @@ function genFor(ctx, ast) { gen(ctx, ast.body); if (ctx.error) return; - gen(ctx, ast.step); - if (ctx.error) return; + if (ast.step) { + gen(ctx, ast.step); + if (ctx.error) return; + } const condRef2 = gen(ctx, ast.condition); if (ctx.error) return; @@ -919,20 +930,6 @@ function genFor(ctx, ast) { ctx.scopes.pop(); } -function genWhile(ctx, ast) { - const condition = gen(ctx, ast.condition); - if (ctx.error) return; - const body = gen(ctx, ast.body); - if (ctx.error) return; - return `while (bigInt(${condition}).neq(bigInt(0))) {\n${body}\n}\n`; -} - -function genCompute(ctx, ast) { - const body = gen(ctx, ast.body); - if (ctx.error) return; - return `{\n${body}\n}\n`; -} - function genIf(ctx, ast) { genIfSrcComment(ctx, ast); const condRef = gen(ctx, ast.condition); @@ -1095,54 +1092,68 @@ function genOp(ctx, ast, op, nOps) { return rRef; } -function genBAnd(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `bigInt(${a}).and(bigInt(${b})).and(__MASK__)`; -} -function genAnd(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); +function genTerCon(ctx, ast) { + const condRef = gen(ctx, ast.values[0]); if (ctx.error) return; - return `((bigInt(${a}).neq(bigInt(0)) && bigInt(${b}).neq(bigInt(0))) ? bigInt(1) : bigInt(0))`; -} + const cond = ctx.refs[condRef]; + if (!utils.sameSizes(cond.sizes, [1,0])) return ctx.throwError(ast.condition, "Operation cannot be done on an array"); -function genOr(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `((bigInt(${a}).neq(bigInt(0)) || bigInt(${b}).neq(bigInt(0))) ? bigInt(1) : bigInt(0))`; -} + let oldCode; + if (cond.used) { + enterConditionalCode(ctx, ast); -function genShl(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `bigInt(${b}).greater(bigInt(256)) ? 0 : bigInt(${a}).shl(bigInt(${b})).and(__MASK__)`; -} + const rLabel = ctx.getUniqueName("_ter"); -function genShr(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `bigInt(${b}).greater(bigInt(256)) ? 0 : bigInt(${a}).shr(bigInt(${b})).and(__MASK__)`; -} + ctx.codeHeader += `PBigInt ${rLabel};\n`; + ctx.code += `if (ctx->field->isTrue(${cond.label})) {\n`; -function genTerCon(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - const c = gen(ctx, ast.values[2]); - if (ctx.error) return; - return `bigInt(${a}).neq(bigInt(0)) ? (${b}) : (${c})`; + oldCode = ctx.code; + ctx.code = ""; + + const thenRef = gen(ctx, ast.values[1]); + if (ctx.error) return; + const then = ctx.refs[thenRef]; + + ctx.code = oldCode + utils.ident(ctx.code); + + ctx.code += `${rLabel} = ${then.label};\n`; + + ctx.code += "} else {\n"; + + oldCode = ctx.code; + ctx.code = ""; + const elseRef = gen(ctx, ast.values[2]); + if (ctx.error) return; + const els = ctx.refs[elseRef]; + + ctx.code = oldCode + utils.ident(ctx.code); + + ctx.code += `${rLabel} = ${els.label};\n`; + + ctx.code += "}\n"; + + + if (!utils.sameSizes(then.sizes, els.sizes)) return ctx.throwError(ast, "Ternary op must return the same sizes"); + + const refId = ctx.refs.length; + ctx.refs.push({ + type: "BIGINT", + sizes: then.sizes, + used: true, + label: rLabel + }); + + return refId; + + } else { + if (!utils.isDefined(cond.value)) return ctx.throwError(ast, "condition value not assigned"); + if (!cond.value[0].isZero()) { + return gen(ctx, ast.values[1]); + } else { + return gen(ctx, ast.values[2]); + } + } } function genInclude(ctx, ast) { diff --git a/src/c_tester.js b/src/c_tester.js index e8dc1d8..1f85b58 100644 --- a/src/c_tester.js +++ b/src/c_tester.js @@ -108,6 +108,9 @@ class CTester { checkObject(prefix + "."+k, eOut[k]); } } else { + if (typeof self.symbols[prefix] == "undefined") { + assert(false, "Output variable not defined: "+ prefix); + } const ba = bigInt(actualOut[self.symbols[prefix].idxWit]).toString(); const be = bigInt(eOut).toString(); assert.strictEqual(ba, be, prefix); diff --git a/src/zqfield.js b/src/zqfield.js index 86f64c1..565788f 100644 --- a/src/zqfield.js +++ b/src/zqfield.js @@ -58,6 +58,7 @@ module.exports = class ZqField { } div(a, b) { + assert(!b.isZero(), "Division by zero"); return a.mul(b.modInv(this.p)).mod(this.p); } diff --git a/test/basiccases.js b/test/basiccases.js index b79de94..3baeecf 100644 --- a/test/basiccases.js +++ b/test/basiccases.js @@ -92,6 +92,25 @@ describe("basic cases", function () { ] ); }); + it("while unrolled", async () => { + await doTest( + "whileunrolled.circom", + [ + [{in: 0}, {out: [0,1,2]}], + [{in: 10}, {out: [10, 11, 12]}], + [{in: __P__.minus(2)}, {out: [__P__.minus(2), __P__.minus(1), 0]}], + ] + ); + }); + it("while rolled", async () => { + await doTest( + "whilerolled.circom", + [ + [{in: 0}, {out: 0}], + [{in: 10}, {out: 10}], + ] + ); + }); it("function1", async () => { await doTest( "function1.circom", @@ -246,4 +265,26 @@ describe("basic cases", function () { ] ); }); + it("Conditional Ternary operator", async () => { + await doTest( + "condternary.circom", + [ + [{in: 0}, {out: 21}], + [{in: 1}, {out: 1}], + [{in: 2}, {out: 23}], + [{in:-1}, {out: 20}], + ] + ); + }); + it("Compute block", async () => { + await doTest( + "compute.circom", + [ + [{x: 1}, {y: 7}], + [{x: 2}, {y: 7}], + [{x: 3}, {y: 11}], + [{x:-1}, {y: -5}], + ] + ); + }); }); diff --git a/test/circuits/condternary.circom b/test/circuits/condternary.circom new file mode 100644 index 0000000..476783f --- /dev/null +++ b/test/circuits/condternary.circom @@ -0,0 +1,15 @@ +template CondTernary() { + signal input in; + signal output out; + + var a = 3; + var b = a==3 ? 1 : 2; // b is 1 + var c = a!=3 ? 10 : 20; // c is 20 + var d = b+c; // d is 21 + + + out <-- ((in & 1) != 1) ? in + d : in; // Add 21 if in is pair + +} + +component main = CondTernary() diff --git a/test/circuits/whilerolled.circom b/test/circuits/whilerolled.circom new file mode 100644 index 0000000..e8e7d6e --- /dev/null +++ b/test/circuits/whilerolled.circom @@ -0,0 +1,16 @@ +template WhileRolled() { + signal input in; + signal output out; + + var acc = 0; + + var i=0; + while (i