Browse Source

Multithread

feature/witness_bin
Jordi Baylina 4 years ago
parent
commit
b564201170
No known key found for this signature in database GPG Key ID: 7480C80C1BE43112
8 changed files with 136 additions and 53 deletions
  1. +93
    -28
      c/calcwit.cpp
  2. +18
    -5
      c/calcwit.h
  3. +2
    -1
      c/circom.h
  4. +3
    -2
      c/main.cpp
  5. +8
    -4
      c/zqfield.cpp
  6. +0
    -2
      c/zqfield.h
  7. +3
    -2
      src/c_build.js
  8. +9
    -9
      src/c_gen.js

+ 93
- 28
c/calcwit.cpp

@ -6,6 +6,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <gmp.h> #include <gmp.h>
#include <assert.h> #include <assert.h>
#include <thread>
#include "calcwit.h" #include "calcwit.h"
#include "utils.h" #include "utils.h"
@ -17,6 +18,8 @@ Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
signalAssigned[0] = true; signalAssigned[0] = true;
#endif #endif
mutexes = new std::mutex[NMUTEXES];
cvs = new std::condition_variable[NMUTEXES];
inputSignalsToTrigger = new int[circuit->NComponents]; inputSignalsToTrigger = new int[circuit->NComponents];
signalValues = new BigInt[circuit->NSignals]; signalValues = new BigInt[circuit->NSignals];
@ -34,18 +37,6 @@ Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
reset(); reset();
} }
void Circom_CalcWit::reset() {
#ifdef SANITY_CHECK
for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
#endif
for (int i=0; i<circuit->NComponents; i++) {
inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
}
}
Circom_CalcWit::~Circom_CalcWit() { Circom_CalcWit::~Circom_CalcWit() {
delete field; delete field;
@ -54,13 +45,40 @@ Circom_CalcWit::~Circom_CalcWit() {
delete signalAssigned; delete signalAssigned;
#endif #endif
delete[] cvs;
delete[] mutexes;
for (int i=0; i<circuit->NSignals; i++) mpz_clear(signalValues[i]); for (int i=0; i<circuit->NSignals; i++) mpz_clear(signalValues[i]);
delete[] signalValues; delete[] signalValues;
delete inputSignalsToTrigger;
delete[] inputSignalsToTrigger;
}
void Circom_CalcWit::syncPrintf(const char *format, ...) {
va_list args;
va_start(args, format);
printf_mutex.lock();
vprintf(format, args);
printf_mutex.unlock();
va_end(args);
}
void Circom_CalcWit::reset() {
#ifdef SANITY_CHECK
for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
#endif
for (int i=0; i<circuit->NComponents; i++) {
inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
}
} }
int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) { int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
int hIdx; int hIdx;
for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) { for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
@ -121,7 +139,16 @@ void Circom_CalcWit::freeBigInts(PBigInt bi, int n) {
delete[] bi; delete[] bi;
} }
void Circom_CalcWit::getSignal(int cIdx, int sIdx, PBigInt value) {
void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) {
// syncPrintf("getSignal: %d\n", sIdx);
if (currentComponentIdx != cIdx) {
std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
while (inputSignalsToTrigger[cIdx] != -1) {
cvs[cIdx % NMUTEXES].wait(lk);
}
// cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
lk.unlock();
}
#ifdef SANITY_CHECK #ifdef SANITY_CHECK
if (signalAssigned[sIdx] == false) { if (signalAssigned[sIdx] == false) {
fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx); fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
@ -129,9 +156,25 @@ void Circom_CalcWit::getSignal(int cIdx, int sIdx, PBigInt value) {
} }
#endif #endif
mpz_set(*value, signalValues[sIdx]); mpz_set(*value, signalValues[sIdx]);
/*
char *valueStr = mpz_get_str(0, 10, *value);
syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
free(valueStr);
*/
}
void Circom_CalcWit::finished(int cIdx) {
{
std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
inputSignalsToTrigger[cIdx] = -1;
}
// syncPrintf("Finished: %d\n", cIdx);
cvs[cIdx % NMUTEXES].notify_all();
} }
void Circom_CalcWit::setSignal(int cIdx, int sIdx, PBigInt value) {
void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) {
// syncPrintf("setSignal: %d\n", sIdx);
#ifdef SANITY_CHECK #ifdef SANITY_CHECK
if (signalAssigned[sIdx] == true) { if (signalAssigned[sIdx] == true) {
fprintf(stderr, "Signal assigned twice: %d\n", sIdx); fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
@ -139,46 +182,68 @@ void Circom_CalcWit::setSignal(int cIdx, int sIdx, PBigInt value) {
} }
signalAssigned[sIdx] = true; signalAssigned[sIdx] = true;
#endif #endif
/*
// Log assignement // Log assignement
/*
char *valueStr = mpz_get_str(0, 10, *value); char *valueStr = mpz_get_str(0, 10, *value);
printf("%d --> %s\n", sIdx, valueStr);
syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
free(valueStr); free(valueStr);
*/ */
mpz_set(signalValues[sIdx], *value); mpz_set(signalValues[sIdx], *value);
if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) { if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
inputSignalsToTrigger[cIdx]--;
if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
if (inputSignalsToTrigger[cIdx]>0) {
inputSignalsToTrigger[cIdx]--;
if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
}
} }
} }
void Circom_CalcWit::checkConstraint(PBigInt value1, PBigInt value2, char const *err) {
void Circom_CalcWit::checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err) {
#ifdef SANITY_CHECK #ifdef SANITY_CHECK
if (mpz_cmp(*value1, *value2) != 0) { if (mpz_cmp(*value1, *value2) != 0) {
char *pcV1 = mpz_get_str(0, 10, *value1); char *pcV1 = mpz_get_str(0, 10, *value1);
char *pcV2 = mpz_get_str(0, 10, *value2); char *pcV2 = mpz_get_str(0, 10, *value2);
std::string sV1 = std::string(pcV1);
std::string sV2 = std::string(pcV2);
// throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
free(pcV1); free(pcV1);
free(pcV2); free(pcV2);
throw std::runtime_error(std::string("Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
assert(false);
} }
#endif #endif
} }
void Circom_CalcWit::triggerComponent(int newCIdx) { void Circom_CalcWit::triggerComponent(int newCIdx) {
int oldCIdx = cIdx;
cIdx = newCIdx;
(*(circuit->components[newCIdx].fn))(this);
cIdx = oldCIdx;
//int oldCIdx = cIdx;
// cIdx = newCIdx;
if (circuit->components[newCIdx].newThread) {
// syncPrintf("Triggered: %d\n", newCIdx);
std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
// t.join();
t.detach();
} else {
(*(circuit->components[newCIdx].fn))(this, newCIdx);
}
// cIdx = oldCIdx;
} }
void Circom_CalcWit::log(PBigInt value) { void Circom_CalcWit::log(PBigInt value) {
char *pcV = mpz_get_str(0, 10, *value); char *pcV = mpz_get_str(0, 10, *value);
printf("Log: %s\n", pcV);
syncPrintf("Log: %s\n", pcV);
free(pcV); free(pcV);
} }
void Circom_CalcWit::join() {
for (int i=0; i<circuit->NComponents; i++) {
std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
while (inputSignalsToTrigger[i] != -1) {
cvs[i % NMUTEXES].wait(lk);
}
// cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
lk.unlock();
// syncPrintf("Joined: %d\n", i);
}
}

+ 18
- 5
c/calcwit.h

@ -3,6 +3,10 @@
#include "circom.h" #include "circom.h"
#include "zqfield.h" #include "zqfield.h"
#include <mutex>
#include <condition_variable>
#define NMUTEXES 128
class Circom_CalcWit { class Circom_CalcWit {
@ -13,7 +17,12 @@ class Circom_CalcWit {
// componentStatus -> For each component // componentStatus -> For each component
// >0 Signals required to trigger // >0 Signals required to trigger
// == 0 Component triggered // == 0 Component triggered
// == -1 Component finished
int *inputSignalsToTrigger; int *inputSignalsToTrigger;
std::mutex *mutexes;
std::condition_variable *cvs;
std::mutex printf_mutex;
BigInt *signalValues; BigInt *signalValues;
@ -22,10 +31,11 @@ class Circom_CalcWit {
void triggerComponent(int newCIdx); void triggerComponent(int newCIdx);
void calculateWitness(void *input, void *output); void calculateWitness(void *input, void *output);
void syncPrintf(const char *format, ...);
public: public:
ZqField *field; ZqField *field;
int cIdx;
// Functions called by the circuit // Functions called by the circuit
Circom_CalcWit(Circom_Circuit *aCircuit); Circom_CalcWit(Circom_Circuit *aCircuit);
~Circom_CalcWit(); ~Circom_CalcWit();
@ -38,17 +48,20 @@ public:
PBigInt allocBigInts(int n); PBigInt allocBigInts(int n);
void freeBigInts(PBigInt bi, int n); void freeBigInts(PBigInt bi, int n);
void getSignal(int cIdx, int sIdx, PBigInt value);
void setSignal(int cIdx, int sIdx, PBigInt value);
void getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value);
void setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value);
void checkConstraint(PBigInt value1, PBigInt value2, char const *err);
void checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err);
void log(PBigInt value); void log(PBigInt value);
void finished(int cIdx);
void join();
// Public functions // Public functions
inline void setInput(int idx, PBigInt val) { inline void setInput(int idx, PBigInt val) {
setSignal(0, circuit->wit2sig[idx], val);
setSignal(0, 0, circuit->wit2sig[idx], val);
} }
inline void getWitness(int idx, PBigInt val) { inline void getWitness(int idx, PBigInt val) {
mpz_set(*val, signalValues[circuit->wit2sig[idx]]); mpz_set(*val, signalValues[circuit->wit2sig[idx]]);

+ 2
- 1
c/circom.h

@ -29,13 +29,14 @@ struct Circom_ComponentEntry {
}; };
typedef Circom_ComponentEntry *Circom_ComponentEntries; typedef Circom_ComponentEntry *Circom_ComponentEntries;
typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx);
typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx, int __cIdx);
struct Circom_Component { struct Circom_Component {
Circom_HashTable hashTable; Circom_HashTable hashTable;
Circom_ComponentEntries entries; Circom_ComponentEntries entries;
Circom_ComponentFunction fn; Circom_ComponentFunction fn;
int inputSignals; int inputSignals;
bool newThread;
}; };
class Circom_Circuit { class Circom_Circuit {

+ 3
- 2
c/main.cpp

@ -48,7 +48,7 @@ void loadBin(Circom_CalcWit *ctx, std::string filename) {
p++; p++;
mpz_import(v,len , -1 , 1, 0, 0, p); mpz_import(v,len , -1 , 1, 0, 0, p);
p+=len; p+=len;
ctx->setSignal(0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v);
ctx->setSignal(0, 0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v);
} }
} }
@ -88,7 +88,7 @@ void itFunc(Circom_CalcWit *ctx, int o, json val) {
mpz_set_str (v, s.c_str(), 10); mpz_set_str (v, s.c_str(), 10);
ctx->setSignal(0, o, &v);
ctx->setSignal(0, 0, o, &v);
} }
@ -187,6 +187,7 @@ int main(int argc, char *argv[]) {
handle_error("Invalid input extension (.bin / .json)"); handle_error("Invalid input extension (.bin / .json)");
} }
ctx->join();
std::string outfilename = argv[2]; std::string outfilename = argv[2];

