mirror of
https://github.com/arnaucube/circom.git
synced 2026-02-06 18:56:40 +01:00
Multithread
This commit is contained in:
123
c/calcwit.cpp
123
c/calcwit.cpp
@@ -6,6 +6,7 @@
|
||||
#include <stdlib.h>
|
||||
#include <gmp.h>
|
||||
#include <assert.h>
|
||||
#include <thread>
|
||||
#include "calcwit.h"
|
||||
#include "utils.h"
|
||||
|
||||
@@ -17,6 +18,8 @@ Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
|
||||
signalAssigned[0] = true;
|
||||
#endif
|
||||
|
||||
mutexes = new std::mutex[NMUTEXES];
|
||||
cvs = new std::condition_variable[NMUTEXES];
|
||||
inputSignalsToTrigger = new int[circuit->NComponents];
|
||||
signalValues = new BigInt[circuit->NSignals];
|
||||
|
||||
@@ -34,6 +37,35 @@ Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
|
||||
reset();
|
||||
}
|
||||
|
||||
|
||||
Circom_CalcWit::~Circom_CalcWit() {
|
||||
delete field;
|
||||
|
||||
#ifdef SANITY_CHECK
|
||||
delete signalAssigned;
|
||||
#endif
|
||||
|
||||
delete[] cvs;
|
||||
delete[] mutexes;
|
||||
|
||||
for (int i=0; i<circuit->NSignals; i++) mpz_clear(signalValues[i]);
|
||||
|
||||
delete[] signalValues;
|
||||
delete[] inputSignalsToTrigger;
|
||||
|
||||
}
|
||||
|
||||
void Circom_CalcWit::syncPrintf(const char *format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
|
||||
printf_mutex.lock();
|
||||
vprintf(format, args);
|
||||
printf_mutex.unlock();
|
||||
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
void Circom_CalcWit::reset() {
|
||||
|
||||
#ifdef SANITY_CHECK
|
||||
@@ -47,20 +79,6 @@ void Circom_CalcWit::reset() {
|
||||
}
|
||||
|
||||
|
||||
Circom_CalcWit::~Circom_CalcWit() {
|
||||
delete field;
|
||||
|
||||
#ifdef SANITY_CHECK
|
||||
delete signalAssigned;
|
||||
#endif
|
||||
|
||||
for (int i=0; i<circuit->NSignals; i++) mpz_clear(signalValues[i]);
|
||||
|
||||
delete[] signalValues;
|
||||
delete inputSignalsToTrigger;
|
||||
|
||||
}
|
||||
|
||||
int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
|
||||
int hIdx;
|
||||
for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
|
||||
@@ -121,7 +139,16 @@ void Circom_CalcWit::freeBigInts(PBigInt bi, int n) {
|
||||
delete[] bi;
|
||||
}
|
||||
|
||||
void Circom_CalcWit::getSignal(int cIdx, int sIdx, PBigInt value) {
|
||||
void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) {
|
||||
// syncPrintf("getSignal: %d\n", sIdx);
|
||||
if (currentComponentIdx != cIdx) {
|
||||
std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
|
||||
while (inputSignalsToTrigger[cIdx] != -1) {
|
||||
cvs[cIdx % NMUTEXES].wait(lk);
|
||||
}
|
||||
// cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
|
||||
lk.unlock();
|
||||
}
|
||||
#ifdef SANITY_CHECK
|
||||
if (signalAssigned[sIdx] == false) {
|
||||
fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
|
||||
@@ -129,9 +156,25 @@ void Circom_CalcWit::getSignal(int cIdx, int sIdx, PBigInt value) {
|
||||
}
|
||||
#endif
|
||||
mpz_set(*value, signalValues[sIdx]);
|
||||
/*
|
||||
char *valueStr = mpz_get_str(0, 10, *value);
|
||||
syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
|
||||
free(valueStr);
|
||||
*/
|
||||
}
|
||||
|
||||
void Circom_CalcWit::setSignal(int cIdx, int sIdx, PBigInt value) {
|
||||
void Circom_CalcWit::finished(int cIdx) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
|
||||
inputSignalsToTrigger[cIdx] = -1;
|
||||
}
|
||||
// syncPrintf("Finished: %d\n", cIdx);
|
||||
cvs[cIdx % NMUTEXES].notify_all();
|
||||
}
|
||||
|
||||
void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) {
|
||||
// syncPrintf("setSignal: %d\n", sIdx);
|
||||
|
||||
#ifdef SANITY_CHECK
|
||||
if (signalAssigned[sIdx] == true) {
|
||||
fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
|
||||
@@ -139,46 +182,68 @@ void Circom_CalcWit::setSignal(int cIdx, int sIdx, PBigInt value) {
|
||||
}
|
||||
signalAssigned[sIdx] = true;
|
||||
#endif
|
||||
/*
|
||||
// Log assignement
|
||||
/*
|
||||
char *valueStr = mpz_get_str(0, 10, *value);
|
||||
printf("%d --> %s\n", sIdx, valueStr);
|
||||
syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
|
||||
free(valueStr);
|
||||
*/
|
||||
mpz_set(signalValues[sIdx], *value);
|
||||
if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
|
||||
inputSignalsToTrigger[cIdx]--;
|
||||
if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
|
||||
if (inputSignalsToTrigger[cIdx]>0) {
|
||||
inputSignalsToTrigger[cIdx]--;
|
||||
if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void Circom_CalcWit::checkConstraint(PBigInt value1, PBigInt value2, char const *err) {
|
||||
void Circom_CalcWit::checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err) {
|
||||
#ifdef SANITY_CHECK
|
||||
if (mpz_cmp(*value1, *value2) != 0) {
|
||||
char *pcV1 = mpz_get_str(0, 10, *value1);
|
||||
char *pcV2 = mpz_get_str(0, 10, *value2);
|
||||
std::string sV1 = std::string(pcV1);
|
||||
std::string sV2 = std::string(pcV2);
|
||||
// throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
|
||||
fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
|
||||
free(pcV1);
|
||||
free(pcV2);
|
||||
throw std::runtime_error(std::string("Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
void Circom_CalcWit::triggerComponent(int newCIdx) {
|
||||
int oldCIdx = cIdx;
|
||||
cIdx = newCIdx;
|
||||
(*(circuit->components[newCIdx].fn))(this);
|
||||
cIdx = oldCIdx;
|
||||
//int oldCIdx = cIdx;
|
||||
// cIdx = newCIdx;
|
||||
if (circuit->components[newCIdx].newThread) {
|
||||
// syncPrintf("Triggered: %d\n", newCIdx);
|
||||
std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
|
||||
// t.join();
|
||||
t.detach();
|
||||
} else {
|
||||
(*(circuit->components[newCIdx].fn))(this, newCIdx);
|
||||
}
|
||||
// cIdx = oldCIdx;
|
||||
}
|
||||
|
||||
void Circom_CalcWit::log(PBigInt value) {
|
||||
char *pcV = mpz_get_str(0, 10, *value);
|
||||
printf("Log: %s\n", pcV);
|
||||
syncPrintf("Log: %s\n", pcV);
|
||||
free(pcV);
|
||||
}
|
||||
|
||||
void Circom_CalcWit::join() {
|
||||
for (int i=0; i<circuit->NComponents; i++) {
|
||||
std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
|
||||
while (inputSignalsToTrigger[i] != -1) {
|
||||
cvs[i % NMUTEXES].wait(lk);
|
||||
}
|
||||
// cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
|
||||
lk.unlock();
|
||||
// syncPrintf("Joined: %d\n", i);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
23
c/calcwit.h
23
c/calcwit.h
@@ -3,6 +3,10 @@
|
||||
|
||||
#include "circom.h"
|
||||
#include "zqfield.h"
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
|
||||
#define NMUTEXES 128
|
||||
|
||||
class Circom_CalcWit {
|
||||
|
||||
@@ -13,7 +17,12 @@ class Circom_CalcWit {
|
||||
// componentStatus -> For each component
|
||||
// >0 Signals required to trigger
|
||||
// == 0 Component triggered
|
||||
// == -1 Component finished
|
||||
int *inputSignalsToTrigger;
|
||||
std::mutex *mutexes;
|
||||
std::condition_variable *cvs;
|
||||
|
||||
std::mutex printf_mutex;
|
||||
|
||||
BigInt *signalValues;
|
||||
|
||||
@@ -22,10 +31,11 @@ class Circom_CalcWit {
|
||||
void triggerComponent(int newCIdx);
|
||||
void calculateWitness(void *input, void *output);
|
||||
|
||||
void syncPrintf(const char *format, ...);
|
||||
|
||||
|
||||
public:
|
||||
ZqField *field;
|
||||
int cIdx;
|
||||
// Functions called by the circuit
|
||||
Circom_CalcWit(Circom_Circuit *aCircuit);
|
||||
~Circom_CalcWit();
|
||||
@@ -38,17 +48,20 @@ public:
|
||||
PBigInt allocBigInts(int n);
|
||||
void freeBigInts(PBigInt bi, int n);
|
||||
|
||||
void getSignal(int cIdx, int sIdx, PBigInt value);
|
||||
void setSignal(int cIdx, int sIdx, PBigInt value);
|
||||
void getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value);
|
||||
void setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value);
|
||||
|
||||
void checkConstraint(PBigInt value1, PBigInt value2, char const *err);
|
||||
void checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err);
|
||||
|
||||
void log(PBigInt value);
|
||||
|
||||
void finished(int cIdx);
|
||||
void join();
|
||||
|
||||
|
||||
// Public functions
|
||||
inline void setInput(int idx, PBigInt val) {
|
||||
setSignal(0, circuit->wit2sig[idx], val);
|
||||
setSignal(0, 0, circuit->wit2sig[idx], val);
|
||||
}
|
||||
inline void getWitness(int idx, PBigInt val) {
|
||||
mpz_set(*val, signalValues[circuit->wit2sig[idx]]);
|
||||
|
||||
@@ -29,13 +29,14 @@ struct Circom_ComponentEntry {
|
||||
};
|
||||
typedef Circom_ComponentEntry *Circom_ComponentEntries;
|
||||
|
||||
typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx);
|
||||
typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx, int __cIdx);
|
||||
|
||||
struct Circom_Component {
|
||||
Circom_HashTable hashTable;
|
||||
Circom_ComponentEntries entries;
|
||||
Circom_ComponentFunction fn;
|
||||
int inputSignals;
|
||||
bool newThread;
|
||||
};
|
||||
|
||||
class Circom_Circuit {
|
||||
|
||||
@@ -48,7 +48,7 @@ void loadBin(Circom_CalcWit *ctx, std::string filename) {
|
||||
p++;
|
||||
mpz_import(v,len , -1 , 1, 0, 0, p);
|
||||
p+=len;
|
||||
ctx->setSignal(0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v);
|
||||
ctx->setSignal(0, 0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ void itFunc(Circom_CalcWit *ctx, int o, json val) {
|
||||
|
||||
mpz_set_str (v, s.c_str(), 10);
|
||||
|
||||
ctx->setSignal(0, o, &v);
|
||||
ctx->setSignal(0, 0, o, &v);
|
||||
}
|
||||
|
||||
|
||||
@@ -187,6 +187,7 @@ int main(int argc, char *argv[]) {
|
||||
handle_error("Invalid input extension (.bin / .json)");
|
||||
}
|
||||
|
||||
ctx->join();
|
||||
|
||||
std::string outfilename = argv[2];
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "zqfield.h"
|
||||
|
||||
ZqField::ZqField(PBigInt ap) {
|
||||
mpz_init2(tmp, 1024);
|
||||
mpz_init_set(p, *ap);
|
||||
mpz_init_set_ui(zero, 0);
|
||||
mpz_init_set_ui(one, 1);
|
||||
@@ -12,7 +11,6 @@ ZqField::ZqField(PBigInt ap) {
|
||||
}
|
||||
|
||||
ZqField::~ZqField() {
|
||||
mpz_clear(tmp);
|
||||
mpz_clear(p);
|
||||
mpz_clear(zero);
|
||||
mpz_clear(one);
|
||||
@@ -29,8 +27,8 @@ void ZqField::sub(PBigInt r, PBigInt a, PBigInt b) {
|
||||
if (mpz_cmp(*a, *b) >= 0) {
|
||||
mpz_sub(*r, *a, *b);
|
||||
} else {
|
||||
mpz_sub(tmp, *b, *a);
|
||||
mpz_sub(*r, p, tmp);
|
||||
mpz_sub(*r, *b, *a);
|
||||
mpz_sub(*r, p, *r);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,14 +41,20 @@ void ZqField::neg(PBigInt r, PBigInt a) {
|
||||
}
|
||||
|
||||
void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) {
|
||||
mpz_t tmp;
|
||||
mpz_init(tmp);
|
||||
mpz_mul(tmp,*a,*b);
|
||||
mpz_fdiv_r(*r, tmp, p);
|
||||
mpz_clear(tmp);
|
||||
}
|
||||
|
||||
void ZqField::div(PBigInt r, PBigInt a, PBigInt b) {
|
||||
mpz_t tmp;
|
||||
mpz_init(tmp);
|
||||
mpz_invert(tmp, *b, p);
|
||||
mpz_mul(tmp,*a,tmp);
|
||||
mpz_fdiv_r(*r, tmp, p);
|
||||
mpz_clear(tmp);
|
||||
}
|
||||
|
||||
void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) {
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
#include "circom.h"
|
||||
|
||||
class ZqField {
|
||||
mpz_t tmp;
|
||||
|
||||
public:
|
||||
BigInt p;
|
||||
BigInt one;
|
||||
|
||||
@@ -176,10 +176,11 @@ function buildCode(ctx) {
|
||||
"/*\n" +
|
||||
instanceDef +
|
||||
"\n*/\n" +
|
||||
`void ${fName}(Circom_CalcWit *ctx) {\n` +
|
||||
`void ${fName}(Circom_CalcWit *ctx, int __cIdx) {\n` +
|
||||
utils.ident(
|
||||
ctx.codeHeader + "\n" +
|
||||
ctx.code + "\n" +
|
||||
"ctx->finished(__cIdx);\n" +
|
||||
ctx.codeFooter
|
||||
) +
|
||||
"}\n";
|
||||
@@ -204,7 +205,7 @@ function buildComponentsArray(ctx) {
|
||||
ccodes.push(`Circom_Component _components[${ctx.components.length}] = {\n`);
|
||||
for (let i=0; i< ctx.components.length; i++) {
|
||||
ccodes.push(i>0 ? " ," : " ");
|
||||
ccodes.push(`{${ctx.components[i].htName},${ctx.components[i].etName},${ctx.components[i].fnName}, ${ctx.components[i].nInSignals}}\n`);
|
||||
ccodes.push(`{${ctx.components[i].htName},${ctx.components[i].etName},${ctx.components[i].fnName}, ${ctx.components[i].nInSignals}, true}\n`);
|
||||
}
|
||||
ccodes.push("};\n");
|
||||
const codeComponents = ccodes.join("");
|
||||
|
||||
18
src/c_gen.js
18
src/c_gen.js
@@ -515,7 +515,7 @@ function genGetSubComponentOffset(ctx, cIdxRef, label) {
|
||||
const cIdx = ctx.refs[cIdxRef];
|
||||
s = cIdx.label;
|
||||
} else {
|
||||
s = "ctx->cIdx";
|
||||
s = "__cIdx";
|
||||
}
|
||||
ctx.code += `${offset.label} = ctx->getSubComponentOffset(${s}, 0x${h}LL /* ${label} */);\n`;
|
||||
return refOffset;
|
||||
@@ -532,7 +532,7 @@ function genGetSubComponentSizes(ctx, cIdxRef, label) {
|
||||
const cIdx = ctx.refs[cIdxRef];
|
||||
s = cIdx.label;
|
||||
} else {
|
||||
s = "ctx->cIdx";
|
||||
s = "__cIdx";
|
||||
}
|
||||
ctx.code += `${sizes.label} = ctx->getSubComponentSizes(${s}, 0x${h}LL /* ${label} */);\n`;
|
||||
return sizesRef;
|
||||
@@ -549,7 +549,7 @@ function genGetSigalOffset(ctx, cIdxRef, label) {
|
||||
const cIdx = ctx.refs[cIdxRef];
|
||||
s = cIdx.label;
|
||||
} else {
|
||||
s = "ctx->cIdx";
|
||||
s = "__cIdx";
|
||||
}
|
||||
ctx.code += `${offset.label} = ctx->getSignalOffset(${s}, 0x${h}LL /* ${label} */);\n`;
|
||||
return refOffset;
|
||||
@@ -565,7 +565,7 @@ function genGetSignalSizes(ctx, cIdxRef, label) {
|
||||
const cIdx = ctx.refs[cIdxRef];
|
||||
s = cIdx.label;
|
||||
} else {
|
||||
s = "ctx->cIdx";
|
||||
s = "__cIdx";
|
||||
}
|
||||
const h = utils.fnvHash(label);
|
||||
ctx.code += `${sizes.label} = ctx->getSignalSizes(${s}, 0x${h}LL /* ${label} */);\n`;
|
||||
@@ -585,10 +585,10 @@ function genSetSignal(ctx, cIdxRef, sIdxRef, valueRef) {
|
||||
const cIdx = ctx.refs[cIdxRef];
|
||||
s = cIdx.label;
|
||||
} else {
|
||||
s = "ctx->cIdx";
|
||||
s = "__cIdx";
|
||||
}
|
||||
const sIdx = ctx.refs[sIdxRef];
|
||||
ctx.code += `ctx->setSignal(${s}, ${sIdx.label}, ${v.label});\n`;
|
||||
ctx.code += `ctx->setSignal(__cIdx, ${s}, ${sIdx.label}, ${v.label});\n`;
|
||||
|
||||
return valueRef;
|
||||
}
|
||||
@@ -602,10 +602,10 @@ function genGetSignal(ctx, cIdxRef, sIdxRef) {
|
||||
const cIdx = ctx.refs[cIdxRef];
|
||||
s = cIdx.label;
|
||||
} else {
|
||||
s = "ctx->cIdx";
|
||||
s = "__cIdx";
|
||||
}
|
||||
const sIdx = ctx.refs[sIdxRef];
|
||||
ctx.code += `ctx->getSignal(${s}, ${sIdx.label}, ${res.label});\n`;
|
||||
ctx.code += `ctx->getSignal(__cIdx, ${s}, ${sIdx.label}, ${res.label});\n`;
|
||||
return resRef;
|
||||
}
|
||||
|
||||
@@ -747,7 +747,7 @@ function genConstraint(ctx, ast) {
|
||||
const strErr = ast.fileName + ":" + ast.first_line + ":" + ast.first_column;
|
||||
instantiateRef(ctx, aRef, a.value);
|
||||
instantiateRef(ctx, bRef, b.value);
|
||||
ctx.code += `ctx->checkConstraint(${a.label}, ${b.label}, "${strErr}");`;
|
||||
ctx.code += `ctx->checkConstraint(__cIdx, ${a.label}, ${b.label}, "${strErr}");`;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user