You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
4.7 KiB

4 years ago
4 years ago
4 years ago
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <string>
  5. #include <iomanip>
  6. #include <sys/types.h>
  7. #include <sys/stat.h>
  8. #include <sys/mman.h>
  9. #include <fcntl.h>
  10. #include <gmp.h>
  11. #include <unistd.h>
  12. #include <nlohmann/json.hpp>
  13. using json = nlohmann::json;
  14. #include "calcwit.h"
  15. #include "circom.h"
  16. #include "utils.h"
  17. #define handle_error(msg) \
  18. do { perror(msg); exit(EXIT_FAILURE); } while (0)
  19. void loadBin(Circom_CalcWit *ctx, std::string filename) {
  20. int fd;
  21. struct stat sb;
  22. // map input
  23. fd = open(filename.c_str(), O_RDONLY);
  24. if (fd == -1)
  25. handle_error("open");
  26. if (fstat(fd, &sb) == -1) /* To obtain file size */
  27. handle_error("fstat");
  28. u8 *in;
  29. in = (u8 *)mmap(NULL, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
  30. if (in == MAP_FAILED)
  31. handle_error("mmap");
  32. close(fd);
  33. BigInt v;
  34. mpz_init2(v, 256);
  35. u8 *p = in;
  36. for (int i=0; i<_circuit.NInputs; i++) {
  37. int len = *(u8 *)p;
  38. p++;
  39. mpz_import(v,len , -1 , 1, 0, 0, p);
  40. p+=len;
  41. ctx->setSignal(0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v);
  42. }
  43. }
  44. typedef void (*ItFunc)(Circom_CalcWit *ctx, int idx, json val);
  45. void iterateArr(Circom_CalcWit *ctx, int o, Circom_Sizes sizes, json jarr, ItFunc f) {
  46. if (!jarr.is_array()) {
  47. assert((sizes[0] == 1)&&(sizes[1] == 0));
  48. f(ctx, o, jarr);
  49. } else {
  50. int n = sizes[0] / sizes[1];
  51. for (int i=0; i<n; i++) {
  52. iterateArr(ctx, o + i*sizes[1], sizes+1, jarr[i], f);
  53. }
  54. }
  55. }
  56. void itFunc(Circom_CalcWit *ctx, int o, json val) {
  57. BigInt v;
  58. mpz_init2(v, 256);
  59. std::string s;
  60. if (val.is_string()) {
  61. s = val.get<std::string>();
  62. } else if (val.is_number()) {
  63. double vd = val.get<double>();
  64. std::stringstream stream;
  65. stream << std::fixed << std::setprecision(0) << vd;
  66. s = stream.str();
  67. } else {
  68. handle_error("Invalid JSON type");
  69. }
  70. mpz_set_str (v, s.c_str(), 10);
  71. ctx->setSignal(0, o, &v);
  72. }
  73. void loadJson(Circom_CalcWit *ctx, std::string filename) {
  74. std::ifstream inStream(filename);
  75. json j;
  76. inStream >> j;
  77. for (json::iterator it = j.begin(); it != j.end(); ++it) {
  78. // std::cout << it.key() << " => " << it.value() << '\n';
  79. u64 h = fnv1a(it.key());
  80. int o = ctx->getSignalOffset(0, h);
  81. Circom_Sizes sizes = ctx->getSignalSizes(0, h);
  82. iterateArr(ctx, o, sizes, it.value(), itFunc);
  83. }
  84. }
  85. void writeOutBin(Circom_CalcWit *ctx, std::string filename) {
  86. FILE *write_ptr;
  87. write_ptr = fopen(filename.c_str(),"wb");
  88. BigInt v;
  89. mpz_init2(v, 256);
  90. u8 buffOut[256];
  91. for (int i=0;i<_circuit.NVars;i++) {
  92. size_t size=256;
  93. ctx->getWitness(i, &v);
  94. mpz_export(buffOut+1, &size, -1, 1, -1, 0, v);
  95. *buffOut = (u8)size;
  96. fwrite(buffOut, size+1, 1, write_ptr);
  97. }
  98. fclose(write_ptr);
  99. }
  100. void writeOutJson(Circom_CalcWit *ctx, std::string filename) {
  101. std::ofstream outFile;
  102. outFile.open (filename);
  103. outFile << "[\n";
  104. BigInt v;
  105. mpz_init2(v, 256);
  106. char pcV[256];
  107. for (int i=0;i<_circuit.NVars;i++) {
  108. ctx->getWitness(i, &v);
  109. mpz_get_str(pcV, 10, v);
  110. std::string sV = std::string(pcV);
  111. outFile << (i ? "," : " ") << "\"" << sV << "\"\n";
  112. }
  113. outFile << "]\n";
  114. outFile.close();
  115. }
  116. bool hasEnding (std::string const &fullString, std::string const &ending) {
  117. if (fullString.length() >= ending.length()) {
  118. return (0 == fullString.compare (fullString.length() - ending.length(), ending.length(), ending));
  119. } else {
  120. return false;
  121. }
  122. }
  123. int main(int argc, char *argv[]) {
  124. if (argc!=3) {
  125. std::string cl = argv[0];
  126. std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1);
  127. std::cout << "Usage: " << base_filename << " <input.<bin|json>> <output.<bin|json>>\n";
  128. } else {
  129. // open output
  130. Circom_CalcWit *ctx = new Circom_CalcWit(&_circuit);
  131. std::string infilename = argv[1];
  132. if (hasEnding(infilename, std::string(".bin"))) {
  133. loadBin(ctx, infilename);
  134. } else if (hasEnding(infilename, std::string(".json"))) {
  135. loadJson(ctx, infilename);
  136. } else {
  137. handle_error("Invalid input extension (.bin / .json)");
  138. }
  139. std::string outfilename = argv[2];
  140. if (hasEnding(outfilename, std::string(".bin"))) {
  141. writeOutBin(ctx, outfilename);
  142. } else if (hasEnding(outfilename, std::string(".json"))) {
  143. writeOutJson(ctx, outfilename);
  144. } else {
  145. handle_error("Invalid output extension (.bin / .json)");
  146. }
  147. delete ctx;
  148. exit(EXIT_SUCCESS);
  149. }
  150. }