Browse Source

div operators

feature/witness_bin
Jordi Baylina 5 years ago
parent
commit
eaf4396cb3
No known key found for this signature in database GPG Key ID: 7480C80C1BE43112
9 changed files with 224 additions and 63 deletions
  1. +27
    -2
      c/zqfield.cpp
  2. +4
    -0
      c/zqfield.h
  3. +31
    -59
      src/c_gen.js
  4. +16
    -0
      src/zqfield.js
  5. +75
    -2
      test/basiccases.js
  6. +23
    -0
      test/circuits/dec.circom
  7. +24
    -0
      test/circuits/inc.circom
  8. +12
    -0
      test/circuits/ops.circom
  9. +12
    -0
      test/circuits/ops2.circom

+ 27
- 2
c/zqfield.cpp

@ -15,8 +15,19 @@ ZqField::~ZqField() {
} }
void ZqField::add(PBigInt r, PBigInt a, PBigInt b) { void ZqField::add(PBigInt r, PBigInt a, PBigInt b) {
mpz_add(tmp,*a,*b);
mpz_fdiv_r(*r, tmp, p);
mpz_add(*r,*a,*b);
if (mpz_cmp(*r, p) >= 0) {
mpz_sub(*r, *r, p);
}
}
void ZqField::sub(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_cmp(*a, *b) >= 0) {
mpz_sub(*r, *a, *b);
} else {
mpz_sub(tmp, *b, *a);
mpz_sub(*r, p, tmp);
}
} }
void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) { void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) {
@ -24,6 +35,20 @@ void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) {
mpz_fdiv_r(*r, tmp, p); mpz_fdiv_r(*r, tmp, p);
} }
void ZqField::div(PBigInt r, PBigInt a, PBigInt b) {
mpz_invert(tmp, *b, p);
mpz_mul(tmp,*a,tmp);
mpz_fdiv_r(*r, tmp, p);
}
void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) {
mpz_fdiv_q(*r, *a, *b);
}
void ZqField::mod(PBigInt r, PBigInt a, PBigInt b) {
mpz_fdiv_r(*r, *a, *b);
}
void ZqField::lt(PBigInt r, PBigInt a, PBigInt b) { void ZqField::lt(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b); int c = mpz_cmp(*a, *b);
if (c<0) { if (c<0) {

+ 4
- 0
c/zqfield.h

@ -15,7 +15,11 @@ public:
void copyn(PBigInt a, PBigInt b, int n); void copyn(PBigInt a, PBigInt b, int n);
void add(PBigInt r,PBigInt a, PBigInt b); void add(PBigInt r,PBigInt a, PBigInt b);
void sub(PBigInt r,PBigInt a, PBigInt b);
void mul(PBigInt r,PBigInt a, PBigInt b); void mul(PBigInt r,PBigInt a, PBigInt b);
void div(PBigInt r,PBigInt a, PBigInt b);
void idiv(PBigInt r,PBigInt a, PBigInt b);
void mod(PBigInt r,PBigInt a, PBigInt b);
void lt(PBigInt r, PBigInt a, PBigInt b); void lt(PBigInt r, PBigInt a, PBigInt b);
void eq(PBigInt r, PBigInt a, PBigInt b); void eq(PBigInt r, PBigInt a, PBigInt b);
void gt(PBigInt r, PBigInt a, PBigInt b); void gt(PBigInt r, PBigInt a, PBigInt b);

+ 31
- 59
src/c_gen.js

@ -156,27 +156,27 @@ function gen(ctx, ast) {
} else if (ast.op == "+") { } else if (ast.op == "+") {
return genBinaryOp(ctx, ast, "add"); return genBinaryOp(ctx, ast, "add");
} else if (ast.op == "-") { } else if (ast.op == "-") {
return genSub(ctx, ast);
return genBinaryOp(ctx, ast, "sub");
} else if (ast.op == "UMINUS") { } else if (ast.op == "UMINUS") {
return genUMinus(ctx, ast); return genUMinus(ctx, ast);
} else if (ast.op == "*") { } else if (ast.op == "*") {
return genBinaryOp(ctx, ast, "mul"); return genBinaryOp(ctx, ast, "mul");
} else if (ast.op == "%") { } else if (ast.op == "%") {
return genMod(ctx, ast);
return genBinaryOp(ctx, ast, "mod");
} else if (ast.op == "PLUSPLUSRIGHT") { } else if (ast.op == "PLUSPLUSRIGHT") {
return genPlusPlusRight(ctx, ast);
return genOpOp(ctx, ast, "add", "RIGHT");
} else if (ast.op == "PLUSPLUSLEFT") { } else if (ast.op == "PLUSPLUSLEFT") {
return genPlusPlusLeft(ctx, ast);
return genOpOp(ctx, ast, "add", "LEFT");
} else if (ast.op == "MINUSMINUSRIGHT") { } else if (ast.op == "MINUSMINUSRIGHT") {
return genMinusMinusRight(ctx, ast);
return genOpOp(ctx, ast, "sub", "RIGHT");
} else if (ast.op == "MINUSMINUSLEFT") { } else if (ast.op == "MINUSMINUSLEFT") {
return genMinusMinusLeft(ctx, ast);
return genOpOp(ctx, ast, "sub", "LEFT");
} else if (ast.op == "**") { } else if (ast.op == "**") {
return genExp(ctx, ast); return genExp(ctx, ast);
} else if (ast.op == "/") { } else if (ast.op == "/") {
return genDiv(ctx, ast);
return genBinaryOp(ctx, ast, "div");
} else if (ast.op == "\\") { } else if (ast.op == "\\") {
return genIDiv(ctx, ast);
return genBinaryOp(ctx, ast, "idiv");
} else if (ast.op == "&") { } else if (ast.op == "&") {
return genBAnd(ctx, ast); return genBAnd(ctx, ast);
} else if (ast.op == "&&") { } else if (ast.op == "&&") {
@ -1009,20 +1009,32 @@ function genVarMulAssignement(ctx, ast) {
return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "*", values: ast.values}]}); return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "*", values: ast.values}]});
} }
function genPlusPlusRight(ctx, ast) {
return `(${genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "+", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]})}).add(__P__).sub(bigInt(1)).mod(__P__)`;
}
function genOpOp(ctx, ast, op, lr) {
function genPlusPlusLeft(ctx, ast) {
return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "+", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]});
}
if (ast.values[0].type != "VARIABLE") return ctx.throwError(ast, "incrementing a non variable");
function genMinusMinusRight(ctx, ast) {
return `(${genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "-", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]})}).add(__P__).sub(bigInt(1)).mod(__P__)`;
}
const vRef = ast.values[0].refId;
const vevalRef = gen(ctx, ast.values[0]);
if (ctx.error) return;
const veval = ctx.refs[vevalRef];
if (veval.type != "BIGINT") return ctx.throwError(ast, "incrementing a non variable");
function genMinusMinusLeft(ctx, ast) {
return genAssignement(ctx, {values: [ast.values[0], {type: "OP", op: "-", values: [ast.values[0], {type: "NUMBER", value: bigInt(1)}]}]});
const resRef = newRef(ctx, "BIGINT", "_tmp");
const res = ctx.refs[resRef];
if (veval.used) {
instantiateRef(ctx, resRef);
ctx.code += `ctx->field->${op}(${res.label}, ${veval.label}, &(ctx->field->one));\n`;
} else {
res.value = [ctx.field[op](veval.value[0], bigInt(1))];
}
genVarAssignment(ctx, ast, vRef, ast.values[0].selectors, resRef);
if (lr == "RIGHT") {
return vevalRef;
} else if (lr == "LEFT") {
return resRef;
}
} }
function genBinaryOp(ctx, ast, op) { function genBinaryOp(ctx, ast, op) {
@ -1143,46 +1155,6 @@ function genMod(ctx, ast) {
return `bigInt(${a}).mod(bigInt(${b}))`; return `bigInt(${a}).mod(bigInt(${b}))`;
} }
function genGt(ctx, ast) {
const a = gen(ctx, ast.values[0]);
if (ctx.error) return;
const b = gen(ctx, ast.values[1]);
if (ctx.error) return;
return `bigInt(${a}).gt(bigInt(${b})) ? 1 : 0`;
}
function genLte(ctx, ast) {
const a = gen(ctx, ast.values[0]);
if (ctx.error) return;
const b = gen(ctx, ast.values[1]);
if (ctx.error) return;
return `bigInt(${a}).lesserOrEquals(bigInt(${b})) ? 1 : 0`;
}
function genGte(ctx, ast) {
const a = gen(ctx, ast.values[0]);
if (ctx.error) return;
const b = gen(ctx, ast.values[1]);
if (ctx.error) return;
return `bigInt(${a}).greaterOrEquals(bigInt(${b})) ? 1 : 0`;
}
function genEq(ctx, ast) {
const a = gen(ctx, ast.values[0]);
if (ctx.error) return;
const b = gen(ctx, ast.values[1]);
if (ctx.error) return;
return `(bigInt(${a}).eq(bigInt(${b})) ? 1 : 0)`;
}
function genNeq(ctx, ast) {
const a = gen(ctx, ast.values[0]);
if (ctx.error) return;
const b = gen(ctx, ast.values[1]);
if (ctx.error) return;
return `(bigInt(${a}).eq(bigInt(${b})) ? 0 : 1)`;
}
function genUMinus(ctx, ast) { function genUMinus(ctx, ast) {
const a = gen(ctx, ast.values[0]); const a = gen(ctx, ast.values[0]);
if (ctx.error) return; if (ctx.error) return;

+ 16
- 0
src/zqfield.js

@ -9,6 +9,10 @@ module.exports = class ZqField {
return a.add(b).mod(this.p); return a.add(b).mod(this.p);
} }
sub(a, b) {
return a.minus(b).mod(this.p);
}
mul(a, b) { mul(a, b) {
return a.mul(b).mod(this.p); return a.mul(b).mod(this.p);
} }
@ -37,5 +41,17 @@ module.exports = class ZqField {
return a.neq(b) ? bigInt(1) : bigInt(0); return a.neq(b) ? bigInt(1) : bigInt(0);
} }
div(a, b) {
return a.mul(b.modInv(this.p)).mod(this.p);
}
idiv(a, b) {
return a.divide(b);
}
mod(a, b) {
return a.mod(b);
}
}; };

