You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
7.1 KiB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. #include <string>
  2. #include <stdexcept>
  3. #include <sstream>
  4. #include <iostream>
  5. #include <iomanip>
  6. #include <stdlib.h>
  7. #include <assert.h>
  8. #include <stdarg.h>
  9. #include <thread>
  10. #include "calcwit.h"
  11. #include "utils.h"
  12. Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
  13. circuit = aCircuit;
  14. #ifdef SANITY_CHECK
  15. signalAssigned = new bool[circuit->NSignals];
  16. signalAssigned[0] = true;
  17. #endif
  18. mutexes = new std::mutex[NMUTEXES];
  19. cvs = new std::condition_variable[NMUTEXES];
  20. inputSignalsToTrigger = new int[circuit->NComponents];
  21. signalValues = new FrElement[circuit->NSignals];
  22. // Set one signal
  23. Fr_copy(&signalValues[0], circuit->constants + 1);
  24. reset();
  25. }
  26. Circom_CalcWit::~Circom_CalcWit() {
  27. #ifdef SANITY_CHECK
  28. delete signalAssigned;
  29. #endif
  30. delete[] cvs;
  31. delete[] mutexes;
  32. delete[] signalValues;
  33. delete[] inputSignalsToTrigger;
  34. }
  35. void Circom_CalcWit::syncPrintf(const char *format, ...) {
  36. va_list args;
  37. va_start(args, format);
  38. printf_mutex.lock();
  39. vprintf(format, args);
  40. printf_mutex.unlock();
  41. va_end(args);
  42. }
  43. void Circom_CalcWit::reset() {
  44. #ifdef SANITY_CHECK
  45. for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
  46. #endif
  47. for (int i=0; i<circuit->NComponents; i++) {
  48. inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
  49. if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
  50. }
  51. }
  52. int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
  53. int hIdx;
  54. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  55. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  56. }
  57. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  58. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  59. throw std::runtime_error("invalid type");
  60. }
  61. return circuit->components[cIdx].entries[entryPos].offset;
  62. }
  63. Circom_Sizes Circom_CalcWit::getSubComponentSizes(int cIdx, u64 hash) {
  64. int hIdx;
  65. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  66. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  67. }
  68. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  69. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  70. throw std::runtime_error("invalid type");
  71. }
  72. return circuit->components[cIdx].entries[entryPos].sizes;
  73. }
  74. int Circom_CalcWit::getSignalOffset(int cIdx, u64 hash) {
  75. int hIdx;
  76. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  77. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  78. }
  79. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  80. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  81. throw std::runtime_error("invalid type");
  82. }
  83. return circuit->components[cIdx].entries[entryPos].offset;
  84. }
  85. Circom_Sizes Circom_CalcWit::getSignalSizes(int cIdx, u64 hash) {
  86. int hIdx;
  87. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  88. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  89. }
  90. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  91. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  92. throw std::runtime_error("invalid type");
  93. }
  94. return circuit->components[cIdx].entries[entryPos].sizes;
  95. }
  96. void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
  97. // syncPrintf("getSignal: %d\n", sIdx);
  98. if ((circuit->components[cIdx].newThread)&&(currentComponentIdx != cIdx)) {
  99. std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  100. while (inputSignalsToTrigger[cIdx] != -1) {
  101. cvs[cIdx % NMUTEXES].wait(lk);
  102. }
  103. // cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
  104. lk.unlock();
  105. }
  106. #ifdef SANITY_CHECK
  107. if (signalAssigned[sIdx] == false) {
  108. fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
  109. assert(false);
  110. }
  111. #endif
  112. Fr_copy(value, signalValues + sIdx);
  113. /*
  114. char *valueStr = mpz_get_str(0, 10, *value);
  115. syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  116. free(valueStr);
  117. */
  118. }
  119. void Circom_CalcWit::finished(int cIdx) {
  120. {
  121. std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  122. inputSignalsToTrigger[cIdx] = -1;
  123. }
  124. // syncPrintf("Finished: %d\n", cIdx);
  125. cvs[cIdx % NMUTEXES].notify_all();
  126. }
  127. void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
  128. // syncPrintf("setSignal: %d\n", sIdx);
  129. #ifdef SANITY_CHECK
  130. if (signalAssigned[sIdx] == true) {
  131. fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
  132. assert(false);
  133. }
  134. signalAssigned[sIdx] = true;
  135. #endif
  136. // Log assignement
  137. /*
  138. char *valueStr = mpz_get_str(0, 10, *value);
  139. syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  140. free(valueStr);
  141. */
  142. Fr_copy(signalValues + sIdx, value);
  143. if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
  144. if (inputSignalsToTrigger[cIdx]>0) {
  145. inputSignalsToTrigger[cIdx]--;
  146. if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
  147. }
  148. }
  149. }
  150. void Circom_CalcWit::checkConstraint(int currentComponentIdx, PFrElement value1, PFrElement value2, char const *err) {
  151. #ifdef SANITY_CHECK
  152. FrElement tmp;
  153. Fr_eq(&tmp, value1, value2);
  154. if (!Fr_isTrue(&tmp)) {
  155. char *pcV1 = Fr_element2str(value1);
  156. char *pcV2 = Fr_element2str(value2);
  157. // throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
  158. fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
  159. free(pcV1);
  160. free(pcV2);
  161. assert(false);
  162. }
  163. #endif
  164. }
  165. void Circom_CalcWit::triggerComponent(int newCIdx) {
  166. //int oldCIdx = cIdx;
  167. // cIdx = newCIdx;
  168. if (circuit->components[newCIdx].newThread) {
  169. // syncPrintf("Triggered: %d\n", newCIdx);
  170. std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
  171. // t.join();
  172. t.detach();
  173. } else {
  174. (*(circuit->components[newCIdx].fn))(this, newCIdx);
  175. }
  176. // cIdx = oldCIdx;
  177. }
  178. void Circom_CalcWit::log(PFrElement value) {
  179. char *pcV = Fr_element2str(value);
  180. syncPrintf("Log: %s\n", pcV);
  181. free(pcV);
  182. }
  183. void Circom_CalcWit::join() {
  184. for (int i=0; i<circuit->NComponents; i++) {
  185. std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
  186. while (inputSignalsToTrigger[i] != -1) {
  187. cvs[i % NMUTEXES].wait(lk);
  188. }
  189. // cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
  190. lk.unlock();
  191. // syncPrintf("Joined: %d\n", i);
  192. }
  193. }