You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
7.2 KiB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
  1. #include <string>
  2. #include <stdexcept>
  3. #include <sstream>
  4. #include <iostream>
  5. #include <iomanip>
  6. #include <stdlib.h>
  7. #include <assert.h>
  8. #include <stdarg.h>
  9. #include <thread>
  10. #include "calcwit.h"
  11. #include "utils.h"
  12. Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
  13. circuit = aCircuit;
  14. #ifdef SANITY_CHECK
  15. signalAssigned = new bool[circuit->NSignals];
  16. signalAssigned[0] = true;
  17. #endif
  18. mutexes = new std::mutex[NMUTEXES];
  19. cvs = new std::condition_variable[NMUTEXES];
  20. inputSignalsToTrigger = new int[circuit->NComponents];
  21. signalValues = new FrElement[circuit->NSignals];
  22. // Set one signal
  23. Fr_copy(&signalValues[0], circuit->constants + 1);
  24. reset();
  25. }
  26. Circom_CalcWit::~Circom_CalcWit() {
  27. #ifdef SANITY_CHECK
  28. delete signalAssigned;
  29. #endif
  30. delete[] cvs;
  31. delete[] mutexes;
  32. delete[] signalValues;
  33. delete[] inputSignalsToTrigger;
  34. }
  35. void Circom_CalcWit::syncPrintf(const char *format, ...) {
  36. va_list args;
  37. va_start(args, format);
  38. printf_mutex.lock();
  39. vprintf(format, args);
  40. printf_mutex.unlock();
  41. va_end(args);
  42. }
  43. void Circom_CalcWit::reset() {
  44. #ifdef SANITY_CHECK
  45. for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
  46. #endif
  47. for (int i=0; i<circuit->NComponents; i++) {
  48. inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
  49. }
  50. for (int i=0; i<circuit->NComponents; i++) {
  51. if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
  52. }
  53. }
  54. int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
  55. int hIdx;
  56. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  57. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  58. }
  59. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  60. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  61. throw std::runtime_error("invalid type");
  62. }
  63. return circuit->components[cIdx].entries[entryPos].offset;
  64. }
  65. Circom_Sizes Circom_CalcWit::getSubComponentSizes(int cIdx, u64 hash) {
  66. int hIdx;
  67. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  68. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  69. }
  70. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  71. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  72. throw std::runtime_error("invalid type");
  73. }
  74. return circuit->components[cIdx].entries[entryPos].sizes;
  75. }
  76. int Circom_CalcWit::getSignalOffset(int cIdx, u64 hash) {
  77. int hIdx;
  78. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  79. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  80. }
  81. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  82. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  83. throw std::runtime_error("invalid type");
  84. }
  85. return circuit->components[cIdx].entries[entryPos].offset;
  86. }
  87. Circom_Sizes Circom_CalcWit::getSignalSizes(int cIdx, u64 hash) {
  88. int hIdx;
  89. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  90. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  91. }
  92. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  93. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  94. throw std::runtime_error("invalid type");
  95. }
  96. return circuit->components[cIdx].entries[entryPos].sizes;
  97. }
  98. void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
  99. // syncPrintf("getSignal: %d\n", sIdx);
  100. if ((circuit->components[cIdx].newThread)&&(currentComponentIdx != cIdx)) {
  101. std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  102. while (inputSignalsToTrigger[cIdx] != -1) {
  103. cvs[cIdx % NMUTEXES].wait(lk);
  104. }
  105. // cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
  106. lk.unlock();
  107. }
  108. #ifdef SANITY_CHECK
  109. if (signalAssigned[sIdx] == false) {
  110. fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
  111. assert(false);
  112. }
  113. #endif
  114. Fr_copy(value, signalValues + sIdx);
  115. /*
  116. char *valueStr = mpz_get_str(0, 10, *value);
  117. syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  118. free(valueStr);
  119. */
  120. }
  121. void Circom_CalcWit::finished(int cIdx) {
  122. {
  123. std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  124. inputSignalsToTrigger[cIdx] = -1;
  125. }
  126. // syncPrintf("Finished: %d\n", cIdx);
  127. cvs[cIdx % NMUTEXES].notify_all();
  128. }
  129. void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
  130. // syncPrintf("setSignal: %d\n", sIdx);
  131. #ifdef SANITY_CHECK
  132. if (signalAssigned[sIdx] == true) {
  133. fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
  134. assert(false);
  135. }
  136. signalAssigned[sIdx] = true;
  137. #endif
  138. // Log assignement
  139. /*
  140. char *valueStr = mpz_get_str(0, 10, *value);
  141. syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  142. free(valueStr);
  143. */
  144. Fr_copy(signalValues + sIdx, value);
  145. if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
  146. if (inputSignalsToTrigger[cIdx]>0) {
  147. inputSignalsToTrigger[cIdx]--;
  148. if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
  149. } else {
  150. fprintf(stderr, "Input signals does not match with map: %d\n", sIdx);
  151. assert(false);
  152. }
  153. }
  154. }
  155. void Circom_CalcWit::checkConstraint(int currentComponentIdx, PFrElement value1, PFrElement value2, char const *err) {
  156. #ifdef SANITY_CHECK
  157. FrElement tmp;
  158. Fr_eq(&tmp, value1, value2);
  159. if (!Fr_isTrue(&tmp)) {
  160. char *pcV1 = Fr_element2str(value1);
  161. char *pcV2 = Fr_element2str(value2);
  162. // throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
  163. fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
  164. free(pcV1);
  165. free(pcV2);
  166. assert(false);
  167. }
  168. #endif
  169. }
  170. void Circom_CalcWit::triggerComponent(int newCIdx) {
  171. //int oldCIdx = cIdx;
  172. // cIdx = newCIdx;
  173. if (circuit->components[newCIdx].newThread) {
  174. // syncPrintf("Triggered: %d\n", newCIdx);
  175. std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
  176. // t.join();
  177. t.detach();
  178. } else {
  179. (*(circuit->components[newCIdx].fn))(this, newCIdx);
  180. }
  181. // cIdx = oldCIdx;
  182. }
  183. void Circom_CalcWit::log(PFrElement value) {
  184. char *pcV = Fr_element2str(value);
  185. syncPrintf("Log: %s\n", pcV);
  186. free(pcV);
  187. }
  188. void Circom_CalcWit::join() {
  189. for (int i=0; i<circuit->NComponents; i++) {
  190. std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
  191. while (inputSignalsToTrigger[i] != -1) {
  192. cvs[i % NMUTEXES].wait(lk);
  193. }
  194. // cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
  195. lk.unlock();
  196. // syncPrintf("Joined: %d\n", i);
  197. }
  198. }