From eaf4396cb315c55020488ffda759e33018df858c Mon Sep 17 00:00:00 2001 From: Jordi Baylina Date: Sat, 7 Dec 2019 21:47:00 +0100 Subject: [PATCH] div operators --- c/zqfield.cpp | 29 ++++++++++++- c/zqfield.h | 4 ++ src/c_gen.js | 90 ++++++++++++++------------------------- src/zqfield.js | 16 +++++++ test/basiccases.js | 77 ++++++++++++++++++++++++++++++++- test/circuits/dec.circom | 23 ++++++++++ test/circuits/inc.circom | 24 +++++++++++ test/circuits/ops.circom | 12 ++++++ test/circuits/ops2.circom | 12 ++++++ 9 files changed, 224 insertions(+), 63 deletions(-) create mode 100644 test/circuits/dec.circom create mode 100644 test/circuits/inc.circom create mode 100644 test/circuits/ops.circom create mode 100644 test/circuits/ops2.circom diff --git a/c/zqfield.cpp b/c/zqfield.cpp index 0a95524..c779b23 100644 --- a/c/zqfield.cpp +++ b/c/zqfield.cpp @@ -15,8 +15,19 @@ ZqField::~ZqField() { } void ZqField::add(PBigInt r, PBigInt a, PBigInt b) { - mpz_add(tmp,*a,*b); - mpz_fdiv_r(*r, tmp, p); + mpz_add(*r,*a,*b); + if (mpz_cmp(*r, p) >= 0) { + mpz_sub(*r, *r, p); + } +} + +void ZqField::sub(PBigInt r, PBigInt a, PBigInt b) { + if (mpz_cmp(*a, *b) >= 0) { + mpz_sub(*r, *a, *b); + } else { + mpz_sub(tmp, *b, *a); + mpz_sub(*r, p, tmp); + } } void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) { @@ -24,6 +35,20 @@ void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) { mpz_fdiv_r(*r, tmp, p); } +void ZqField::div(PBigInt r, PBigInt a, PBigInt b) { + mpz_invert(tmp, *b, p); + mpz_mul(tmp,*a,tmp); + mpz_fdiv_r(*r, tmp, p); +} + +void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) { + mpz_fdiv_q(*r, *a, *b); +} + +void ZqField::mod(PBigInt r, PBigInt a, PBigInt b) { + mpz_fdiv_r(*r, *a, *b); +} + void ZqField::lt(PBigInt r, PBigInt a, PBigInt b) { int c = mpz_cmp(*a, *b); if (c<0) { diff --git a/c/zqfield.h b/c/zqfield.h index ee67ac0..ca11d0b 100644 --- a/c/zqfield.h +++ b/c/zqfield.h @@ -15,7 +15,11 @@ public: void copyn(PBigInt a, PBigInt b, int n); void add(PBigInt r,PBigInt a, PBigInt b); + void sub(PBigInt r,PBigInt a, PBigInt b); void mul(PBigInt r,PBigInt a, PBigInt b); + void div(PBigInt r,PBigInt a, PBigInt b); + void idiv(PBigInt r,PBigInt a, PBigInt b); + void mod(PBigInt r,PBigInt a, PBigInt b); void lt(PBigInt r, PBigInt a, PBigInt b); void eq(PBigInt r, PBigInt a, PBigInt b); void gt(PBigInt r, PBigInt a, PBigInt b); diff --git a/src/c_gen.js b/src/c_gen.js index c948a86..b62e3ad 100644 --- a/src/c_gen.js +++ b/src/c_gen.js @@ -156,27 +156,27 @@ function gen(ctx, ast) { } else if (ast.op == "+") { return genBinaryOp(ctx, ast, "add"); } else if (ast.op == "-") { - return genSub(ctx, ast); + return genBinaryOp(ctx, ast, "sub"); } else if (ast.op == "UMINUS") { return genUMinus(ctx, ast); } else if (ast.op == "*") { return genBinaryOp(ctx, ast, "mul"); } else if (ast.op == "%") { - return genMod(ctx, ast); + return genBinaryOp(ctx, ast, "mod"); } else if (ast.op == "PLUSPLUSRIGHT") { - return genPlusPlusRight(ctx, ast); + return genOpOp(ctx, ast, "add", "RIGHT"); } else if (ast.op == "PLUSPLUSLEFT") { - return genPlusPlusLeft(ctx, ast); + return genOpOp(ctx, ast, "add", "LEFT"); } else if (ast.op == "MINUSMINUSRIGHT") { - return genMinusMinusRight(ctx, ast); + return genOpOp(ctx, ast, "sub", "RIGHT"); } else if (ast.op == "MINUSMINUSLEFT") { - return genMinusMinusLeft(ctx, ast); + return genOpOp(ctx, ast, "sub", "LEFT"); } else if (ast.op == "**") { return genExp(ctx, ast); } else if (ast.op == "/") { - return genDiv(ctx, ast); + return genBinaryOp(ctx, ast, "div"); } else if (ast.op == "\\") { - return genIDiv(ctx, ast); + return genBinaryOp(ctx, ast, "idiv"); } else if (ast.op == "&") { return genBAnd(ctx, ast); } else if (ast.op == "&&") { @@ -1009,20 +1009,32 @@ function genVarMulAssignement(ctx, ast) { return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "*", values: ast.values}]}); } -function genPlusPlusRight(ctx, ast) { - return `(${genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "+", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]})}).add(__P__).sub(bigInt(1)).mod(__P__)`; -} +function genOpOp(ctx, ast, op, lr) { -function genPlusPlusLeft(ctx, ast) { - return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "+", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]}); -} + if (ast.values[0].type != "VARIABLE") return ctx.throwError(ast, "incrementing a non variable"); -function genMinusMinusRight(ctx, ast) { - return `(${genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "-", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]})}).add(__P__).sub(bigInt(1)).mod(__P__)`; -} + const vRef = ast.values[0].refId; + + const vevalRef = gen(ctx, ast.values[0]); + if (ctx.error) return; + const veval = ctx.refs[vevalRef]; + + if (veval.type != "BIGINT") return ctx.throwError(ast, "incrementing a non variable"); -function genMinusMinusLeft(ctx, ast) { - return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "-", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]}); + const resRef = newRef(ctx, "BIGINT", "_tmp"); + const res = ctx.refs[resRef]; + if (veval.used) { + instantiateRef(ctx, resRef); + ctx.code += `ctx->field->${op}(${res.label}, ${veval.label}, &(ctx->field->one));\n`; + } else { + res.value = [ctx.field[op](veval.value[0], bigInt(1))]; + } + genVarAssignment(ctx, ast, vRef, ast.values[0].selectors, resRef); + if (lr == "RIGHT") { + return vevalRef; + } else if (lr == "LEFT") { + return resRef; + } } function genBinaryOp(ctx, ast, op) { @@ -1143,46 +1155,6 @@ function genMod(ctx, ast) { return `bigInt(${a}).mod(bigInt(${b}))`; } -function genGt(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `bigInt(${a}).gt(bigInt(${b})) ? 1 : 0`; -} - -function genLte(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `bigInt(${a}).lesserOrEquals(bigInt(${b})) ? 1 : 0`; -} - -function genGte(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `bigInt(${a}).greaterOrEquals(bigInt(${b})) ? 1 : 0`; -} - -function genEq(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `(bigInt(${a}).eq(bigInt(${b})) ? 1 : 0)`; -} - -function genNeq(ctx, ast) { - const a = gen(ctx, ast.values[0]); - if (ctx.error) return; - const b = gen(ctx, ast.values[1]); - if (ctx.error) return; - return `(bigInt(${a}).eq(bigInt(${b})) ? 0 : 1)`; -} - function genUMinus(ctx, ast) { const a = gen(ctx, ast.values[0]); if (ctx.error) return; diff --git a/src/zqfield.js b/src/zqfield.js index 24c373f..9100791 100644 --- a/src/zqfield.js +++ b/src/zqfield.js @@ -9,6 +9,10 @@ module.exports = class ZqField { return a.add(b).mod(this.p); } + sub(a, b) { + return a.minus(b).mod(this.p); + } + mul(a, b) { return a.mul(b).mod(this.p); } @@ -37,5 +41,17 @@ module.exports = class ZqField { return a.neq(b) ? bigInt(1) : bigInt(0); } + div(a, b) { + return a.mul(b.modInv(this.p)).mod(this.p); + } + + idiv(a, b) { + return a.divide(b); + } + + mod(a, b) { + return a.mod(b); + } + }; diff --git a/test/basiccases.js b/test/basiccases.js index 6d48d0b..0d652db 100644 --- a/test/basiccases.js +++ b/test/basiccases.js @@ -6,12 +6,37 @@ const c_tester = require("../index.js").c_tester; const __P__ = new bigInt("21888242871839275222246405745257275088548364400416034343698204186575808495617"); +function normalize(o) { + if ((typeof(o) == "bigint") || o.isZero !== undefined) { + const res = bigInt(o); + return norm(res); + } else if (Array.isArray(o)) { + return o.map(normalize); + } else if (typeof o == "object") { + const res = {}; + for (let k in o) { + res[k] = normalize(o[k]); + } + return res; + } else { + const res = bigInt(o); + return norm(res); + } + + function norm(n) { + let res = n.mod(__P__); + if (res.isNegative()) res = __P__.add(res); + return res; + } +} + + async function doTest(circuit, testVectors) { const cir = await c_tester(path.join(__dirname, "circuits", circuit)); for (let i=0; i { + await doTest( + "inc.circom", + [ + [{in: 0}, {out: [5, 2]}], + [{in: 1}, {out: [6, 4]}], + [{in: 2}, {out: [7, 6]}], + [{in: 3}, {out: [8, 8]}], + [{in: __P__.minus(2)}, {out: [3,__P__.minus(2)]}], + ] + ); + }); + it("dec", async () => { + await doTest( + "dec.circom", + [ + [{in: 0}, {out: [1, __P__.minus(2)]}], + [{in: 1}, {out: [2, 0]}], + [{in: 2}, {out: [3, 2]}], + [{in: 3}, {out: [4, 4]}], + [{in: __P__.minus(2)}, {out: [__P__.minus(1),__P__.minus(6)]}], + ] + ); + }); + it("ops", async () => { + await doTest( + "ops.circom", + [ + [{in: [-2, 2]}, {add: 0, sub: -4, mul: -4}], + [{in: [-1, 1]}, {add: 0, sub: -2, mul: -1}], + [{in: [ 0, 0]}, {add: 0, sub: 0, mul: 0}], + [{in: [ 1,-1]}, {add: 0, sub: 2, mul: -1}], + [{in: [ 2,-2]}, {add: 0, sub: 4, mul: -4}], + [{in: [-2,-3]}, {add: -5, sub: 1, mul: 6}], + [{in: [ 2, 3]}, {add: 5, sub: -1, mul: 6}], + ] + ); + }); + it("ops2", async () => { + await doTest( + "ops2.circom", + [ + [{in: [-2, 2]}, {div: -1, idiv: bigInt("10944121435919637611123202872628637544274182200208017171849102093287904247807"), mod: 1}], + [{in: [-1, 1]}, {div: -1, idiv: -1, mod: 0}], + [{in: [ 1,-1]}, {div: -1, idiv: 0, mod: 1}], + ] + ); + }); }); diff --git a/test/circuits/dec.circom b/test/circuits/dec.circom new file mode 100644 index 0000000..14c105c --- /dev/null +++ b/test/circuits/dec.circom @@ -0,0 +1,23 @@ +template Main() { + signal input in; + signal output out[2]; + +// First play with variables; + + var c = 3; + var d = c--; // d --> 3 + var e = --c; // e --> 1 + + out[0] <== in + e; // in + 1 + +// Then play with signals + + c = in; + d = c--; //d <-- in; + e = --c; // d <-- in-2 + + out[1] <== in + e; // 2*in -2 + +} + +component main = Main(); diff --git a/test/circuits/inc.circom b/test/circuits/inc.circom new file mode 100644 index 0000000..5c11622 --- /dev/null +++ b/test/circuits/inc.circom @@ -0,0 +1,24 @@ +template Main() { + signal input in; + signal output out[2]; + +// First play with variables; + + var c = 3; + var d = c++; // d --> 3 + var e = ++c; // e --> 5 + + out[0] <== in + e; // in + 5 + +// Then play with signals + + c = in; + d = c++; //d <-- in; + e = ++c; // d <-- in+2 + + out[1] <== in + e; // 2*in +2 + +} + +component main = Main(); + diff --git a/test/circuits/ops.circom b/test/circuits/ops.circom new file mode 100644 index 0000000..e71b614 --- /dev/null +++ b/test/circuits/ops.circom @@ -0,0 +1,12 @@ +template Ops() { + signal input in[2]; + signal output add; + signal output sub; + signal output mul; + + add <-- in[0] + in[1]; + sub <-- in[0] - in[1]; + mul <-- in[0] * in[1]; +} + +component main = Ops(); diff --git a/test/circuits/ops2.circom b/test/circuits/ops2.circom new file mode 100644 index 0000000..3f0a37a --- /dev/null +++ b/test/circuits/ops2.circom @@ -0,0 +1,12 @@ +template Ops2() { + signal input in[2]; + signal output div; + signal output idiv; + signal output mod; + + div <-- in[0] / in[1]; + idiv <-- in[0] \ in[1]; + mod <-- in[0] % in[1]; +} + +component main = Ops2();