+ 75
- 2
test/basiccases.js

@ -6,12 +6,37 @@ const c_tester = require("../index.js").c_tester;
const __P__ = new bigInt("21888242871839275222246405745257275088548364400416034343698204186575808495617"); const __P__ = new bigInt("21888242871839275222246405745257275088548364400416034343698204186575808495617");
function normalize(o) {
if ((typeof(o) == "bigint") || o.isZero !== undefined) {
const res = bigInt(o);
return norm(res);
} else if (Array.isArray(o)) {
return o.map(normalize);
} else if (typeof o == "object") {
const res = {};
for (let k in o) {
res[k] = normalize(o[k]);
}
return res;
} else {
const res = bigInt(o);
return norm(res);
}
function norm(n) {
let res = n.mod(__P__);
if (res.isNegative()) res = __P__.add(res);
return res;
}
}
async function doTest(circuit, testVectors) { async function doTest(circuit, testVectors) {
const cir = await c_tester(path.join(__dirname, "circuits", circuit)); const cir = await c_tester(path.join(__dirname, "circuits", circuit));
for (let i=0; i<testVectors.length; i++) { for (let i=0; i<testVectors.length; i++) {
const w = await cir.calculateWitness(testVectors[i][0]);
await cir.assertOut(w, testVectors[i][1] );
const w = await cir.calculateWitness(normalize(testVectors[i][0]));
await cir.assertOut(w, normalize(testVectors[i][1]) );
} }
await cir.release(); await cir.release();
@ -129,4 +154,52 @@ describe("basic cases", function () {
] ]
); );
}); });
it("inc", async () => {
await doTest(
"inc.circom",
[
[{in: 0}, {out: [5, 2]}],
[{in: 1}, {out: [6, 4]}],
[{in: 2}, {out: [7, 6]}],
[{in: 3}, {out: [8, 8]}],
[{in: __P__.minus(2)}, {out: [3,__P__.minus(2)]}],
]
);
});
it("dec", async () => {
await doTest(
"dec.circom",
[
[{in: 0}, {out: [1, __P__.minus(2)]}],
[{in: 1}, {out: [2, 0]}],
[{in: 2}, {out: [3, 2]}],
[{in: 3}, {out: [4, 4]}],
[{in: __P__.minus(2)}, {out: [__P__.minus(1),__P__.minus(6)]}],
]
);
});
it("ops", async () => {
await doTest(
"ops.circom",
[
[{in: [-2, 2]}, {add: 0, sub: -4, mul: -4}],
[{in: [-1, 1]}, {add: 0, sub: -2, mul: -1}],
[{in: [ 0, 0]}, {add: 0, sub: 0, mul: 0}],
[{in: [ 1,-1]}, {add: 0, sub: 2, mul: -1}],
[{in: [ 2,-2]}, {add: 0, sub: 4, mul: -4}],
[{in: [-2,-3]}, {add: -5, sub: 1, mul: 6}],
[{in: [ 2, 3]}, {add: 5, sub: -1, mul: 6}],
]
);
});
it("ops2", async () => {
await doTest(
"ops2.circom",
[
[{in: [-2, 2]}, {div: -1, idiv: bigInt("10944121435919637611123202872628637544274182200208017171849102093287904247807"), mod: 1}],
[{in: [-1, 1]}, {div: -1, idiv: -1, mod: 0}],
[{in: [ 1,-1]}, {div: -1, idiv: 0, mod: 1}],
]
);
});
}); });

