diff --git a/c/zqfield.cpp b/c/zqfield.cpp index 764337e..0a95524 100644 --- a/c/zqfield.cpp +++ b/c/zqfield.cpp @@ -33,6 +33,51 @@ void ZqField::lt(PBigInt r, PBigInt a, PBigInt b) { } } +void ZqField::eq(PBigInt r, PBigInt a, PBigInt b) { + int c = mpz_cmp(*a, *b); + if (c==0) { + mpz_set(*r, one); + } else { + mpz_set(*r, zero); + } +} + +void ZqField::gt(PBigInt r, PBigInt a, PBigInt b) { + int c = mpz_cmp(*a, *b); + if (c>0) { + mpz_set(*r, one); + } else { + mpz_set(*r, zero); + } +} + +void ZqField::leq(PBigInt r, PBigInt a, PBigInt b) { + int c = mpz_cmp(*a, *b); + if (c<=0) { + mpz_set(*r, one); + } else { + mpz_set(*r, zero); + } +} + +void ZqField::geq(PBigInt r, PBigInt a, PBigInt b) { + int c = mpz_cmp(*a, *b); + if (c>=0) { + mpz_set(*r, one); + } else { + mpz_set(*r, zero); + } +} + +void ZqField::neq(PBigInt r, PBigInt a, PBigInt b) { + int c = mpz_cmp(*a, *b); + if (c!=0) { + mpz_set(*r, one); + } else { + mpz_set(*r, zero); + } +} + int ZqField::isTrue(PBigInt a) { return mpz_sgn(*a); } diff --git a/c/zqfield.h b/c/zqfield.h index fcb8e7b..ee67ac0 100644 --- a/c/zqfield.h +++ b/c/zqfield.h @@ -17,6 +17,11 @@ public: void add(PBigInt r,PBigInt a, PBigInt b); void mul(PBigInt r,PBigInt a, PBigInt b); void lt(PBigInt r, PBigInt a, PBigInt b); + void eq(PBigInt r, PBigInt a, PBigInt b); + void gt(PBigInt r, PBigInt a, PBigInt b); + void leq(PBigInt r, PBigInt a, PBigInt b); + void geq(PBigInt r, PBigInt a, PBigInt b); + void neq(PBigInt r, PBigInt a, PBigInt b); int isTrue(PBigInt a); }; diff --git a/src/c_gen.js b/src/c_gen.js index 3dfe055..c948a86 100644 --- a/src/c_gen.js +++ b/src/c_gen.js @@ -190,15 +190,15 @@ function gen(ctx, ast) { } else if (ast.op == "<") { return genBinaryOp(ctx, ast, "lt"); } else if (ast.op == ">") { - return genGt(ctx, ast); + return genBinaryOp(ctx, ast, "gt"); } else if (ast.op == "<=") { - return genLte(ctx, ast); + return genBinaryOp(ctx, ast, "leq"); } else if (ast.op == ">=") { - return genGte(ctx, ast); + return genBinaryOp(ctx, ast, "geq"); } else if (ast.op == "==") { - return genEq(ctx, ast); + return genBinaryOp(ctx, ast, "eq"); } else if (ast.op == "!=") { - return genNeq(ctx, ast); + return genBinaryOp(ctx, ast, "neq"); } else if (ast.op == "?") { return genTerCon(ctx, ast); } else { @@ -282,6 +282,11 @@ function genForSrcComment(ctx, ast) { ctx.code += `\n/* for (${init},${condition},${step}) */\n`; } +function genIfSrcComment(ctx, ast) { + const condition = getSource(ctx, ast.condition); + ctx.code += `\n/* if (${condition}) */\n`; +} + function genDeclareComponent(ctx, ast) { return ast.refId; @@ -835,6 +840,8 @@ function genFor(ctx, ast) { const condRef = gen(ctx, ast.condition); + if (ctx.error) return; + const cond = ctx.refs[condRef]; if (!utils.sameSizes(cond.sizes, [1,0])) return ctx.throwError(ast.condition, "Operation cannot be done on an array"); @@ -867,6 +874,8 @@ function genFor(ctx, ast) { if (ctx.error) return; const condRef2 = gen(ctx, ast.condition); + if (ctx.error) return; + const cond2 = ctx.refs[condRef2]; if (!inLoop) { @@ -917,16 +926,45 @@ function genCompute(ctx, ast) { } function genIf(ctx, ast) { - const condition = gen(ctx, ast.condition); - if (ctx.error) return; - const thenBody = gen(ctx, ast.then); + genIfSrcComment(ctx, ast); + const condRef = gen(ctx, ast.condition); if (ctx.error) return; - if (ast.else) { - const elseBody = gen(ctx, ast.else); + const cond = ctx.refs[condRef]; + if (!utils.sameSizes(cond.sizes, [1,0])) return ctx.throwError(ast.condition, "Operation cannot be done on an array"); + + if (cond.used) { + enterConditionalCode(ctx, ast); + + ctx.code += `if (ctx->field->isTrue(${cond.label})) {\n`; + + const oldCode = ctx.code; + ctx.code = ""; + + gen(ctx, ast.then); if (ctx.error) return; - return `if (bigInt(${condition}).neq(bigInt(0))) {\n${thenBody}\n} else {\n${elseBody}\n}\n`; + + ctx.code = oldCode + utils.ident(ctx.code); + + if (ast.else) { + ctx.code += "} else {\n"; + const oldCode = ctx.code; + ctx.code = ""; + gen(ctx, ast.else); + if (ctx.error) return; + ctx.code = oldCode + utils.ident(ctx.code); + } + + ctx.code += "}\n"; + } else { - return `if (bigInt(${condition}).neq(bigInt(0))) {\n${thenBody}\n}\n`; + if (!utils.isDefined(cond.value)) return ctx.throwError(ast, "condition value not assigned"); + if (!cond.value[0].isZero()) { + gen(ctx, ast.then); + } else { + if (ast.else) { + gen(ctx, ast.else); + } + } } } diff --git a/src/zqfield.js b/src/zqfield.js index fac7f4c..24c373f 100644 --- a/src/zqfield.js +++ b/src/zqfield.js @@ -17,5 +17,25 @@ module.exports = class ZqField { return a.lt(b) ? bigInt(1) : bigInt(0); } + eq(a, b) { + return a.eq(b) ? bigInt(1) : bigInt(0); + } + + gt(a, b) { + return a.gt(b) ? bigInt(1) : bigInt(0); + } + + leq(a, b) { + return a.leq(b) ? bigInt(1) : bigInt(0); + } + + geq(a, b) { + return a.geq(b) ? bigInt(1) : bigInt(0); + } + + neq(a, b) { + return a.neq(b) ? bigInt(1) : bigInt(0); + } + }; diff --git a/test/basiccases.js b/test/basiccases.js index c19c4ea..6d48d0b 100644 --- a/test/basiccases.js +++ b/test/basiccases.js @@ -107,4 +107,26 @@ describe("basic cases", function () { ] ); }); + it("if unrolled", async () => { + await doTest( + "ifunrolled.circom", + [ + [{in: 0}, {out: [1, 3, 6]}], + [{in: 10}, {out: [11, 13, 16]}], + [{in: __P__.minus(2)}, {out: [__P__.minus(1), 1, 4]}], + ] + ); + }); + it("if rolled", async () => { + await doTest( + "ifrolled.circom", + [ + [{in: 0}, {out: [1, 0, 0]}], + [{in: 1}, {out: [0, 1, 0]}], + [{in: 2}, {out: [0, 0, 1]}], + [{in: 3}, {out: [0, 0, 0]}], + [{in: __P__.minus(2)}, {out: [0,0,0]}], + ] + ); + }); }); diff --git a/test/circuits/ifrolled.circom b/test/circuits/ifrolled.circom new file mode 100644 index 0000000..5f4b2aa --- /dev/null +++ b/test/circuits/ifrolled.circom @@ -0,0 +1,26 @@ +template Main() { + signal input in; + signal output out[3]; + + if (in == 0) { + out[0] <-- 1; // TRUE + } + if (in != 0) { + out[0] <-- 0; + } + + if (in == 1) { + out[1] <-- 1; // TRUE + } else { + out[1] <-- 0; + } + + if (in == 2) { + out[2] <-- 1; + } else { + out[2] <-- 0; // TRUE + } + +} + +component main = Main(); diff --git a/test/circuits/ifunrolled.circom b/test/circuits/ifunrolled.circom new file mode 100644 index 0000000..00903ea --- /dev/null +++ b/test/circuits/ifunrolled.circom @@ -0,0 +1,31 @@ + + +template Main() { + signal input in; + signal output out[3]; + + var c = 1; + if (c == 1) { + out[0] <== in +1; // TRUE + } + if (c == 0) { + out[0] <== in +2; + } + + c = c +1; + if (c == 2) { + out[1] <== in + 3; // TRUE + } else { + out[1] <== in + 4; + } + + c = c +1; + if (c == 2) { + out[2] <== in + 5; + } else { + out[2] <== in + 6; // TRUE + } + +} + +component main = Main();