You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

201 lines
5.9 KiB

  1. const chai = require("chai");
  2. const assert = chai.assert;
  3. const fs = require("fs");
  4. var tmp = require("tmp-promise");
  5. const path = require("path");
  6. const util = require("util");
  7. const exec = util.promisify(require("child_process").exec);
  8. const loadR1cs = require("r1csfile").load;
  9. const ZqField = require("ffjavascript").ZqField;
  10. module.exports = wasm_tester;
  11. async function wasm_tester(circomInput, _options) {
  12. assert(await compiler_above_version("2.0.0"),"Wrong compiler version. Must be at least 2.0.0");
  13. tmp.setGracefulCleanup();
  14. const dir = await tmp.dir({prefix: "circom_", unsafeCleanup: true });
  15. //console.log(dir.path);
  16. const baseName = path.basename(circomInput, ".circom");
  17. const options = Object.assign({}, _options);
  18. options.wasm = true;
  19. options.sym = true;
  20. options.json = options.json || false; // costraints in json format
  21. options.r1cs = true;
  22. options.output = dir.path;
  23. await compile(circomInput, options);
  24. const utils = require("./utils");
  25. const WitnessCalculator = require("./witness_calculator");
  26. const wasm = await fs.promises.readFile(path.join(dir.path, baseName+"_js/"+ baseName + ".wasm"));
  27. const wc = await WitnessCalculator(wasm);
  28. return new WasmTester(dir, baseName, wc);
  29. }
  30. async function compile (fileName, options) {
  31. var flags = "--wasm ";
  32. if (options.sym) flags += "--sym ";
  33. if (options.r1cs) flags += "--r1cs ";
  34. if (options.json) flags += "--json ";
  35. if (options.output) flags += "--output " + options.output + " ";
  36. if (options.O === 0) flags += "--O0 "
  37. if (options.O === 1) flags += "--O1 "
  38. b = await exec("circom " + flags + fileName);
  39. assert(b.stderr == "",
  40. "circom compiler error \n" + b.stderr);
  41. }
  42. class WasmTester {
  43. constructor(dir, baseName, witnessCalculator) {
  44. this.dir=dir;
  45. this.baseName = baseName;
  46. this.witnessCalculator = witnessCalculator;
  47. }
  48. async release() {
  49. await this.dir.cleanup();
  50. }
  51. async calculateWitness(input, sanityCheck) {
  52. return await this.witnessCalculator.calculateWitness(input, sanityCheck);
  53. }
  54. async loadSymbols() {
  55. if (this.symbols) return;
  56. this.symbols = {};
  57. const symsStr = await fs.promises.readFile(
  58. path.join(this.dir.path, this.baseName + ".sym"),
  59. "utf8"
  60. );
  61. const lines = symsStr.split("\n");
  62. for (let i=0; i<lines.length; i++) {
  63. const arr = lines[i].split(",");
  64. if (arr.length!=4) continue;
  65. this.symbols[arr[3]] = {
  66. labelIdx: Number(arr[0]),
  67. varIdx: Number(arr[1]),
  68. componentIdx: Number(arr[2]),
  69. };
  70. }
  71. }
  72. async loadConstraints() {
  73. const self = this;
  74. if (this.constraints) return;
  75. const r1cs = await loadR1cs(path.join(this.dir.path, this.baseName + ".r1cs"),true, false);
  76. self.F = new ZqField(r1cs.prime);
  77. self.nVars = r1cs.nVars;
  78. self.constraints = r1cs.constraints;
  79. }
  80. async assertOut(actualOut, expectedOut) {
  81. const self = this;
  82. if (!self.symbols) await self.loadSymbols();
  83. checkObject("main", expectedOut);
  84. function checkObject(prefix, eOut) {
  85. if (Array.isArray(eOut)) {
  86. for (let i=0; i<eOut.length; i++) {
  87. checkObject(prefix + "["+i+"]", eOut[i]);
  88. }
  89. } else if ((typeof eOut == "object")&&(eOut.constructor.name == "Object")) {
  90. for (let k in eOut) {
  91. checkObject(prefix + "."+k, eOut[k]);
  92. }
  93. } else {
  94. if (typeof self.symbols[prefix] == "undefined") {
  95. assert(false, "Output variable not defined: "+ prefix);
  96. }
  97. const ba = actualOut[self.symbols[prefix].varIdx].toString();
  98. const be = eOut.toString();
  99. assert.strictEqual(ba, be, prefix);
  100. }
  101. }
  102. }
  103. async getDecoratedOutput(witness) {
  104. const self = this;
  105. const lines = [];
  106. if (!self.symbols) await self.loadSymbols();
  107. for (let n in self.symbols) {
  108. let v;
  109. if (utils.isDefined(witness[self.symbols[n].varIdx])) {
  110. v = witness[self.symbols[n].varIdx].toString();
  111. } else {
  112. v = "undefined";
  113. }
  114. lines.push(`${n} --> ${v}`);
  115. }
  116. return lines.join("\n");
  117. }
  118. async checkConstraints(witness) {
  119. const self = this;
  120. if (!self.constraints) await self.loadConstraints();
  121. for (let i=0; i<self.constraints.length; i++) {
  122. checkConstraint(self.constraints[i]);
  123. }
  124. function checkConstraint(constraint) {
  125. const F = self.F;
  126. const a = evalLC(constraint[0]);
  127. const b = evalLC(constraint[1]);
  128. const c = evalLC(constraint[2]);
  129. assert (F.isZero(F.sub(F.mul(a,b), c)), "Constraint doesn't match");
  130. }
  131. function evalLC(lc) {
  132. const F = self.F;
  133. let v = F.zero;
  134. for (let w in lc) {
  135. v = F.add(
  136. v,
  137. F.mul( lc[w], witness[w] )
  138. );
  139. }
  140. return v;
  141. }
  142. }
  143. }
  144. function version_to_list ( v ) {
  145. return v.split(".").map(function(x) {
  146. return parseInt(x, 10);
  147. });
  148. }
  149. function check_versions ( v1, v2 ) {
  150. //check if v1 is newer than or equal to v2
  151. for (let i = 0; i < v2.length; i++) {
  152. if (v1[i] > v2[i]) return true;
  153. if (v1[i] < v2[i]) return false;
  154. }
  155. return true;
  156. }
  157. async function compiler_above_version(v) {
  158. let output = await exec('circom --version').toString();
  159. let compiler_version = version_to_list(output.slice(output.search(/\d/),-1));
  160. vlist = version_to_list(v);
  161. return check_versions ( compiler_version, vlist );
  162. }