+ 23
- 0
test/circuits/dec.circom

@ -0,0 +1,23 @@
template Main() {
signal input in;
signal output out[2];
// First play with variables;
var c = 3;
var d = c--; // d --> 3
var e = --c; // e --> 1
out[0] <== in + e; // in + 1
// Then play with signals
c = in;
d = c--; //d <-- in;
e = --c; // d <-- in-2
out[1] <== in + e; // 2*in -2
}
component main = Main();

+ 24
- 0
test/circuits/inc.circom

@ -0,0 +1,24 @@
template Main() {
signal input in;
signal output out[2];
// First play with variables;
var c = 3;
var d = c++; // d --> 3
var e = ++c; // e --> 5
out[0] <== in + e; // in + 5
// Then play with signals
c = in;
d = c++; //d <-- in;
e = ++c; // d <-- in+2
out[1] <== in + e; // 2*in +2
}
component main = Main();

+ 12
- 0
test/circuits/ops.circom

@ -0,0 +1,12 @@
template Ops() {
signal input in[2];
signal output add;
signal output sub;
signal output mul;
add <-- in[0] + in[1];
sub <-- in[0] - in[1];
mul <-- in[0] * in[1];
}
component main = Ops();

+ 12
- 0
test/circuits/ops2.circom

@ -0,0 +1,12 @@
template Ops2() {
signal input in[2];
signal output div;
signal output idiv;
signal output mod;
div <-- in[0] / in[1];
idiv <-- in[0] \ in[1];
mod <-- in[0] % in[1];
}
component main = Ops2();

Loading…
Cancel
Save