+ 8
- 4
c/zqfield.cpp

@ -1,7 +1,6 @@
#include "zqfield.h" #include "zqfield.h"
ZqField::ZqField(PBigInt ap) { ZqField::ZqField(PBigInt ap) {
mpz_init2(tmp, 1024);
mpz_init_set(p, *ap); mpz_init_set(p, *ap);
mpz_init_set_ui(zero, 0); mpz_init_set_ui(zero, 0);
mpz_init_set_ui(one, 1); mpz_init_set_ui(one, 1);
@ -12,7 +11,6 @@ ZqField::ZqField(PBigInt ap) {
} }
ZqField::~ZqField() { ZqField::~ZqField() {
mpz_clear(tmp);
mpz_clear(p); mpz_clear(p);
mpz_clear(zero); mpz_clear(zero);
mpz_clear(one); mpz_clear(one);
@ -29,8 +27,8 @@ void ZqField::sub(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_cmp(*a, *b) >= 0) { if (mpz_cmp(*a, *b) >= 0) {
mpz_sub(*r, *a, *b); mpz_sub(*r, *a, *b);
} else { } else {
mpz_sub(tmp, *b, *a);
mpz_sub(*r, p, tmp);
mpz_sub(*r, *b, *a);
mpz_sub(*r, p, *r);
} }
} }
@ -43,14 +41,20 @@ void ZqField::neg(PBigInt r, PBigInt a) {
} }
void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) { void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) {
mpz_t tmp;
mpz_init(tmp);
mpz_mul(tmp,*a,*b); mpz_mul(tmp,*a,*b);
mpz_fdiv_r(*r, tmp, p); mpz_fdiv_r(*r, tmp, p);
mpz_clear(tmp);
} }
void ZqField::div(PBigInt r, PBigInt a, PBigInt b) { void ZqField::div(PBigInt r, PBigInt a, PBigInt b) {
mpz_t tmp;
mpz_init(tmp);
mpz_invert(tmp, *b, p); mpz_invert(tmp, *b, p);
mpz_mul(tmp,*a,tmp); mpz_mul(tmp,*a,tmp);
mpz_fdiv_r(*r, tmp, p); mpz_fdiv_r(*r, tmp, p);
mpz_clear(tmp);
} }
void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) { void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) {

+ 0
- 2
c/zqfield.h

@ -4,8 +4,6 @@
#include "circom.h" #include "circom.h"
class ZqField { class ZqField {
mpz_t tmp;
public: public:
BigInt p; BigInt p;
BigInt one; BigInt one;

+ 3
- 2
src/c_build.js

@ -176,10 +176,11 @@ function buildCode(ctx) {
"/*\n" + "/*\n" +
instanceDef + instanceDef +
"\n*/\n" + "\n*/\n" +
`void ${fName}(Circom_CalcWit *ctx) {\n` +
`void ${fName}(Circom_CalcWit *ctx, int __cIdx) {\n` +
utils.ident( utils.ident(
ctx.codeHeader + "\n" + ctx.codeHeader + "\n" +
ctx.code + "\n" + ctx.code + "\n" +
"ctx->finished(__cIdx);\n" +
ctx.codeFooter ctx.codeFooter
) + ) +
"}\n"; "}\n";
@ -204,7 +205,7 @@ function buildComponentsArray(ctx) {
ccodes.push(`Circom_Component _components[${ctx.components.length}] = {\n`); ccodes.push(`Circom_Component _components[${ctx.components.length}] = {\n`);
for (let i=0; i< ctx.components.length; i++) { for (let i=0; i< ctx.components.length; i++) {
ccodes.push(i>0 ? " ," : " "); ccodes.push(i>0 ? " ," : " ");
ccodes.push(`{${ctx.components[i].htName},${ctx.components[i].etName},${ctx.components[i].fnName}, ${ctx.components[i].nInSignals}}\n`);
ccodes.push(`{${ctx.components[i].htName},${ctx.components[i].etName},${ctx.components[i].fnName}, ${ctx.components[i].nInSignals}, true}\n`);
} }
ccodes.push("};\n"); ccodes.push("};\n");
const codeComponents = ccodes.join(""); const codeComponents = ccodes.join("");

+ 9
- 9
src/c_gen.js

@ -515,7 +515,7 @@ function genGetSubComponentOffset(ctx, cIdxRef, label) {
const cIdx = ctx.refs[cIdxRef]; const cIdx = ctx.refs[cIdxRef];
s = cIdx.label; s = cIdx.label;
} else { } else {
s = "ctx->cIdx";
s = "__cIdx";
} }
ctx.code += `${offset.label} = ctx->getSubComponentOffset(${s}, 0x${h}LL /* ${label} */);\n`; ctx.code += `${offset.label} = ctx->getSubComponentOffset(${s}, 0x${h}LL /* ${label} */);\n`;
return refOffset; return refOffset;
@ -532,7 +532,7 @@ function genGetSubComponentSizes(ctx, cIdxRef, label) {
const cIdx = ctx.refs[cIdxRef]; const cIdx = ctx.refs[cIdxRef];
s = cIdx.label; s = cIdx.label;
} else { } else {
s = "ctx->cIdx";
s = "__cIdx";
} }
ctx.code += `${sizes.label} = ctx->getSubComponentSizes(${s}, 0x${h}LL /* ${label} */);\n`; ctx.code += `${sizes.label} = ctx->getSubComponentSizes(${s}, 0x${h}LL /* ${label} */);\n`;
return sizesRef; return sizesRef;
@ -549,7 +549,7 @@ function genGetSigalOffset(ctx, cIdxRef, label) {
const cIdx = ctx.refs[cIdxRef]; const cIdx = ctx.refs[cIdxRef];
s = cIdx.label; s = cIdx.label;
} else { } else {
s = "ctx->cIdx";
s = "__cIdx";
} }
ctx.code += `${offset.label} = ctx->getSignalOffset(${s}, 0x${h}LL /* ${label} */);\n`; ctx.code += `${offset.label} = ctx->getSignalOffset(${s}, 0x${h}LL /* ${label} */);\n`;
return refOffset; return refOffset;
@ -565,7 +565,7 @@ function genGetSignalSizes(ctx, cIdxRef, label) {
const cIdx = ctx.refs[cIdxRef]; const cIdx = ctx.refs[cIdxRef];
s = cIdx.label; s = cIdx.label;
} else { } else {
s = "ctx->cIdx";
s = "__cIdx";
} }
const h = utils.fnvHash(label); const h = utils.fnvHash(label);
ctx.code += `${sizes.label} = ctx->getSignalSizes(${s}, 0x${h}LL /* ${label} */);\n`; ctx.code += `${sizes.label} = ctx->getSignalSizes(${s}, 0x${h}LL /* ${label} */);\n`;
@ -585,10 +585,10 @@ function genSetSignal(ctx, cIdxRef, sIdxRef, valueRef) {
const cIdx = ctx.refs[cIdxRef]; const cIdx = ctx.refs[cIdxRef];
s = cIdx.label; s = cIdx.label;
} else { } else {
s = "ctx->cIdx";
s = "__cIdx";
} }
const sIdx = ctx.refs[sIdxRef]; const sIdx = ctx.refs[sIdxRef];
ctx.code += `ctx->setSignal(${s}, ${sIdx.label}, ${v.label});\n`;
ctx.code += `ctx->setSignal(__cIdx, ${s}, ${sIdx.label}, ${v.label});\n`;
return valueRef; return valueRef;
} }
@ -602,10 +602,10 @@ function genGetSignal(ctx, cIdxRef, sIdxRef) {
const cIdx = ctx.refs[cIdxRef]; const cIdx = ctx.refs[cIdxRef];
s = cIdx.label; s = cIdx.label;
} else { } else {
s = "ctx->cIdx";
s = "__cIdx";
} }
const sIdx = ctx.refs[sIdxRef]; const sIdx = ctx.refs[sIdxRef];
ctx.code += `ctx->getSignal(${s}, ${sIdx.label}, ${res.label});\n`;
ctx.code += `ctx->getSignal(__cIdx, ${s}, ${sIdx.label}, ${res.label});\n`;
return resRef; return resRef;
} }
@ -747,7 +747,7 @@ function genConstraint(ctx, ast) {
const strErr = ast.fileName + ":" + ast.first_line + ":" + ast.first_column; const strErr = ast.fileName + ":" + ast.first_line + ":" + ast.first_column;
instantiateRef(ctx, aRef, a.value); instantiateRef(ctx, aRef, a.value);
instantiateRef(ctx, bRef, b.value); instantiateRef(ctx, bRef, b.value);
ctx.code += `ctx->checkConstraint(${a.label}, ${b.label}, "${strErr}");`;
ctx.code += `ctx->checkConstraint(__cIdx, ${a.label}, ${b.label}, "${strErr}");`;
} }

Loading…
Cancel
Save