From b564201170210848a56ec9fcfc709f2da22e48de Mon Sep 17 00:00:00 2001 From: Jordi Baylina Date: Fri, 20 Dec 2019 22:01:12 +0100 Subject: [PATCH] Multithread --- c/calcwit.cpp | 121 +++++++++++++++++++++++++++++++++++++------------ c/calcwit.h | 23 ++++++++-- c/circom.h | 3 +- c/main.cpp | 5 +- c/zqfield.cpp | 12 +++-- c/zqfield.h | 2 - src/c_build.js | 5 +- src/c_gen.js | 18 ++++---- 8 files changed, 136 insertions(+), 53 deletions(-) diff --git a/c/calcwit.cpp b/c/calcwit.cpp index eb634fe..a85af31 100644 --- a/c/calcwit.cpp +++ b/c/calcwit.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "calcwit.h" #include "utils.h" @@ -17,6 +18,8 @@ Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) { signalAssigned[0] = true; #endif + mutexes = new std::mutex[NMUTEXES]; + cvs = new std::condition_variable[NMUTEXES]; inputSignalsToTrigger = new int[circuit->NComponents]; signalValues = new BigInt[circuit->NSignals]; @@ -34,18 +37,6 @@ Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) { reset(); } -void Circom_CalcWit::reset() { - -#ifdef SANITY_CHECK - for (int i=1; iNComponents; i++) signalAssigned[i] = false; -#endif - - for (int i=0; iNComponents; i++) { - inputSignalsToTrigger[i] = circuit->components[i].inputSignals; - if (inputSignalsToTrigger[i] == 0) triggerComponent(i); - } -} - Circom_CalcWit::~Circom_CalcWit() { delete field; @@ -54,13 +45,40 @@ Circom_CalcWit::~Circom_CalcWit() { delete signalAssigned; #endif + delete[] cvs; + delete[] mutexes; + for (int i=0; iNSignals; i++) mpz_clear(signalValues[i]); delete[] signalValues; - delete inputSignalsToTrigger; + delete[] inputSignalsToTrigger; + +} + +void Circom_CalcWit::syncPrintf(const char *format, ...) { + va_list args; + va_start(args, format); + + printf_mutex.lock(); + vprintf(format, args); + printf_mutex.unlock(); + + va_end(args); +} + +void Circom_CalcWit::reset() { + +#ifdef SANITY_CHECK + for (int i=1; iNComponents; i++) signalAssigned[i] = false; +#endif + for (int i=0; iNComponents; i++) { + inputSignalsToTrigger[i] = circuit->components[i].inputSignals; + if (inputSignalsToTrigger[i] == 0) triggerComponent(i); + } } + int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) { int hIdx; for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) { @@ -121,7 +139,16 @@ void Circom_CalcWit::freeBigInts(PBigInt bi, int n) { delete[] bi; } -void Circom_CalcWit::getSignal(int cIdx, int sIdx, PBigInt value) { +void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) { + // syncPrintf("getSignal: %d\n", sIdx); + if (currentComponentIdx != cIdx) { + std::unique_lock lk(mutexes[cIdx % NMUTEXES]); + while (inputSignalsToTrigger[cIdx] != -1) { + cvs[cIdx % NMUTEXES].wait(lk); + } + // cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;}); + lk.unlock(); + } #ifdef SANITY_CHECK if (signalAssigned[sIdx] == false) { fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx); @@ -129,9 +156,25 @@ void Circom_CalcWit::getSignal(int cIdx, int sIdx, PBigInt value) { } #endif mpz_set(*value, signalValues[sIdx]); + /* + char *valueStr = mpz_get_str(0, 10, *value); + syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr); + free(valueStr); + */ +} + +void Circom_CalcWit::finished(int cIdx) { + { + std::lock_guard lk(mutexes[cIdx % NMUTEXES]); + inputSignalsToTrigger[cIdx] = -1; + } + // syncPrintf("Finished: %d\n", cIdx); + cvs[cIdx % NMUTEXES].notify_all(); } -void Circom_CalcWit::setSignal(int cIdx, int sIdx, PBigInt value) { +void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) { + // syncPrintf("setSignal: %d\n", sIdx); + #ifdef SANITY_CHECK if (signalAssigned[sIdx] == true) { fprintf(stderr, "Signal assigned twice: %d\n", sIdx); @@ -139,46 +182,68 @@ void Circom_CalcWit::setSignal(int cIdx, int sIdx, PBigInt value) { } signalAssigned[sIdx] = true; #endif - /* // Log assignement + /* char *valueStr = mpz_get_str(0, 10, *value); - printf("%d --> %s\n", sIdx, valueStr); + syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr); free(valueStr); */ mpz_set(signalValues[sIdx], *value); if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) { - inputSignalsToTrigger[cIdx]--; - if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx); + if (inputSignalsToTrigger[cIdx]>0) { + inputSignalsToTrigger[cIdx]--; + if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx); + } } } -void Circom_CalcWit::checkConstraint(PBigInt value1, PBigInt value2, char const *err) { +void Circom_CalcWit::checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err) { #ifdef SANITY_CHECK if (mpz_cmp(*value1, *value2) != 0) { char *pcV1 = mpz_get_str(0, 10, *value1); char *pcV2 = mpz_get_str(0, 10, *value2); - std::string sV1 = std::string(pcV1); - std::string sV2 = std::string(pcV2); + // throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 ); + fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2); free(pcV1); free(pcV2); - throw std::runtime_error(std::string("Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 ); + assert(false); } #endif } void Circom_CalcWit::triggerComponent(int newCIdx) { - int oldCIdx = cIdx; - cIdx = newCIdx; - (*(circuit->components[newCIdx].fn))(this); - cIdx = oldCIdx; + //int oldCIdx = cIdx; + // cIdx = newCIdx; + if (circuit->components[newCIdx].newThread) { + // syncPrintf("Triggered: %d\n", newCIdx); + std::thread t(circuit->components[newCIdx].fn, this, newCIdx); + // t.join(); + t.detach(); + } else { + (*(circuit->components[newCIdx].fn))(this, newCIdx); + } + // cIdx = oldCIdx; } void Circom_CalcWit::log(PBigInt value) { char *pcV = mpz_get_str(0, 10, *value); - printf("Log: %s\n", pcV); + syncPrintf("Log: %s\n", pcV); free(pcV); } +void Circom_CalcWit::join() { + for (int i=0; iNComponents; i++) { + std::unique_lock lk(mutexes[i % NMUTEXES]); + while (inputSignalsToTrigger[i] != -1) { + cvs[i % NMUTEXES].wait(lk); + } + // cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;}); + lk.unlock(); + // syncPrintf("Joined: %d\n", i); + } + +} + diff --git a/c/calcwit.h b/c/calcwit.h index 5f2ce22..1a02ee6 100644 --- a/c/calcwit.h +++ b/c/calcwit.h @@ -3,6 +3,10 @@ #include "circom.h" #include "zqfield.h" +#include +#include + +#define NMUTEXES 128 class Circom_CalcWit { @@ -13,7 +17,12 @@ class Circom_CalcWit { // componentStatus -> For each component // >0 Signals required to trigger // == 0 Component triggered + // == -1 Component finished int *inputSignalsToTrigger; + std::mutex *mutexes; + std::condition_variable *cvs; + + std::mutex printf_mutex; BigInt *signalValues; @@ -22,10 +31,11 @@ class Circom_CalcWit { void triggerComponent(int newCIdx); void calculateWitness(void *input, void *output); + void syncPrintf(const char *format, ...); + public: ZqField *field; - int cIdx; // Functions called by the circuit Circom_CalcWit(Circom_Circuit *aCircuit); ~Circom_CalcWit(); @@ -38,17 +48,20 @@ public: PBigInt allocBigInts(int n); void freeBigInts(PBigInt bi, int n); - void getSignal(int cIdx, int sIdx, PBigInt value); - void setSignal(int cIdx, int sIdx, PBigInt value); + void getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value); + void setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value); - void checkConstraint(PBigInt value1, PBigInt value2, char const *err); + void checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err); void log(PBigInt value); + void finished(int cIdx); + void join(); + // Public functions inline void setInput(int idx, PBigInt val) { - setSignal(0, circuit->wit2sig[idx], val); + setSignal(0, 0, circuit->wit2sig[idx], val); } inline void getWitness(int idx, PBigInt val) { mpz_set(*val, signalValues[circuit->wit2sig[idx]]); diff --git a/c/circom.h b/c/circom.h index db83e8b..da48033 100644 --- a/c/circom.h +++ b/c/circom.h @@ -29,13 +29,14 @@ struct Circom_ComponentEntry { }; typedef Circom_ComponentEntry *Circom_ComponentEntries; -typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx); +typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx, int __cIdx); struct Circom_Component { Circom_HashTable hashTable; Circom_ComponentEntries entries; Circom_ComponentFunction fn; int inputSignals; + bool newThread; }; class Circom_Circuit { diff --git a/c/main.cpp b/c/main.cpp index 26d3167..1f77b34 100644 --- a/c/main.cpp +++ b/c/main.cpp @@ -48,7 +48,7 @@ void loadBin(Circom_CalcWit *ctx, std::string filename) { p++; mpz_import(v,len , -1 , 1, 0, 0, p); p+=len; - ctx->setSignal(0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v); + ctx->setSignal(0, 0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v); } } @@ -88,7 +88,7 @@ void itFunc(Circom_CalcWit *ctx, int o, json val) { mpz_set_str (v, s.c_str(), 10); - ctx->setSignal(0, o, &v); + ctx->setSignal(0, 0, o, &v); } @@ -187,6 +187,7 @@ int main(int argc, char *argv[]) { handle_error("Invalid input extension (.bin / .json)"); } + ctx->join(); std::string outfilename = argv[2]; diff --git a/c/zqfield.cpp b/c/zqfield.cpp index 50551f4..84605f6 100644 --- a/c/zqfield.cpp +++ b/c/zqfield.cpp @@ -1,7 +1,6 @@ #include "zqfield.h" ZqField::ZqField(PBigInt ap) { - mpz_init2(tmp, 1024); mpz_init_set(p, *ap); mpz_init_set_ui(zero, 0); mpz_init_set_ui(one, 1); @@ -12,7 +11,6 @@ ZqField::ZqField(PBigInt ap) { } ZqField::~ZqField() { - mpz_clear(tmp); mpz_clear(p); mpz_clear(zero); mpz_clear(one); @@ -29,8 +27,8 @@ void ZqField::sub(PBigInt r, PBigInt a, PBigInt b) { if (mpz_cmp(*a, *b) >= 0) { mpz_sub(*r, *a, *b); } else { - mpz_sub(tmp, *b, *a); - mpz_sub(*r, p, tmp); + mpz_sub(*r, *b, *a); + mpz_sub(*r, p, *r); } } @@ -43,14 +41,20 @@ void ZqField::neg(PBigInt r, PBigInt a) { } void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) { + mpz_t tmp; + mpz_init(tmp); mpz_mul(tmp,*a,*b); mpz_fdiv_r(*r, tmp, p); + mpz_clear(tmp); } void ZqField::div(PBigInt r, PBigInt a, PBigInt b) { + mpz_t tmp; + mpz_init(tmp); mpz_invert(tmp, *b, p); mpz_mul(tmp,*a,tmp); mpz_fdiv_r(*r, tmp, p); + mpz_clear(tmp); } void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) { diff --git a/c/zqfield.h b/c/zqfield.h index bac9b2a..1080387 100644 --- a/c/zqfield.h +++ b/c/zqfield.h @@ -4,8 +4,6 @@ #include "circom.h" class ZqField { - mpz_t tmp; - public: BigInt p; BigInt one; diff --git a/src/c_build.js b/src/c_build.js index f81e577..da2f27d 100644 --- a/src/c_build.js +++ b/src/c_build.js @@ -176,10 +176,11 @@ function buildCode(ctx) { "/*\n" + instanceDef + "\n*/\n" + - `void ${fName}(Circom_CalcWit *ctx) {\n` + + `void ${fName}(Circom_CalcWit *ctx, int __cIdx) {\n` + utils.ident( ctx.codeHeader + "\n" + ctx.code + "\n" + + "ctx->finished(__cIdx);\n" + ctx.codeFooter ) + "}\n"; @@ -204,7 +205,7 @@ function buildComponentsArray(ctx) { ccodes.push(`Circom_Component _components[${ctx.components.length}] = {\n`); for (let i=0; i< ctx.components.length; i++) { ccodes.push(i>0 ? " ," : " "); - ccodes.push(`{${ctx.components[i].htName},${ctx.components[i].etName},${ctx.components[i].fnName}, ${ctx.components[i].nInSignals}}\n`); + ccodes.push(`{${ctx.components[i].htName},${ctx.components[i].etName},${ctx.components[i].fnName}, ${ctx.components[i].nInSignals}, true}\n`); } ccodes.push("};\n"); const codeComponents = ccodes.join(""); diff --git a/src/c_gen.js b/src/c_gen.js index b79679d..a886b3b 100644 --- a/src/c_gen.js +++ b/src/c_gen.js @@ -515,7 +515,7 @@ function genGetSubComponentOffset(ctx, cIdxRef, label) { const cIdx = ctx.refs[cIdxRef]; s = cIdx.label; } else { - s = "ctx->cIdx"; + s = "__cIdx"; } ctx.code += `${offset.label} = ctx->getSubComponentOffset(${s}, 0x${h}LL /* ${label} */);\n`; return refOffset; @@ -532,7 +532,7 @@ function genGetSubComponentSizes(ctx, cIdxRef, label) { const cIdx = ctx.refs[cIdxRef]; s = cIdx.label; } else { - s = "ctx->cIdx"; + s = "__cIdx"; } ctx.code += `${sizes.label} = ctx->getSubComponentSizes(${s}, 0x${h}LL /* ${label} */);\n`; return sizesRef; @@ -549,7 +549,7 @@ function genGetSigalOffset(ctx, cIdxRef, label) { const cIdx = ctx.refs[cIdxRef]; s = cIdx.label; } else { - s = "ctx->cIdx"; + s = "__cIdx"; } ctx.code += `${offset.label} = ctx->getSignalOffset(${s}, 0x${h}LL /* ${label} */);\n`; return refOffset; @@ -565,7 +565,7 @@ function genGetSignalSizes(ctx, cIdxRef, label) { const cIdx = ctx.refs[cIdxRef]; s = cIdx.label; } else { - s = "ctx->cIdx"; + s = "__cIdx"; } const h = utils.fnvHash(label); ctx.code += `${sizes.label} = ctx->getSignalSizes(${s}, 0x${h}LL /* ${label} */);\n`; @@ -585,10 +585,10 @@ function genSetSignal(ctx, cIdxRef, sIdxRef, valueRef) { const cIdx = ctx.refs[cIdxRef]; s = cIdx.label; } else { - s = "ctx->cIdx"; + s = "__cIdx"; } const sIdx = ctx.refs[sIdxRef]; - ctx.code += `ctx->setSignal(${s}, ${sIdx.label}, ${v.label});\n`; + ctx.code += `ctx->setSignal(__cIdx, ${s}, ${sIdx.label}, ${v.label});\n`; return valueRef; } @@ -602,10 +602,10 @@ function genGetSignal(ctx, cIdxRef, sIdxRef) { const cIdx = ctx.refs[cIdxRef]; s = cIdx.label; } else { - s = "ctx->cIdx"; + s = "__cIdx"; } const sIdx = ctx.refs[sIdxRef]; - ctx.code += `ctx->getSignal(${s}, ${sIdx.label}, ${res.label});\n`; + ctx.code += `ctx->getSignal(__cIdx, ${s}, ${sIdx.label}, ${res.label});\n`; return resRef; } @@ -747,7 +747,7 @@ function genConstraint(ctx, ast) { const strErr = ast.fileName + ":" + ast.first_line + ":" + ast.first_column; instantiateRef(ctx, aRef, a.value); instantiateRef(ctx, bRef, b.value); - ctx.code += `ctx->checkConstraint(${a.label}, ${b.label}, "${strErr}");`; + ctx.code += `ctx->checkConstraint(__cIdx, ${a.label}, ${b.label}, "${strErr}");`; }