You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

228 lines
7.0 KiB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
  1. #include <string>
  2. #include <stdexcept>
  3. #include <sstream>
  4. #include <iostream>
  5. #include <iomanip>
  6. #include <stdlib.h>
  7. #include <assert.h>
  8. #include <thread>
  9. #include "calcwit.h"
  10. #include "utils.h"
  11. Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
  12. circuit = aCircuit;
  13. #ifdef SANITY_CHECK
  14. signalAssigned = new bool[circuit->NSignals];
  15. signalAssigned[0] = true;
  16. #endif
  17. mutexes = new std::mutex[NMUTEXES];
  18. cvs = new std::condition_variable[NMUTEXES];
  19. inputSignalsToTrigger = new int[circuit->NComponents];
  20. signalValues = new FrElement[circuit->NSignals];
  21. // Set one signal
  22. Fr_copy(&signalValues[0], circuit->constants + 1);
  23. reset();
  24. }
  25. Circom_CalcWit::~Circom_CalcWit() {
  26. #ifdef SANITY_CHECK
  27. delete signalAssigned;
  28. #endif
  29. delete[] cvs;
  30. delete[] mutexes;
  31. delete[] signalValues;
  32. delete[] inputSignalsToTrigger;
  33. }
  34. void Circom_CalcWit::syncPrintf(const char *format, ...) {
  35. va_list args;
  36. va_start(args, format);
  37. printf_mutex.lock();
  38. vprintf(format, args);
  39. printf_mutex.unlock();
  40. va_end(args);
  41. }
  42. void Circom_CalcWit::reset() {
  43. #ifdef SANITY_CHECK
  44. for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
  45. #endif
  46. for (int i=0; i<circuit->NComponents; i++) {
  47. inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
  48. if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
  49. }
  50. }
  51. int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
  52. int hIdx;
  53. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  54. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  55. }
  56. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  57. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  58. throw std::runtime_error("invalid type");
  59. }
  60. return circuit->components[cIdx].entries[entryPos].offset;
  61. }
  62. Circom_Sizes Circom_CalcWit::getSubComponentSizes(int cIdx, u64 hash) {
  63. int hIdx;
  64. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  65. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  66. }
  67. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  68. if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
  69. throw std::runtime_error("invalid type");
  70. }
  71. return circuit->components[cIdx].entries[entryPos].sizes;
  72. }
  73. int Circom_CalcWit::getSignalOffset(int cIdx, u64 hash) {
  74. int hIdx;
  75. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  76. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  77. }
  78. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  79. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  80. throw std::runtime_error("invalid type");
  81. }
  82. return circuit->components[cIdx].entries[entryPos].offset;
  83. }
  84. Circom_Sizes Circom_CalcWit::getSignalSizes(int cIdx, u64 hash) {
  85. int hIdx;
  86. for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
  87. if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
  88. }
  89. int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
  90. if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
  91. throw std::runtime_error("invalid type");
  92. }
  93. return circuit->components[cIdx].entries[entryPos].sizes;
  94. }
  95. void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
  96. // syncPrintf("getSignal: %d\n", sIdx);
  97. if ((circuit->components[cIdx].newThread)&&(currentComponentIdx != cIdx)) {
  98. std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  99. while (inputSignalsToTrigger[cIdx] != -1) {
  100. cvs[cIdx % NMUTEXES].wait(lk);
  101. }
  102. // cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
  103. lk.unlock();
  104. }
  105. #ifdef SANITY_CHECK
  106. if (signalAssigned[sIdx] == false) {
  107. fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
  108. assert(false);
  109. }
  110. #endif
  111. Fr_copy(value, signalValues + sIdx);
  112. /*
  113. char *valueStr = mpz_get_str(0, 10, *value);
  114. syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  115. free(valueStr);
  116. */
  117. }
  118. void Circom_CalcWit::finished(int cIdx) {
  119. {
  120. std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
  121. inputSignalsToTrigger[cIdx] = -1;
  122. }
  123. // syncPrintf("Finished: %d\n", cIdx);
  124. cvs[cIdx % NMUTEXES].notify_all();
  125. }
  126. void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
  127. // syncPrintf("setSignal: %d\n", sIdx);
  128. #ifdef SANITY_CHECK
  129. if (signalAssigned[sIdx] == true) {
  130. fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
  131. assert(false);
  132. }
  133. signalAssigned[sIdx] = true;
  134. #endif
  135. // Log assignement
  136. /*
  137. char *valueStr = mpz_get_str(0, 10, *value);
  138. syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
  139. free(valueStr);
  140. */
  141. Fr_copy(signalValues + sIdx, value);
  142. if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
  143. if (inputSignalsToTrigger[cIdx]>0) {
  144. inputSignalsToTrigger[cIdx]--;
  145. if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
  146. }
  147. }
  148. }
  149. void Circom_CalcWit::checkConstraint(int currentComponentIdx, PFrElement value1, PFrElement value2, char const *err) {
  150. #ifdef SANITY_CHECK
  151. FrElement tmp;
  152. Fr_eq(&tmp, value1, value2);
  153. if (!Fr_isTrue(&tmp)) {
  154. char *pcV1 = Fr_element2str(value1);
  155. char *pcV2 = Fr_element2str(value2);
  156. // throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
  157. fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
  158. free(pcV1);
  159. free(pcV2);
  160. assert(false);
  161. }
  162. #endif
  163. }
  164. void Circom_CalcWit::triggerComponent(int newCIdx) {
  165. //int oldCIdx = cIdx;
  166. // cIdx = newCIdx;
  167. if (circuit->components[newCIdx].newThread) {
  168. // syncPrintf("Triggered: %d\n", newCIdx);
  169. std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
  170. // t.join();
  171. t.detach();
  172. } else {
  173. (*(circuit->components[newCIdx].fn))(this, newCIdx);
  174. }
  175. // cIdx = oldCIdx;
  176. }
  177. void Circom_CalcWit::log(PFrElement value) {
  178. char *pcV = Fr_element2str(value);
  179. syncPrintf("Log: %s\n", pcV);
  180. free(pcV);
  181. }
  182. void Circom_CalcWit::join() {
  183. for (int i=0; i<circuit->NComponents; i++) {
  184. std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
  185. while (inputSignalsToTrigger[i] != -1) {
  186. cvs[i % NMUTEXES].wait(lk);
  187. }
  188. // cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
  189. lk.unlock();
  190. // syncPrintf("Joined: %d\n", i);
  191. }
  192. }