You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
7.5 KiB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
  1. #include <string>
  2. #include <stdexcept>
  3. #include <sstream>
  4. #include <iostream>
  5. #include <iomanip>
  6. #include <stdlib.h>
  7. #include <gmp.h>
  8. #include <assert.h>
  9. #include <thread>
  10. #include "calcwit.h"
  11. #include "utils.h"
  12. Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
  13. circuit = aCircuit;
  14. #ifdef SANITY_CHECK
  15. signalAssigned = new bool[circuit->NSignals];
  16. signalAssigned[0] = true;
  17. #endif
  18. mutexes = new std::mutex[NMUTEXES];
  19. cvs = new std::condition_variable[NMUTEXES];
  20. inputSignalsToTrigger = new int[circuit->NComponents];
  21. signalValues = new BigInt[circuit->NSignals];
  22. // Set one signal
  23. mpz_init_set_ui(signalValues[0], 1);
  24. // Initialize remaining signals
  25. for (int i=1; i<circuit->NSignals; i++) mpz_init2(signalValues[i], 256);
  26. BigInt p;
  27. mpz_init_set_str(p, circuit->P, 10);
  28. field = new ZqField(&p);
  29. mpz_clear(p);
  30. reset();
  31. }
  32. Circom_CalcWit::~Circom_CalcWit() {
  33. delete field;
  34. #ifdef SANITY_CHECK
  35. delete signalAssigned;
  36. #endif
  37. delete[] cvs;
  38. delete[] mutexes;
  39. for (int i=0; i<circuit->NSignals; i++) mpz_clear(signalValues[i]);
  40. delete[] signalValues;
  41. delete[] inputSignalsToTrigger;
  42. }
  43. void Circom_CalcWit::syncPrintf(const char *format, ...) {
  44. va_list args;
  45. va_start(args, format);
  46. printf_mutex.lock();
  47. vprintf(format, args);
  48. printf_mutex.unlock();
  49. va_end(args);
  50. }
  51. void Circom_CalcWit::reset() {
  52. #ifdef SANITY_CHECK
  53. for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
  54. #endif
  55. for (int i=0; i<circuit->NComponents; i++) {
  56. inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
  57. if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
  58. }
  59. }
  60. int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
  61. int hIdx;
  62. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  63. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  64. }
  65. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  66. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  67. throw std::runtime_error("invalid type");
  68. }
  69. return circuit->components[cIdx].entries[entryPos].offset;
  70. }
  71. Circom_Sizes Circom_CalcWit::getSubComponentSizes(int cIdx, u64 hash) {
  72. int hIdx;
  73. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  74. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  75. }
  76. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  77. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  78. throw std::runtime_error("invalid type");
  79. }
  80. return circuit->components[cIdx].entries[entryPos].sizes;
  81. }
  82. int Circom_CalcWit::getSignalOffset(int cIdx, u64 hash) {
  83. int hIdx;
  84. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  85. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  86. }
  87. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  88. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  89. throw std::runtime_error("invalid type");
  90. }
  91. return circuit->components[cIdx].entries[entryPos].offset;
  92. }
  93. Circom_Sizes Circom_CalcWit::getSignalSizes(int cIdx, u64 hash) {
  94. int hIdx;
  95. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  96. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  97. }
  98. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  99. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  100. throw std::runtime_error("invalid type");
  101. }
  102. return circuit->components[cIdx].entries[entryPos].sizes;
  103. }
  104. PBigInt Circom_CalcWit::allocBigInts(int n) {
  105. PBigInt res = new BigInt[n];
  106. for (int i=0; i<n; i++) mpz_init2(res[i], 256);
  107. return res;
  108. }
  109. void Circom_CalcWit::freeBigInts(PBigInt bi, int n) {
  110. for (int i=0; i<n; i++) mpz_clear(bi[i]);
  111. delete[] bi;
  112. }
  113. void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) {
  114. // syncPrintf("getSignal: %d\n", sIdx);
  115. if (currentComponentIdx != cIdx) {
  116. std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  117. while (inputSignalsToTrigger[cIdx] != -1) {
  118. cvs[cIdx % NMUTEXES].wait(lk);
  119. }
  120. // cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
  121. lk.unlock();
  122. }
  123. #ifdef SANITY_CHECK
  124. if (signalAssigned[sIdx] == false) {
  125. fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
  126. assert(false);
  127. }
  128. #endif
  129. mpz_set(*value, signalValues[sIdx]);
  130. /*
  131. char *valueStr = mpz_get_str(0, 10, *value);
  132. syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  133. free(valueStr);
  134. */
  135. }
  136. void Circom_CalcWit::finished(int cIdx) {
  137. {
  138. std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  139. inputSignalsToTrigger[cIdx] = -1;
  140. }
  141. // syncPrintf("Finished: %d\n", cIdx);
  142. cvs[cIdx % NMUTEXES].notify_all();
  143. }
  144. void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PBigInt value) {
  145. // syncPrintf("setSignal: %d\n", sIdx);
  146. #ifdef SANITY_CHECK
  147. if (signalAssigned[sIdx] == true) {
  148. fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
  149. assert(false);
  150. }
  151. signalAssigned[sIdx] = true;
  152. #endif
  153. // Log assignement
  154. /*
  155. char *valueStr = mpz_get_str(0, 10, *value);
  156. syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  157. free(valueStr);
  158. */
  159. mpz_set(signalValues[sIdx], *value);
  160. if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
  161. if (inputSignalsToTrigger[cIdx]>0) {
  162. inputSignalsToTrigger[cIdx]--;
  163. if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
  164. }
  165. }
  166. }
  167. void Circom_CalcWit::checkConstraint(int currentComponentIdx, PBigInt value1, PBigInt value2, char const *err) {
  168. #ifdef SANITY_CHECK
  169. if (mpz_cmp(*value1, *value2) != 0) {
  170. char *pcV1 = mpz_get_str(0, 10, *value1);
  171. char *pcV2 = mpz_get_str(0, 10, *value2);
  172. // throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
  173. fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
  174. free(pcV1);
  175. free(pcV2);
  176. assert(false);
  177. }
  178. #endif
  179. }
  180. void Circom_CalcWit::triggerComponent(int newCIdx) {
  181. //int oldCIdx = cIdx;
  182. // cIdx = newCIdx;
  183. if (circuit->components[newCIdx].newThread) {
  184. // syncPrintf("Triggered: %d\n", newCIdx);
  185. std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
  186. // t.join();
  187. t.detach();
  188. } else {
  189. (*(circuit->components[newCIdx].fn))(this, newCIdx);
  190. }
  191. // cIdx = oldCIdx;
  192. }
  193. void Circom_CalcWit::log(PBigInt value) {
  194. char *pcV = mpz_get_str(0, 10, *value);
  195. syncPrintf("Log: %s\n", pcV);
  196. free(pcV);
  197. }
  198. void Circom_CalcWit::join() {
  199. for (int i=0; i<circuit->NComponents; i++) {
  200. std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
  201. while (inputSignalsToTrigger[i] != -1) {
  202. cvs[i % NMUTEXES].wait(lk);
  203. }
  204. // cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
  205. lk.unlock();
  206. // syncPrintf("Joined: %d\n", i);
  207. }
  208. }