Begining of wasm

This commit is contained in:
Jordi Baylina
2020-03-09 21:16:56 +01:00
parent 6c1a3e7687
commit 8f63d18ff4
70 changed files with 4135 additions and 1353 deletions

View File

@@ -0,0 +1,245 @@
<% function addS1S2() { %>
xor rdx, rdx
mov edx, eax
add edx, ecx
jo add_manageOverflow ; rsi already is the 64bits result
mov [rdi], rdx ; not necessary to adjust so just save and return
ret
add_manageOverflow: ; Do the operation in 64 bits
push rsi
movsx rsi, eax
movsx rdx, ecx
add rsi, rdx
call rawCopyS2L
pop rsi
ret
<% } %>
<% function addL1S2() { %>
add rsi, 8
movsx rdx, ecx
add rdi, 8
cmp rdx, 0
<% const rawAddLabel = global.tmpLabel() %>
jns <%= rawAddLabel %>
neg rdx
call rawSubLS
sub rdi, 8
sub rsi, 8
ret
<%= rawAddLabel %>:
call rawAddLS
sub rdi, 8
sub rsi, 8
ret
<% } %>
<% function addS1L2() { %>
lea rsi, [rdx + 8]
movsx rdx, eax
add rdi, 8
cmp rdx, 0
<% const rawAddLabel = global.tmpLabel() %>
jns <%= rawAddLabel %>
neg rdx
call rawSubLS
sub rdi, 8
sub rsi, 8
ret
<%= rawAddLabel %>:
call rawAddLS
sub rdi, 8
sub rsi, 8
ret
<% } %>
<% function addL1L2() { %>
add rdi, 8
add rsi, 8
add rdx, 8
call rawAddLL
sub rdi, 8
sub rsi, 8
ret
<% } %>
;;;;;;;;;;;;;;;;;;;;;;
; add
;;;;;;;;;;;;;;;;;;;;;;
; Adds two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_add:
mov rax, [rsi]
mov rcx, [rdx]
bt rax, 63 ; Check if is short first operand
jc add_l1
bt rcx, 63 ; Check if is short second operand
jc add_s1l2
add_s1s2: ; Both operands are short
<%= addS1S2() %>
add_l1:
bt rcx, 63 ; Check if is short second operand
jc add_l1l2
;;;;;;;;
add_l1s2:
bt rax, 62 ; check if montgomery first
jc add_l1ms2
add_l1ns2:
<%= global.setTypeDest("0x80"); %>
<%= addL1S2(); %>
add_l1ms2:
bt rcx, 62 ; check if montgomery second
jc add_l1ms2m
add_l1ms2n:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_b() %>
<%= addL1L2() %>
add_l1ms2m:
<%= global.setTypeDest("0xC0"); %>
<%= addL1L2() %>
;;;;;;;;
add_s1l2:
bt rcx, 62 ; check if montgomery first
jc add_s1l2m
add_s1l2n:
<%= global.setTypeDest("0x80"); %>
<%= addS1L2(); %>
add_s1l2m:
bt rax, 62 ; check if montgomery second
jc add_s1ml2m
add_s1nl2m:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_a() %>
<%= addL1L2() %>
add_s1ml2m:
<%= global.setTypeDest("0xC0"); %>
<%= addL1L2() %>
;;;;
add_l1l2:
bt rax, 62 ; check if montgomery first
jc add_l1ml2
add_l1nl2:
bt rcx, 62 ; check if montgomery second
jc add_l1nl2m
add_l1nl2n:
<%= global.setTypeDest("0x80"); %>
<%= addL1L2() %>
add_l1nl2m:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_a(); %>
<%= addL1L2() %>
add_l1ml2:
bt rcx, 62 ; check if montgomery seconf
jc add_l1ml2m
add_l1ml2n:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_b(); %>
<%= addL1L2() %>
add_l1ml2m:
<%= global.setTypeDest("0xC0"); %>
<%= addL1L2() %>
;;;;;;;;;;;;;;;;;;;;;;
; rawAddLL
;;;;;;;;;;;;;;;;;;;;;;
; Adds two elements of type long
; Params:
; rsi <= Pointer to the long data of element 1
; rdx <= Pointer to the long data of element 2
; rdi <= Pointer to the long data of result
; Modified Registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;
rawAddLL:
; Add component by component with carry
<% for (let i=0; i<n64; i++) { %>
mov rax, [rsi + <%=i*8%>]
<%= i==0 ? "add" : "adc" %> rax, [rdx + <%=i*8%>]
mov [rdi + <%=i*8%>], rax
<% } %>
jc rawAddLL_sq ; if overflow, substract q
; Compare with q
<% for (let i=0; i<n64; i++) { %>
<% if (i>0) { %>
mov rax, [rdi + <%= (n64-i-1)*8 %>]
<% } %>
cmp rax, [q + <%= (n64-i-1)*8 %>]
jc rawAddLL_done ; q is bigget so done.
jnz rawAddLL_sq ; q is lower
<% } %>
; If equal substract q
rawAddLL_sq:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "sub" : "sbb" %> [rdi + <%=i*8%>], rax
<% } %>
rawAddLL_done:
ret
;;;;;;;;;;;;;;;;;;;;;;
; rawAddLS
;;;;;;;;;;;;;;;;;;;;;;
; Adds two elements of type long
; Params:
; rdi <= Pointer to the long data of result
; rsi <= Pointer to the long data of element 1
; rdx <= Value to be added
;;;;;;;;;;;;;;;;;;;;;;
rawAddLS:
; Add component by component with carry
add rdx, [rsi]
mov [rdi] ,rdx
<% for (let i=1; i<n64; i++) { %>
mov rdx, 0
adc rdx, [rsi + <%=i*8%>]
mov [rdi + <%=i*8%>], rdx
<% } %>
jc rawAddLS_sq ; if overflow, substract q
; Compare with q
<% for (let i=0; i<n64; i++) { %>
mov rax, [rdi + <%= (n64-i-1)*8 %>]
cmp rax, [q + <%= (n64-i-1)*8 %>]
jc rawAddLS_done ; q is bigget so done.
jnz rawAddLS_sq ; q is lower
<% } %>
; If equal substract q
rawAddLS_sq:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "sub" : "sbb" %> [rdi + <%=i*8%>], rax
<% } %>
rawAddLS_done:
ret

View File

@@ -0,0 +1,217 @@
<% function binOpS1S2(op) { %>
cmp r8d, 0
<% const s1s2_solveNeg = global.tmpLabel() %>
js <%=s1s2_solveNeg%>
cmp r9d, 0
js <%=s1s2_solveNeg%>
xor rdx, rdx ; both ops are positive so do the op and return
mov edx, r8d
<%=op%> edx, r9d
mov [rdi], rdx ; not necessary to adjust so just save and return
ret
<%=s1s2_solveNeg%>:
<%= global.setTypeDest("0x80"); %>
<%= global.toLong_b() %>
<%= global.toLong_a() %>
<%= binOpL1L2(op) %>
<% } %>
<% function binOpS1L2(op) { %>
cmp r8d, 0
<% const s1l2_solveNeg = global.tmpLabel() %>
js <%=s1l2_solveNeg%>
movsx rax, r8d
<%=op%> rax, [rdx +8]
mov [rdi+8], rax
<% for (let i=1; i<n64; i++) { %>
xor rax, rax
<%=op%> rax, [rdx + <%= (i*8)+8 %>]
<% if (i== n64-1) { %>
and rax, [lboMask]
<% } %>
mov [rdi + <%= (i*8)+8 %> ], rax
<% } %>
ret
<%=s1l2_solveNeg%>:
<%= global.toLong_a() %>
<%= global.setTypeDest("0x80"); %>
<%= binOpL1L2(op) %>
<% } %>
<% function binOpL1S2(op) { %>
cmp r9d, 0
<% const l1s2_solveNeg = global.tmpLabel() %>
js <%=l1s2_solveNeg%>
movsx rax, r9d
<%=op%> rax, [rsi +8]
mov [rdi+8], rax
<% for (let i=1; i<n64; i++) { %>
xor rax, rax
<%=op%> rax, [rsi + <%= (i*8)+8 %>];
<% if (i== n64-1) { %>
and rax, [lboMask] ;
<% } %>
mov [rdi + <%= (i*8)+8 %> ], rax;
<% } %>
ret
<%=l1s2_solveNeg%>:
<%= global.toLong_b() %>
<%= global.setTypeDest("0x80"); %>
<%= binOpL1L2(op) %>
<% } %>
<% function binOpL1L2(op) { %>
<% for (let i=0; i<n64; i++) { %>
mov rax, [rsi + <%= (i*8)+8 %>]
<%=op%> rax, [rdx + <%= (i*8)+8 %>]
<% if (i== n64-1) { %>
and rax, [lboMask]
<% } %>
mov [rdi + <%= (i*8)+8 %> ], rax
<% } %>
ret
<% } %>
<% function binOp(op) { %>
;;;;;;;;;;;;;;;;;;;;;;
; b<%= op %>
;;;;;;;;;;;;;;;;;;;;;;
; Adds two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_b<%=op%>:
mov r8, [rsi]
mov r9, [rdx]
bt r8, 63 ; Check if is short first operand
jc <%=op%>_l1
bt r9, 63 ; Check if is short second operand
jc <%=op%>_s1l2
<%=op%>_s1s2:
<%= binOpS1S2(op) %>
<%=op%>_l1:
bt r9, 63 ; Check if is short second operand
jc <%=op%>_l1l2
<%=op%>_l1s2:
bt r8, 62 ; check if montgomery first
jc <%=op%>_l1ms2
<%=op%>_l1ns2:
<%= global.setTypeDest("0x80"); %>
<%= binOpL1S2(op) %>
<%=op%>_l1ms2:
<%= global.setTypeDest("0x80"); %>
push r9 ; r9 is used in montgomery so we need to save it
<%= global.fromMont_a() %>
pop r9
<%= binOpL1S2(op) %>
<%=op%>_s1l2:
bt r9, 62 ; check if montgomery first
jc <%=op%>_s1l2m
<%=op%>_s1l2n:
<%= global.setTypeDest("0x80"); %>
<%= binOpS1L2(op) %>
<%=op%>_s1l2m:
<%= global.setTypeDest("0x80"); %>
push r8 ; r8 is used in montgomery so we need to save it
<%= global.fromMont_b() %>
pop r8
<%= binOpS1L2(op) %>
<%=op%>_l1l2:
bt r8, 62 ; check if montgomery first
jc <%=op%>_l1ml2
bt r9, 62 ; check if montgomery first
jc <%=op%>_l1nl2m
<%=op%>_l1nl2n:
<%= global.setTypeDest("0x80"); %>
<%= binOpL1L2(op) %>
<%=op%>_l1nl2m:
<%= global.setTypeDest("0x80"); %>
<%= global.fromMont_b() %>
<%= binOpL1L2(op) %>
<%=op%>_l1ml2:
bt r9, 62 ; check if montgomery first
jc <%=op%>_l1ml2m
<%=op%>_l1ml2n:
<%= global.setTypeDest("0x80"); %>
<%= global.fromMont_a() %>
<%= binOpL1L2(op) %>
<%=op%>_l1ml2m:
<%= global.setTypeDest("0x80"); %>
<%= global.fromMont_a() %>
<%= global.fromMont_b() %>
<%= binOpL1L2(op) %>
<% } %>
<%= binOp("and") %>
<%= binOp("or") %>
<%= binOp("xor") %>
;;;;;;;;;;;;;;;;;;;;;;
; bnot
;;;;;;;;;;;;;;;;;;;;;;
; Adds two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdi <= Pointer to result
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_bnot:
<%= global.setTypeDest("0x80"); %>
mov r8, [rsi]
bt r8, 63 ; Check if is long operand
jc bnot_l1
bnot_s:
<%= global.toLong_a() %>
jmp bnot_l1n
bnot_l1:
bt r8, 62 ; check if montgomery first
jnc bnot_l1n
bnot_l1m:
<%= global.fromMont_a() %>
bnot_l1n:
<% for (let i=0; i<n64; i++) { %>
mov rax, [rsi + <%= i*8 + 8 %>]
not rax
<% if (i== n64-1) { %>
and rax, [lboMask]
<% } %>
mov [rdi + <%= i*8 + 8 %>], rax
<% } %>
ret

View File

@@ -0,0 +1,72 @@
const bigInt=require("big-integer");
const path = require("path");
const util = require("util");
const renderFile = util.promisify(require("ejs").renderFile);
const runningAsScript = !module.parent;
class ZqBuilder {
constructor(q, name) {
const self = this;
this.q=bigInt(q);
this.n64 = Math.floor((this.q.bitLength() - 1) / 64)+1;
this.name = name;
this.bigInt = bigInt;
this.lastTmp=0;
this.global = {};
this.global.tmpLabel = function(label) {
self.lastTmp++;
label = label || "tmp";
return label+"_"+self.lastTmp;
};
}
constantElement(v) {
let S = "";
const mask = bigInt("FFFFFFFFFFFFFFFF", 16);
for (let i=0; i<this.n64; i++) {
if (i>0) S = S+",";
let shex = v.shiftRight(i*64).and(mask).toString(16);
while (shex.length <16) shex = "0" + shex;
S = S + "0x" + shex;
}
return S;
}
}
async function buildField(q, name) {
const builder = new ZqBuilder(q, name);
const asm = await renderFile(path.join(__dirname, "fr.asm.ejs"), builder);
const c = await renderFile(path.join(__dirname, "fr.c.ejs"), builder);
const h = await renderFile(path.join(__dirname, "fr.h.ejs"), builder);
return {asm: asm, h: h, c: c};
}
if (runningAsScript) {
const fs = require("fs");
var argv = require("yargs")
.usage("Usage: $0 -q [primeNum] -n [name] -oc [out .c file] -oh [out .h file]")
.demandOption(["q","n"])
.alias("q", "prime")
.alias("n", "name")
.argv;
const q = bigInt(argv.q);
const asmFileName = (argv.oc) ? argv.oc : argv.name.toLowerCase() + ".asm";
const hFileName = (argv.oc) ? argv.oc : argv.name.toLowerCase() + ".h";
const cFileName = (argv.oc) ? argv.oc : argv.name.toLowerCase() + ".c";
buildField(q, argv.name).then( (res) => {
fs.writeFileSync(asmFileName, res.asm, "utf8");
fs.writeFileSync(hFileName, res.h, "utf8");
fs.writeFileSync(cFileName, res.c, "utf8");
});
} else {
module.exports = buildField;
}

View File

@@ -0,0 +1,75 @@
const chai = require("chai");
const assert = chai.assert;
const fs = require("fs");
var tmp = require("tmp-promise");
const path = require("path");
const util = require("util");
const exec = util.promisify(require("child_process").exec);
const BuildZqField = require("./buildzqfield");
module.exports = testField;
async function testField(prime, test) {
tmp.setGracefulCleanup();
const dir = await tmp.dir({prefix: "circom_", unsafeCleanup: true });
const source = await BuildZqField(prime, "Fr");
// console.log(dir.path);
await fs.promises.writeFile(path.join(dir.path, "fr.asm"), source.asm, "utf8");
await fs.promises.writeFile(path.join(dir.path, "fr.h"), source.h, "utf8");
await fs.promises.writeFile(path.join(dir.path, "fr.c"), source.c, "utf8");
await exec(`cp ${path.join(__dirname, "tester.cpp")} ${dir.path}`);
await exec("nasm -fmacho64 --prefix _ " +
` ${path.join(dir.path, "fr.asm")}`
);
await exec("g++" +
` ${path.join(dir.path, "tester.cpp")}` +
` ${path.join(dir.path, "fr.o")}` +
` ${path.join(dir.path, "fr.c")}` +
` -o ${path.join(dir.path, "tester")}` +
" -lgmp -g"
);
const inLines = [];
for (let i=0; i<test.length; i++) {
for (let j=0; j<test[i][0].length; j++) {
inLines.push(test[i][0][j]);
}
}
inLines.push("");
await fs.promises.writeFile(path.join(dir.path, "in.tst"), inLines.join("\n"), "utf8");
await exec(`${path.join(dir.path, "tester")}` +
` <${path.join(dir.path, "in.tst")}` +
` >${path.join(dir.path, "out.tst")}`);
const res = await fs.promises.readFile(path.join(dir.path, "out.tst"), "utf8");
const resLines = res.split("\n");
for (let i=0; i<test.length; i++) {
const expected = test[i][1].toString();
const calculated = resLines[i];
if (calculated != expected) {
console.log("FAILED");
for (let j=0; j<test[i][0].length; j++) {
console.log(test[i][0][j]);
}
console.log("Should Return: " + expected);
console.log("But Returns: " + calculated);
}
assert.equal(calculated, expected);
}
}

View File

@@ -0,0 +1,439 @@
<% function signL(reg, label_pos, label_neg) { %>
<% for (let i=n64-1; i>=0; i--) { %>
mov rax, [<%=reg%> + <%= 8+(i*8) %>]
cmp [half + <%= (i*8) %>], rax ; comare with (q-1)/2
jc <%=label_neg%> ; half<rax => e1-e2 is neg => e1 < e2
<% if (i>0) { %>
jnz <%=label_pos%> ; half>rax => e1 -e2 is pos => e1 > e2
<% } else { %>
jmp <%=label_pos%>
<% } %>
<% } %>
<% } %>
;;;;;;;;;;;;;;;;;;;;;;
; rgt - Raw Greater Than
;;;;;;;;;;;;;;;;;;;;;;
; returns in ax 1 id *rsi > *rdx
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rax <= Return 1 or 0
; Modified Registers:
; r8, r9, rax
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_rgt:
mov r8, [rsi]
mov r9, [rdx]
bt r8, 63 ; Check if is short first operand
jc rgt_l1
bt r9, 63 ; Check if is short second operand
jc rgt_s1l2
rgt_s1s2: ; Both operands are short
cmp r8d, r9d
jg rgt_ret1
jmp rgt_ret0
rgt_l1:
bt r9, 63 ; Check if is short second operand
jc rgt_l1l2
;;;;;;;;
rgt_l1s2:
bt r8, 62 ; check if montgomery first
jc rgt_l1ms2
rgt_l1ns2:
<%= global.toLong_b() %>
jmp rgtL1L2
rgt_l1ms2:
<%= global.toLong_b() %>
<%= global.fromMont_a() %>
jmp rgtL1L2
;;;;;;;;
rgt_s1l2:
bt r9, 62 ; check if montgomery second
jc rgt_s1l2m
rgt_s1l2n:
<%= global.toLong_a() %>
jmp rgtL1L2
rgt_s1l2m:
<%= global.toLong_a() %>
<%= global.fromMont_b() %>
jmp rgtL1L2
;;;;
rgt_l1l2:
bt r8, 62 ; check if montgomery first
jc rgt_l1ml2
rgt_l1nl2:
bt r9, 62 ; check if montgomery second
jc rgt_l1nl2m
rgt_l1nl2n:
jmp rgtL1L2
rgt_l1nl2m:
<%= global.fromMont_b() %>
jmp rgtL1L2
rgt_l1ml2:
bt r9, 62 ; check if montgomery second
jc rgt_l1ml2m
rgt_l1ml2n:
<%= global.fromMont_a() %>
jmp rgtL1L2
rgt_l1ml2m:
<%= global.fromMont_a() %>
<%= global.fromMont_b() %>
jmp rgtL1L2
;;;;;;
; rgtL1L2
;;;;;;
rgtL1L2:
<%= signL("rsi", "rgtl1l2_p1", "rgtl1l2_n1") %>
rgtl1l2_p1:
<%= signL("rdx", "rgtRawL1L2", "rgt_ret1") %>
rgtl1l2_n1:
<%= signL("rdx", "rgt_ret0", "rgtRawL1L2") %>
rgtRawL1L2:
<% for (let i=n64-1; i>=0; i--) { %>
mov rax, [rsi + <%= 8+(i*8) %>]
cmp [rdx + <%= 8+(i*8) %>], rax ; comare with (q-1)/2
jc rgt_ret1 ; rsi<rdx => 1st > 2nd
<% if (i>0) { %>
jnz rgt_ret0
<% } %>
<% } %>
rgt_ret0:
xor rax, rax
ret
rgt_ret1:
mov rax, 1
ret
;;;;;;;;;;;;;;;;;;;;;;
; rlt - Raw Less Than
;;;;;;;;;;;;;;;;;;;;;;
; returns in ax 1 id *rsi > *rdx
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rax <= Return 1 or 0
; Modified Registers:
; r8, r9, rax
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_rlt:
mov r8, [rsi]
mov r9, [rdx]
bt r8, 63 ; Check if is short first operand
jc rlt_l1
bt r9, 63 ; Check if is short second operand
jc rlt_s1l2
rlt_s1s2: ; Both operands are short
cmp r8d, r9d
jl rlt_ret1
jmp rlt_ret0
rlt_l1:
bt r9, 63 ; Check if is short second operand
jc rlt_l1l2
;;;;;;;;
rlt_l1s2:
bt r8, 62 ; check if montgomery first
jc rlt_l1ms2
rlt_l1ns2:
<%= global.toLong_b() %>
jmp rltL1L2
rlt_l1ms2:
<%= global.toLong_b() %>
<%= global.fromMont_a() %>
jmp rltL1L2
;;;;;;;;
rlt_s1l2:
bt r9, 62 ; check if montgomery second
jc rlt_s1l2m
rlt_s1l2n:
<%= global.toLong_a() %>
jmp rltL1L2
rlt_s1l2m:
<%= global.toLong_a() %>
<%= global.fromMont_b() %>
jmp rltL1L2
;;;;
rlt_l1l2:
bt r8, 62 ; check if montgomery first
jc rlt_l1ml2
rlt_l1nl2:
bt r9, 62 ; check if montgomery second
jc rlt_l1nl2m
rlt_l1nl2n:
jmp rltL1L2
rlt_l1nl2m:
<%= global.fromMont_b() %>
jmp rltL1L2
rlt_l1ml2:
bt r9, 62 ; check if montgomery second
jc rlt_l1ml2m
rlt_l1ml2n:
<%= global.fromMont_a() %>
jmp rltL1L2
rlt_l1ml2m:
<%= global.fromMont_a() %>
<%= global.fromMont_b() %>
jmp rltL1L2
;;;;;;
; rltL1L2
;;;;;;
rltL1L2:
<%= signL("rsi", "rltl1l2_p1", "rltl1l2_n1") %>
rltl1l2_p1:
<%= signL("rdx", "rltRawL1L2", "rlt_ret0") %>
rltl1l2_n1:
<%= signL("rdx", "rlt_ret1", "rltRawL1L2") %>
rltRawL1L2:
<% for (let i=n64-1; i>=0; i--) { %>
mov rax, [rsi + <%= 8+(i*8) %>]
cmp [rdx + <%= 8+(i*8) %>], rax ; comare with (q-1)/2
jc rlt_ret0 ; rsi<rdx => 1st > 2nd
jnz rlt_ret1
<% } %>
rlt_ret0:
xor rax, rax
ret
rlt_ret1:
mov rax, 1
ret
;;;;;;;;;;;;;;;;;;;;;;
; req - Raw Eq
;;;;;;;;;;;;;;;;;;;;;;
; returns in ax 1 id *rsi == *rdx
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rax <= Return 1 or 0
; Modified Registers:
; r8, r9, rax
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_req:
mov r8, [rsi]
mov r9, [rdx]
bt r8, 63 ; Check if is short first operand
jc req_l1
bt r9, 63 ; Check if is short second operand
jc req_s1l2
req_s1s2: ; Both operands are short
cmp r8d, r9d
je req_ret1
jmp req_ret0
req_l1:
bt r9, 63 ; Check if is short second operand
jc req_l1l2
;;;;;;;;
req_l1s2:
bt r8, 62 ; check if montgomery first
jc req_l1ms2
req_l1ns2:
<%= global.toLong_b() %>
jmp reqL1L2
req_l1ms2:
<%= global.toMont_b() %>
jmp reqL1L2
;;;;;;;;
req_s1l2:
bt r9, 62 ; check if montgomery second
jc req_s1l2m
req_s1l2n:
<%= global.toLong_a() %>
jmp reqL1L2
req_s1l2m:
<%= global.toMont_a() %>
jmp reqL1L2
;;;;
req_l1l2:
bt r8, 62 ; check if montgomery first
jc req_l1ml2
req_l1nl2:
bt r9, 62 ; check if montgomery second
jc req_l1nl2m
req_l1nl2n:
jmp reqL1L2
req_l1nl2m:
<%= global.toMont_a() %>
jmp reqL1L2
req_l1ml2:
bt r9, 62 ; check if montgomery second
jc req_l1ml2m
req_l1ml2n:
<%= global.toMont_b() %>
jmp reqL1L2
req_l1ml2m:
jmp reqL1L2
;;;;;;
; eqL1L2
;;;;;;
reqL1L2:
<% for (let i=0; i<n64; i++) { %>
mov rax, [rsi + <%= 8+(i*8) %>]
cmp [rdx + <%= 8+(i*8) %>], rax
jne req_ret0 ; rsi<rdi => 1st > 2nd
<% } %>
req_ret1:
mov rax, 1
ret
req_ret0:
xor rax, rax
ret
;;;;;;;;;;;;;;;;;;;;;;
; gt
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_gt:
call <%=name%>_rgt
mov [rdi], rax
ret
;;;;;;;;;;;;;;;;;;;;;;
; lt
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_lt:
call <%=name%>_rlt
mov [rdi], rax
ret
;;;;;;;;;;;;;;;;;;;;;;
; eq
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_eq:
call <%=name%>_req
mov [rdi], rax
ret
;;;;;;;;;;;;;;;;;;;;;;
; neq
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_neq:
call <%=name%>_req
xor rax, 1
mov [rdi], rax
ret
;;;;;;;;;;;;;;;;;;;;;;
; geq
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_geq:
call <%=name%>_rlt
xor rax, 1
mov [rdi], rax
ret
;;;;;;;;;;;;;;;;;;;;;;
; leq
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_leq:
call <%=name%>_rgt
xor rax, 1
mov [rdi], rax
ret

View File

@@ -0,0 +1,108 @@
<% function retOne() { %>
mov qword [rdi], 1
add rsp, <%= (n64+1)*8 %>
ret
<% } %>
<% function retZero() { %>
mov qword [rdi], 0
add rsp, <%= (n64+1)*8 %>
ret
<% } %>
<% function cmpLong(op, eq) { %>
<%
if (eq==true) {
if (["leq","geq"].indexOf(op) >= 0) retOne();
if (["lt","gt"].indexOf(op) >= 0) retZero();
}
%>
<% const label_gt = global.tmpLabel() %>
<% const label_lt = global.tmpLabel() %>
<% for (let i=n64-1; i>=0; i--) { %>
mov rax, [rsp + <%= 8+(i*8) %>]
cmp [half + <%= (i*8) %>], rax ; comare with (q-1)/2
jc <%=label_lt%> ; half<rax => e1-e2 is neg => e1 < e2
jnz <%=label_gt%> ; half>rax => e1 -e2 is pos => e1 > e2
<% } %>
; half == rax => e1-e2 is pos => e1 > e2
<%=label_gt%>:
<% if (["geq","gt"].indexOf(op) >= 0) retOne(); else retZero(); %>
<%=label_lt%>:
<% if (["leq","lt"].indexOf(op) >= 0) retOne(); else retZero(); %>
<% } // cmpLong%>
<% function cmpOp(op) { %>
;;;;;;;;;;;;;;;;;;;;;;
; <%= op %>
;;;;;;;;;;;;;;;;;;;;;;
; Compares two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result can be zero or one.
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_<%=op%>:
sub rsp, <%= (n64+1)*8 %> ; Save space for the result of the substraction
push rdi ; Save rdi
lea rdi, [rsp+8] ; We pushed rdi so we need to add 8
call <%=name%>_sub ; Do a substraction
call <%=name%>_toNormal ; Convert it to normal
pop rdi
mov rax, [rsp] ; We already poped do no need to add 8
bt rax, 63 ; check is result is long
jc <%=op%>_longCmp
<%=op%>_shortCmp:
cmp eax, 0
je <%=op%>_s_eq
js <%=op%>_s_lt
<%=op%>_s_gt:
<% if (["geq","gt", "neq"].indexOf(op) >= 0) retOne(); else retZero(); %>
<%=op%>_s_lt:
<% if (["leq","lt", "neq"].indexOf(op) >= 0) retOne(); else retZero(); %>
<%=op%>_s_eq:
<% if (["eq","geq", "leq"].indexOf(op) >= 0) retOne(); else retZero(); %>
<%=op%>_longCmp:
<% for (let i=n64-1; i>=0; i--) { %>
cmp qword [rsp + <%= 8+(i*8) %>], 0
jnz <%=op%>_neq
<% } %>
<%=op%>_eq:
<% if (op == "eq") {
retOne();
} else if (op == "neq") {
retZero();
} else {
cmpLong(op, true);
}
%>
<%=op%>_neq:
<% if (op == "neq") {
retOne();
} else if (op == "eq") {
retZero();
} else {
cmpLong(op, false);
}
%>
<% } %>
<%= cmpOp("eq") %>
<%= cmpOp("neq") %>
<%= cmpOp("lt") %>
<%= cmpOp("gt") %>
<%= cmpOp("leq") %>
<%= cmpOp("geq") %>

View File

@@ -0,0 +1,139 @@
;;;;;;;;;;;;;;;;;;;;;;
; copy
;;;;;;;;;;;;;;;;;;;;;;
; Copies
; Params:
; rsi <= the src
; rdi <= the dest
;
; Nidified registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_copy:
<% for (let i=0; i<=n64; i++) { %>
mov rax, [rsi + <%= i*8 %>]
mov [rdi + <%= i*8 %>], rax
<% } %>
ret
;;;;;;;;;;;;;;;;;;;;;;
; copy an array of integers
;;;;;;;;;;;;;;;;;;;;;;
; Copies
; Params:
; rsi <= the src
; rdi <= the dest
; rdx <= number of integers to copy
;
; Nidified registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_copyn:
<%=name%>_copyn_loop:
mov r8, rsi
mov r9, rdi
mov rax, <%= n64+1 %>
mul rdx
mov rcx, rax
cld
rep movsq
mov rsi, r8
mov rdi, r9
ret
;;;;;;;;;;;;;;;;;;;;;;
; rawCopyS2L
;;;;;;;;;;;;;;;;;;;;;;
; Convert a 64 bit integer to a long format field element
; Params:
; rsi <= the integer
; rdi <= Pointer to the overwritted element
;
; Nidified registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;;
rawCopyS2L:
mov al, 0x80
shl rax, 56
mov [rdi], rax ; set the result to LONG normal
cmp rsi, 0
js u64toLong_adjust_neg
mov [rdi + 8], rsi
xor rax, rax
<% for (let i=1; i<n64; i++) { %>
mov [rdi + <%= 8+i*8 %>], rax
<% } %>
ret
u64toLong_adjust_neg:
add rsi, [q] ; Set the first digit
mov [rdi + 8], rsi ;
mov rsi, -1 ; all ones
<% for (let i=1; i<n64; i++) { %>
mov rax, rsi ; Add to q
adc rax, [q + <%= i*8 %> ]
mov [rdi + <%= (i+1)*8 %>], rax
<% } %>
ret
;;;;;;;;;;;;;;;;;;;;;;
; toInt
;;;;;;;;;;;;;;;;;;;;;;
; Convert a 64 bit integer to a long format field element
; Params:
; rsi <= Pointer to the element
; Returs:
; rax <= The value
;;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_toInt:
mov rax, [rdi]
bt rax, 63
jc <%=name%>_long
movsx rax, eax
ret
<%=name%>_long:
bt rax, 62
jnc <%=name%>_longNormal
<%=name%>_longMontgomery:
call <%=name%>_toLongNormal
<%=name%>_longNormal:
mov rax, [rdi + 8]
mov rcx, rax
shr rcx, 31
jnz <%=name%>_longNeg
<% for (let i=1; i< n64; i++) { %>
mov rcx, [rdi + <%= i*8+8 %>]
test rcx, rcx
jnz <%=name%>_longNeg
<% } %>
ret
<%=name%>_longNeg:
mov rax, [rdi + 8]
sub rax, [q]
jnc <%=name%>_longErr
<% for (let i=1; i<n64; i++) { %>
mov rcx, [rdi + <%= i*8+8 %>]
sbb rcx, [q + <%= i*8 %>]
jnc <%=name%>_longErr
<% } %>
mov rcx, rax
sar rcx, 31
add rcx, 1
jnz <%=name%>_longErr
ret
<%=name%>_longErr:
push rdi
mov rdi, 0
call <%=name%>_fail
pop rdi

5528
ports/c/buildasm/fr.asm Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,53 @@
global <%=name%>_copy
global <%=name%>_copyn
global <%=name%>_add
global <%=name%>_sub
global <%=name%>_neg
global <%=name%>_mul
global <%=name%>_square
global <%=name%>_band
global <%=name%>_bor
global <%=name%>_bxor
global <%=name%>_bnot
global <%=name%>_eq
global <%=name%>_neq
global <%=name%>_lt
global <%=name%>_gt
global <%=name%>_leq
global <%=name%>_geq
global <%=name%>_land
global <%=name%>_lor
global <%=name%>_lnot
global <%=name%>_toNormal
global <%=name%>_toLongNormal
global <%=name%>_toMontgomery
global <%=name%>_toInt
global <%=name%>_isTrue
global <%=name%>_q
extern <%=name%>_fail
DEFAULT REL
section .text
<%- include('utils.asm.ejs'); %>
<%- include('copy.asm.ejs'); %>
<%- include('montgomery.asm.ejs'); %>
<%- include('add.asm.ejs'); %>
<%- include('sub.asm.ejs'); %>
<%- include('neg.asm.ejs'); %>
<%- include('mul.asm.ejs'); %>
<%- include('binops.asm.ejs'); %>
<%- include('cmpops.asm.ejs'); %>
<%- include('logicalops.asm.ejs'); %>
section .data
<%=name%>_q:
dd 0
dd 0x80000000
q dq <%= constantElement(q) %>
half dq <%= constantElement(q.shiftRight(1)) %>
R2 dq <%= constantElement(bigInt.one.shiftLeft(n64*64*2).mod(q)) %>
R3 dq <%= constantElement(bigInt.one.shiftLeft(n64*64*3).mod(q)) %>
lboMask dq 0x<%= bigInt("8000000000000000",16).shiftRight(n64*64 - q.bitLength()).minus(bigInt.one).toString(16) %>

184
ports/c/buildasm/fr.c Normal file
View File

@@ -0,0 +1,184 @@
#include "fr.h"
#include <stdio.h>
#include <stdlib.h>
#include <gmp.h>
#include <assert.h>
mpz_t q;
mpz_t zero;
mpz_t one;
mpz_t mask;
size_t nBits;
void Fr_toMpz(mpz_t r, PFrElement pE) {
Fr_toNormal(pE);
if (!(pE->type & Fr_LONG)) {
mpz_set_si(r, pE->shortVal);
if (pE->shortVal<0) {
mpz_add(r, r, q);
}
} else {
mpz_import(r, Fr_N64, -1, 8, -1, 0, (const void *)pE->longVal);
}
}
void Fr_fromMpz(PFrElement pE, mpz_t v) {
if (mpz_fits_sint_p(v)) {
pE->type = Fr_SHORT;
pE->shortVal = mpz_get_si(v);
} else {
pE->type = Fr_LONG;
for (int i=0; i<Fr_N64; i++) pE->longVal[i] = 0;
mpz_export((void *)(pE->longVal), NULL, -1, 8, -1, 0, v);
}
}
void Fr_init() {
mpz_init(q);
mpz_import(q, Fr_N64, -1, 8, -1, 0, (const void *)Fr_q.longVal);
mpz_init_set_ui(zero, 0);
mpz_init_set_ui(one, 1);
nBits = mpz_sizeinbase (q, 2);
mpz_init(mask);
mpz_mul_2exp(mask, one, nBits-1);
mpz_sub(mask, mask, one);
}
void Fr_str2element(PFrElement pE, char const *s) {
mpz_t mr;
mpz_init_set_str(mr, s, 10);
Fr_fromMpz(pE, mr);
}
char *Fr_element2str(PFrElement pE) {
mpz_t r;
if (!(pE->type & Fr_LONG)) {
if (pE->shortVal>=0) {
char *r = new char[32];
sprintf(r, "%d", pE->shortVal);
return r;
} else {
mpz_init_set_si(r, pE->shortVal);
mpz_add(r, r, q);
}
} else {
Fr_toNormal(pE);
mpz_init(r);
mpz_import(r, Fr_N64, -1, 8, -1, 0, (const void *)pE->longVal);
}
char *res = mpz_get_str (0, 10, r);
mpz_clear(r);
return res;
}
void Fr_idiv(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
// char *s1 = mpz_get_str (0, 10, ma);
// printf("s1 %s\n", s1);
Fr_toMpz(mb, b);
// char *s2 = mpz_get_str (0, 10, mb);
// printf("s2 %s\n", s2);
mpz_fdiv_q(mr, ma, mb);
// char *sr = mpz_get_str (0, 10, mr);
// printf("r %s\n", sr);
Fr_fromMpz(r, mr);
}
void Fr_mod(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
Fr_toMpz(mb, b);
mpz_fdiv_r(mr, ma, mb);
Fr_fromMpz(r, mr);
}
void Fr_shl(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
Fr_toMpz(mb, b);
if (mpz_cmp_ui(mb, nBits) >= 0) {
mpz_set(mr, zero);
} else {
mpz_mul_2exp(mr, ma, mpz_get_ui(mb));
mpz_and(mr, mr, mask);
}
Fr_fromMpz(r, mr);
}
void Fr_shr(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
Fr_toMpz(mb, b);
if (mpz_cmp_ui(mb, nBits) >= 0) {
mpz_set(mr, zero);
} else {
mpz_tdiv_q_2exp(mr, ma, mpz_get_ui(mb));
mpz_and(mr, mr, mask);
}
Fr_fromMpz(r, mr);
}
void Fr_pow(PFrElement r, PFrElement a, PFrElement b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
Fr_toMpz(ma, a);
Fr_toMpz(mb, b);
mpz_powm(mr, ma, mb, q);
Fr_fromMpz(r, mr);
}
void Fr_inv(PFrElement r, PFrElement a) {
mpz_t ma;
mpz_t mr;
mpz_init(ma);
mpz_init(mr);
Fr_toMpz(ma, a);
mpz_invert(mr, ma, q);
Fr_fromMpz(r, mr);
}
void Fr_div(PFrElement r, PFrElement a, PFrElement b) {
FrElement tmp;
Fr_inv(&tmp, b);
Fr_mul(r, a, &tmp);
}
void Fr_fail() {
assert(false);
}

184
ports/c/buildasm/fr.c.ejs Normal file
View File

@@ -0,0 +1,184 @@
#include "<%=name.toLowerCase()+".h"%>"
#include <stdio.h>
#include <stdlib.h>
#include <gmp.h>
#include <assert.h>
mpz_t q;
mpz_t zero;
mpz_t one;
mpz_t mask;
size_t nBits;
void <%=name%>_toMpz(mpz_t r, P<%=name%>Element pE) {
<%=name%>_toNormal(pE);
if (!(pE->type & <%=name%>_LONG)) {
mpz_set_si(r, pE->shortVal);
if (pE->shortVal<0) {
mpz_add(r, r, q);
}
} else {
mpz_import(r, <%=name%>_N64, -1, 8, -1, 0, (const void *)pE->longVal);
}
}
void <%=name%>_fromMpz(P<%=name%>Element pE, mpz_t v) {
if (mpz_fits_sint_p(v)) {
pE->type = <%=name%>_SHORT;
pE->shortVal = mpz_get_si(v);
} else {
pE->type = <%=name%>_LONG;
for (int i=0; i<<%=name%>_N64; i++) pE->longVal[i] = 0;
mpz_export((void *)(pE->longVal), NULL, -1, 8, -1, 0, v);
}
}
void <%=name%>_init() {
mpz_init(q);
mpz_import(q, <%=name%>_N64, -1, 8, -1, 0, (const void *)Fr_q.longVal);
mpz_init_set_ui(zero, 0);
mpz_init_set_ui(one, 1);
nBits = mpz_sizeinbase (q, 2);
mpz_init(mask);
mpz_mul_2exp(mask, one, nBits-1);
mpz_sub(mask, mask, one);
}
void <%=name%>_str2element(P<%=name%>Element pE, char const *s) {
mpz_t mr;
mpz_init_set_str(mr, s, 10);
<%=name%>_fromMpz(pE, mr);
}
char *<%=name%>_element2str(P<%=name%>Element pE) {
mpz_t r;
if (!(pE->type & <%=name%>_LONG)) {
if (pE->shortVal>=0) {
char *r = new char[32];
sprintf(r, "%d", pE->shortVal);
return r;
} else {
mpz_init_set_si(r, pE->shortVal);
mpz_add(r, r, q);
}
} else {
<%=name%>_toNormal(pE);
mpz_init(r);
mpz_import(r, <%=name%>_N64, -1, 8, -1, 0, (const void *)pE->longVal);
}
char *res = mpz_get_str (0, 10, r);
mpz_clear(r);
return res;
}
void <%=name%>_idiv(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
<%=name%>_toMpz(ma, a);
// char *s1 = mpz_get_str (0, 10, ma);
// printf("s1 %s\n", s1);
<%=name%>_toMpz(mb, b);
// char *s2 = mpz_get_str (0, 10, mb);
// printf("s2 %s\n", s2);
mpz_fdiv_q(mr, ma, mb);
// char *sr = mpz_get_str (0, 10, mr);
// printf("r %s\n", sr);
<%=name%>_fromMpz(r, mr);
}
void <%=name%>_mod(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
<%=name%>_toMpz(ma, a);
<%=name%>_toMpz(mb, b);
mpz_fdiv_r(mr, ma, mb);
<%=name%>_fromMpz(r, mr);
}
void <%=name%>_shl(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
<%=name%>_toMpz(ma, a);
<%=name%>_toMpz(mb, b);
if (mpz_cmp_ui(mb, nBits) >= 0) {
mpz_set(mr, zero);
} else {
mpz_mul_2exp(mr, ma, mpz_get_ui(mb));
mpz_and(mr, mr, mask);
}
<%=name%>_fromMpz(r, mr);
}
void <%=name%>_shr(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
<%=name%>_toMpz(ma, a);
<%=name%>_toMpz(mb, b);
if (mpz_cmp_ui(mb, nBits) >= 0) {
mpz_set(mr, zero);
} else {
mpz_tdiv_q_2exp(mr, ma, mpz_get_ui(mb));
mpz_and(mr, mr, mask);
}
<%=name%>_fromMpz(r, mr);
}
void <%=name%>_pow(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b) {
mpz_t ma;
mpz_t mb;
mpz_t mr;
mpz_init(ma);
mpz_init(mb);
mpz_init(mr);
<%=name%>_toMpz(ma, a);
<%=name%>_toMpz(mb, b);
mpz_powm(mr, ma, mb, q);
<%=name%>_fromMpz(r, mr);
}
void <%=name%>_inv(P<%=name%>Element r, P<%=name%>Element a) {
mpz_t ma;
mpz_t mr;
mpz_init(ma);
mpz_init(mr);
<%=name%>_toMpz(ma, a);
mpz_invert(mr, ma, q);
<%=name%>_fromMpz(r, mr);
}
void <%=name%>_div(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b) {
<%=name%>Element tmp;
<%=name%>_inv(&tmp, b);
<%=name%>_mul(r, a, &tmp);
}
void <%=name%>_fail() {
assert(false);
}

67
ports/c/buildasm/fr.h Normal file
View File

@@ -0,0 +1,67 @@
#ifndef __FR_H
#define __FR_H
#include <stdint.h>
#define Fr_N64 4
#define Fr_SHORT 0x00000000
#define Fr_LONG 0x80000000
#define Fr_LONGMONTGOMERY 0xC0000000
typedef struct __attribute__((__packed__)) {
int32_t shortVal;
uint32_t type;
uint64_t longVal[Fr_N64];
} FrElement;
typedef FrElement *PFrElement;
extern FrElement Fr_q;
extern "C" void Fr_copy(PFrElement r, PFrElement a);
extern "C" void Fr_copyn(PFrElement r, PFrElement a, int n);
extern "C" void Fr_add(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_sub(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_neg(PFrElement r, PFrElement a);
extern "C" void Fr_mul(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_square(PFrElement r, PFrElement a);
extern "C" void Fr_band(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_bor(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_bxor(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_bnot(PFrElement r, PFrElement a);
extern "C" void Fr_eq(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_neq(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_lt(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_gt(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_leq(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_geq(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_land(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_lor(PFrElement r, PFrElement a, PFrElement b);
extern "C" void Fr_lnot(PFrElement r, PFrElement a);
extern "C" void Fr_toNormal(PFrElement pE);
extern "C" void Fr_toLongNormal(PFrElement pE);
extern "C" void Fr_toMontgomery(PFrElement pE);
extern "C" int Fr_isTrue(PFrElement pE);
extern "C" int Fr_toInt(PFrElement pE);
extern "C" void Fr_fail();
extern FrElement Fr_q;
// Pending functions to convert
void Fr_str2element(PFrElement pE, char const*s);
char *Fr_element2str(PFrElement pE);
void Fr_idiv(PFrElement r, PFrElement a, PFrElement b);
void Fr_mod(PFrElement r, PFrElement a, PFrElement b);
void Fr_inv(PFrElement r, PFrElement a);
void Fr_div(PFrElement r, PFrElement a, PFrElement b);
void Fr_shl(PFrElement r, PFrElement a, PFrElement b);
void Fr_shr(PFrElement r, PFrElement a, PFrElement b);
void Fr_pow(PFrElement r, PFrElement a, PFrElement b);
void Fr_init();
#endif // __FR_H

67
ports/c/buildasm/fr.h.ejs Normal file
View File

@@ -0,0 +1,67 @@
#ifndef __<%=name.toUpperCase()%>_H
#define __<%=name.toUpperCase()%>_H
#include <stdint.h>
#define <%=name%>_N64 <%= n64 %>
#define <%=name%>_SHORT 0x00000000
#define <%=name%>_LONG 0x80000000
#define <%=name%>_LONGMONTGOMERY 0xC0000000
typedef struct __attribute__((__packed__)) {
int32_t shortVal;
uint32_t type;
uint64_t longVal[<%=name%>_N64];
} <%=name%>Element;
typedef <%=name%>Element *P<%=name%>Element;
extern <%=name%>Element <%=name%>_q;
extern "C" void <%=name%>_copy(P<%=name%>Element r, P<%=name%>Element a);
extern "C" void <%=name%>_copyn(P<%=name%>Element r, P<%=name%>Element a, int n);
extern "C" void <%=name%>_add(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_sub(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_neg(P<%=name%>Element r, P<%=name%>Element a);
extern "C" void <%=name%>_mul(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_square(P<%=name%>Element r, P<%=name%>Element a);
extern "C" void <%=name%>_band(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_bor(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_bxor(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_bnot(P<%=name%>Element r, P<%=name%>Element a);
extern "C" void <%=name%>_eq(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_neq(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_lt(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_gt(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_leq(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_geq(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_land(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_lor(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
extern "C" void <%=name%>_lnot(P<%=name%>Element r, P<%=name%>Element a);
extern "C" void <%=name%>_toNormal(P<%=name%>Element pE);
extern "C" void <%=name%>_toLongNormal(P<%=name%>Element pE);
extern "C" void <%=name%>_toMontgomery(P<%=name%>Element pE);
extern "C" int <%=name%>_isTrue(P<%=name%>Element pE);
extern "C" int <%=name%>_toInt(P<%=name%>Element pE);
extern "C" void <%=name%>_fail();
extern <%=name%>Element <%=name%>_q;
// Pending functions to convert
void <%=name%>_str2element(P<%=name%>Element pE, char const*s);
char *<%=name%>_element2str(P<%=name%>Element pE);
void <%=name%>_idiv(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
void <%=name%>_mod(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
void <%=name%>_inv(P<%=name%>Element r, P<%=name%>Element a);
void <%=name%>_div(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
void <%=name%>_shl(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
void <%=name%>_shr(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
void <%=name%>_pow(P<%=name%>Element r, P<%=name%>Element a, P<%=name%>Element b);
void <%=name%>_init();
#endif // __<%=name.toUpperCase()%>_H

BIN
ports/c/buildasm/fr.o Normal file

Binary file not shown.

View File

@@ -0,0 +1,97 @@
<% function isTrue(resReg, srcPtrReg) { %>
<% const longIsZero = global.tmpLabel() %>
<% const retOne = global.tmpLabel("retOne") %>
<% const retZero = global.tmpLabel("retZero") %>
<% const done = global.tmpLabel("done") %>
mov rax, [<%=srcPtrReg%>]
bt rax, 63
jc <%= longIsZero %>
test eax, eax
jz <%= retZero %>
jmp <%= retOne %>
<%= longIsZero %>:
<% for (let i=0; i<n64; i++) { %>
mov rax, [<%= srcPtrReg + " + " +(i*8+8) %>]
test rax, rax
jnz <%= retOne %>
<% } %>
<%= retZero %>:
mov qword <%=resReg%>, 0
jmp <%= done %>
<%= retOne %>:
mov qword <%=resReg%>, 1
<%= done %>:
<% } %>
<% function logicalOp(op) { %>
;;;;;;;;;;;;;;;;;;;;;;
; l<%= op %>
;;;;;;;;;;;;;;;;;;;;;;
; Logical <%= op %> between two elements
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result zero or one
; Modified Registers:
; rax, rcx, r8
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_l<%=op%>:
<%= isTrue("r8", "rsi") %>
<%= isTrue("rcx", "rdx") %>
<%=op%> rcx, r8
mov [rdi], rcx
ret
<% } %>
<% logicalOp("and"); %>
<% logicalOp("or"); %>
;;;;;;;;;;;;;;;;;;;;;;
; lnot
;;;;;;;;;;;;;;;;;;;;;;
; Do the logical not of an element
; Params:
; rsi <= Pointer to element to be tested
; rdi <= Pointer to result one if element1 is zero and zero otherwise
; Modified Registers:
; rax, rax, r8
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_lnot:
<%= isTrue("rcx", "rsi") %>
test rcx, rcx
jz lnot_retOne
lnot_retZero:
mov qword [rdi], 0
ret
lnot_retOne:
mov qword [rdi], 1
ret
;;;;;;;;;;;;;;;;;;;;;;
; isTrue
;;;;;;;;;;;;;;;;;;;;;;
; Convert a 64 bit integer to a long format field element
; Params:
; rsi <= Pointer to the element
; Returs:
; rax <= 1 if true 0 if false
;;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_isTrue:
<%= isTrue("rax", "rdi") %>
ret

64
ports/c/buildasm/main.c Normal file
View File

@@ -0,0 +1,64 @@
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include "fr.h"
int main() {
Fr_init();
/*
FrElement a = { 0, Fr_LONGMONTGOMERY, {1,1,1,1}};
FrElement b = { 0, Fr_LONGMONTGOMERY, {2,2,2,2}};
FrElement a={0x43e1f593f0000000ULL,0x2833e84879b97091ULL,0xb85045b68181585dULL,0x30644e72e131a029ULL};
FrElement b = {3,0,0,0};
FrElement c;
*/
// Fr_add(&(c[0]), a, a);
// Fr_add(&(c[0]), c, b);
/*
for (int i=0; i<1000000000; i++) {
Fr_mul(&c, &a, &b);
}
Fr_mul(&c,&a, &b);
*/
/*
FrElement a1[10];
FrElement a2[10];
for (int i=0; i<10; i++) {
a1[i].type = Fr_LONGMONTGOMERY;
a1[i].shortVal =0;
for (int j=0; j<Fr_N64; j++) {
a2[i].longVal[j] = i;
}
}
Fr_copyn(a2, a1, 10);
for (int i=0; i<10; i++) {
char *c1 = Fr_element2str(&a1[i]);
char *c2 = Fr_element2str(&a2[i]);
printf("%s\n%s\n\n", c1, c2);
free(c1);
free(c2);
}
*/
int tests[7] = { 0, 1, 2, -1, -2, 0x7FFFFFFF, (int)0x80000000};
for (int i=0; i<7;i++) {
FrElement a = { tests[i], Fr_SHORT, {0,0,0,0}};
Fr_toLongNormal(&a);
int b = Fr_toInt(&a);
int c = Fr_isTrue(&a);
printf("%d, %d, %d\n", tests[i], b, c);
}
FrElement err = { 0, Fr_LONGMONTGOMERY, {1,1,1,1}};
Fr_toInt(&err);
// printf("%llu, %llu, %llu, %llu\n", c.longVal[0], c.longVal[1], c.longVal[2], c.longVal[3]);
}

View File

@@ -0,0 +1,342 @@
<%
//////////////////////
// montgomeryTemplate
//////////////////////
// This function creates functions with the montgomery transformation
// applied
// the round hook allows to add diferent code in the iteration
//
// All the montgomery functions modifies:
// r8, r9, 10, r11, rax, rcx
//////////////////////
function montgomeryTemplate(fnName, round) {
let r0, r1, r2;
function setR(step) {
if ((step % 3) == 0) {
r0 = "r8";
r1 = "r9";
r2 = "r10";
} else if ((step % 3) == 1) {
r0 = "r9";
r1 = "r10";
r2 = "r8";
} else {
r0 = "r10";
r1 = "r8";
r2 = "r9";
}
}
const base = bigInt.one.shiftLeft(64);
const np64 = base.minus(q.modInv(base));
%>
<%=fnName%>:
sub rsp, <%= n64*8 %> ; Reserve space for ms
mov rcx, rdx ; rdx is needed for multiplications so keep it in cx
mov r11, 0x<%= np64.toString(16) %> ; np
xor r8,r8
xor r9,r9
xor r10,r10
<%
// Main loop
for (let i=0; i<n64*2; i++) {
setR(i);
round(i, r0, r1, r2);
%>
<%
for (let j=i-1; j>=0; j--) { // All ms
if (((i-j)<n64)&&(j<n64)) {
%>
mov rax, [rsp + <%= j*8 %>]
mul qword [q + <%= (i-j)*8 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
}
} // ms
%>
<%
if (i<n64) {
%>
mov rax, <%= r0 %>
mul r11
mov [rsp + <%= i*8 %>], rax
mul qword [q]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} else {
%>
mov [rdi + <%= (i-n64)*8 %> ], <%= r0 %>
xor <%= r0 %>,<%= r0 %>
<%
}
%>
<%
} // Main Loop
%>
test <%= r1 %>, <%= r1 %>
jnz <%=fnName%>_mulM_sq
; Compare with q
<%
for (let i=0; i<n64; i++) {
%>
mov rax, [rdi + <%= (n64-i-1)*8 %>]
cmp rax, [q + <%= (n64-i-1)*8 %>]
jc <%=fnName%>_mulM_done ; q is bigget so done.
jnz <%=fnName%>_mulM_sq ; q is lower
<%
}
%>
; If equal substract q
<%=fnName%>_mulM_sq:
<%
for (let i=0; i<n64; i++) {
%>
mov rax, [q + <%= i*8 %>]
<%= i==0 ? "sub" : "sbb" %> [rdi + <%= i*8 %>], rax
<%
}
%>
<%=fnName%>_mulM_done:
mov rdx, rcx ; recover rdx to its original place.
add rsp, <%= n64*8 %> ; recover rsp
ret
<%
} // Template
%>
;;;;;;;;;;;;;;;;;;;;;;
; rawMontgomeryMul
;;;;;;;;;;;;;;;;;;;;;;
; Multiply two elements in montgomery form
; Params:
; rsi <= Pointer to the long data of element 1
; rdx <= Pointer to the long data of element 2
; rdi <= Pointer to the long data of result
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%
montgomeryTemplate("rawMontgomeryMul", function(i, r0, r1, r2) {
// Same Digit
for (let o1=Math.max(0, i-n64+1); (o1<=i)&&(o1<n64); o1++) {
const o2= i-o1;
%>
mov rax, [rsi + <%= 8*o1 %>]
mul qword [rcx + <%= 8*o2 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} // Same digit
})
%>
;;;;;;;;;;;;;;;;;;;;;;
; rawMontgomerySquare
;;;;;;;;;;;;;;;;;;;;;;
; Square an element
; Params:
; rsi <= Pointer to the long data of element 1
; rdi <= Pointer to the long data of result
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%
montgomeryTemplate("rawMontgomerySquare", function(i, r0, r1, r2) {
// Same Digit
for (let o1=Math.max(0, i-n64+1); (o1<((i+1)>>1) )&&(o1<n64); o1++) {
const o2= i-o1;
%>
mov rax, [rsi + <%= 8*o1 %>]
mul qword [rsi + <%= 8*o2 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} // Same digit
%>
<% if (i%2 == 0) { %>
mov rax, [rsi + <%= 8*(i/2) %>]
mul rax
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<% } %>
<%
})
%>
;;;;;;;;;;;;;;;;;;;;;;
; rawMontgomeryMul1
;;;;;;;;;;;;;;;;;;;;;;
; Multiply two elements in montgomery form
; Params:
; rsi <= Pointer to the long data of element 1
; rdx <= second operand
; rdi <= Pointer to the long data of result
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%
montgomeryTemplate("rawMontgomeryMul1", function(i, r0, r1, r2) {
// Same Digit
if (i<n64) {
%>
mov rax, [rsi + <%= 8*i %>]
mul rcx
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} // Same digit
})
%>
;;;;;;;;;;;;;;;;;;;;;;
; rawFromMontgomery
;;;;;;;;;;;;;;;;;;;;;;
; Multiply two elements in montgomery form
; Params:
; rsi <= Pointer to the long data of element 1
; rdi <= Pointer to the long data of result
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%
montgomeryTemplate("rawFromMontgomery", function(i, r0, r1, r2) {
// Same Digit
if (i<n64) {
%>
add <%= r0 %>, [rdi + <%= 8*i %>]
adc <%= r1 %>, 0x0
adc <%= r2 %>, 0x0
<%
} // Same digit
})
%>
;;;;;;;;;;;;;;;;;;;;;;
; toMontgomery
;;;;;;;;;;;;;;;;;;;;;;
; Convert a number to Montgomery
; rdi <= Pointer element to convert
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;
<%=name%>_toMontgomery:
mov rax, [rdi]
bts rax, 62 ; check if montgomery
jc toMontgomery_doNothing
bts rax, 63
jc toMontgomeryLong
toMontgomeryShort:
mov [rdi], rax
add rdi, 8
push rsi
lea rsi, [R2]
movsx rdx, eax
cmp rdx, 0
js negMontgomeryShort
posMontgomeryShort:
call rawMontgomeryMul1
pop rsi
sub rdi, 8
ret
negMontgomeryShort:
neg rdx ; Do the multiplication positive and then negate the result.
call rawMontgomeryMul1
mov rsi, rdi
call rawNegL
pop rsi
sub rdi, 8
ret
toMontgomeryLong:
mov [rdi], rax
add rdi, 8
push rsi
mov rdx, rdi
lea rsi, [R2]
call rawMontgomeryMul
pop rsi
sub rdi, 8
toMontgomery_doNothing:
ret
;;;;;;;;;;;;;;;;;;;;;;
; toNormal
;;;;;;;;;;;;;;;;;;;;;;
; Convert a number from Montgomery
; rdi <= Pointer element to convert
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;
<%=name%>_toNormal:
mov rax, [rdi]
btc rax, 62 ; check if montgomery
jnc toNormal_doNothing
bt rax, 63 ; if short, it means it's converted
jnc toNormal_doNothing
toNormalLong:
mov [rdi], rax
add rdi, 8
call rawFromMontgomery
sub rdi, 8
toNormal_doNothing:
ret
;;;;;;;;;;;;;;;;;;;;;;
; toLongNormal
;;;;;;;;;;;;;;;;;;;;;;
; Convert a number to long normal
; rdi <= Pointer element to convert
; Modified registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;
<%=name%>_toLongNormal:
mov rax, [rdi]
bt rax, 62 ; check if montgomery
jc toLongNormal_fromMontgomery
bt rax, 63 ; check if long
jnc toLongNormal_fromShort
ret ; It is already long
toLongNormal_fromMontgomery:
add rdi, 8
call rawFromMontgomery
sub rdi, 8
ret
toLongNormal_fromShort:
mov r8, rsi ; save rsi
movsx rsi, eax
call rawCopyS2L
mov rsi, r8 ; recover rsi
ret

View File

@@ -0,0 +1,275 @@
<% function mulS1S2() { %>
xor rax, rax
mov eax, r8d
imul r9d
jo mul_manageOverflow ; rsi already is the 64bits result
mov [rdi], rax ; not necessary to adjust so just save and return
mul_manageOverflow: ; Do the operation in 64 bits
push rsi
movsx rax, r8d
movsx rcx, r9d
imul rcx
mov rsi, rax
call rawCopyS2L
pop rsi
<% } %>
<% function squareS1() { %>
xor rax, rax
mov eax, r8d
imul eax
jo square_manageOverflow ; rsi already is the 64bits result
mov [rdi], rax ; not necessary to adjust so just save and return
square_manageOverflow: ; Do the operation in 64 bits
push rsi
movsx rax, r8d
imul rax
mov rsi, rax
call rawCopyS2L
pop rsi
<% } %>
<% function mulL1S2(t) { %>
push rsi
add rsi, 8
movsx rdx, r9d
add rdi, 8
cmp rdx, 0
<% const rawPositiveLabel = global.tmpLabel() %>
jns <%= rawPositiveLabel %>
neg rdx
call rawMontgomeryMul1
mov rsi, rdi
call rawNegL
sub rdi, 8
pop rsi
<% const done = global.tmpLabel() %>
jmp <%= done %>
<%= rawPositiveLabel %>:
call rawMontgomeryMul1
sub rdi, 8
pop rsi
<%= done %>:
<% } %>
<% function mulS1L2() { %>
push rsi
lea rsi, [rdx + 8]
movsx rdx, r8d
add rdi, 8
cmp rdx, 0
<% const rawPositiveLabel = global.tmpLabel() %>
jns <%= rawPositiveLabel %>
neg rdx
call rawMontgomeryMul1
mov rsi, rdi
call rawNegL
sub rdi, 8
pop rsi
<% const done = global.tmpLabel() %>
jmp <%= done %>
<%= rawPositiveLabel %>:
call rawMontgomeryMul1
sub rdi, 8
pop rsi
<%= done %>:
<% } %>
<% function mulL1L2() { %>
add rdi, 8
add rsi, 8
add rdx, 8
call rawMontgomeryMul
sub rdi, 8
sub rsi, 8
<% } %>
<% function squareL1() { %>
add rdi, 8
add rsi, 8
call rawMontgomerySquare
sub rdi, 8
sub rsi, 8
<% } %>
<% function mulR3() { %>
push rsi
add rdi, 8
mov rsi, rdi
lea rdx, [R3]
call rawMontgomeryMul
sub rdi, 8
pop rsi
<% } %>
;;;;;;;;;;;;;;;;;;;;;;
; square
;;;;;;;;;;;;;;;;;;;;;;
; Squares a field element
; Params:
; rsi <= Pointer to element 1
; rdi <= Pointer to result
; [rdi] = [rsi] * [rsi]
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_square:
mov r8, [rsi]
bt r8, 63 ; Check if is short first operand
jc square_l1
square_s1: ; Both operands are short
<%= squareS1() %>
ret
square_l1:
bt r8, 62 ; check if montgomery first
jc square_l1m
square_l1n:
<%= global.setTypeDest("0xC0"); %>
<%= squareL1() %>
<%= mulR3() %>
ret
square_l1m:
<%= global.setTypeDest("0xC0"); %>
<%= squareL1() %>
ret
;;;;;;;;;;;;;;;;;;;;;;
; mul
;;;;;;;;;;;;;;;;;;;;;;
; Multiplies two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result
; [rdi] = [rsi] * [rdi]
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_mul:
mov r8, [rsi]
mov r9, [rdx]
bt r8, 63 ; Check if is short first operand
jc mul_l1
bt r9, 63 ; Check if is short second operand
jc mul_s1l2
mul_s1s2: ; Both operands are short
<%= mulS1S2() %>
ret
mul_l1:
bt r9, 63 ; Check if is short second operand
jc mul_l1l2
;;;;;;;;
mul_l1s2:
bt r8, 62 ; check if montgomery first
jc mul_l1ms2
mul_l1ns2:
bt r9, 62 ; check if montgomery first
jc mul_l1ns2m
mul_l1ns2n:
<%= global.setTypeDest("0xC0"); %>
<%= mulL1S2() %>
<%= mulR3() %>
ret
mul_l1ns2m:
<%= global.setTypeDest("0x80"); %>
<%= mulL1L2() %>
ret
mul_l1ms2:
bt r9, 62 ; check if montgomery second
jc mul_l1ms2m
mul_l1ms2n:
<%= global.setTypeDest("0x80"); %>
<%= mulL1S2() %>
ret
mul_l1ms2m:
<%= global.setTypeDest("0xC0"); %>
<%= mulL1L2() %>
ret
;;;;;;;;
mul_s1l2:
bt r8, 62 ; check if montgomery first
jc mul_s1ml2
mul_s1nl2:
bt r9, 62 ; check if montgomery first
jc mul_s1nl2m
mul_s1nl2n:
<%= global.setTypeDest("0xC0"); %>
<%= mulS1L2() %>
<%= mulR3() %>
ret
mul_s1nl2m:
<%= global.setTypeDest("0x80"); %>
<%= mulS1L2(); %>
ret
mul_s1ml2:
bt r9, 62 ; check if montgomery first
jc mul_s1ml2m
mul_s1ml2n:
<%= global.setTypeDest("0x80"); %>
<%= mulL1L2() %>
ret
mul_s1ml2m:
<%= global.setTypeDest("0xC0"); %>
<%= mulL1L2() %>
ret
;;;;
mul_l1l2:
bt r8, 62 ; check if montgomery first
jc mul_l1ml2
mul_l1nl2:
bt r9, 62 ; check if montgomery second
jc mul_l1nl2m
mul_l1nl2n:
<%= global.setTypeDest("0xC0"); %>
<%= mulL1L2() %>
<%= mulR3() %>
ret
mul_l1nl2m:
<%= global.setTypeDest("0x80"); %>
<%= mulL1L2() %>
ret
mul_l1ml2:
bt r9, 62 ; check if montgomery seconf
jc mul_l1ml2m
mul_l1ml2n:
<%= global.setTypeDest("0x80"); %>
<%= mulL1L2() %>
ret
mul_l1ml2m:
<%= global.setTypeDest("0xC0"); %>
<%= mulL1L2() %>
ret

View File

@@ -0,0 +1,78 @@
<% function negS() { %>
neg eax
jo neg_manageOverflow ; Check if overflow. (0x80000000 is the only case)
mov [rdi], rax ; not necessary to adjust so just save and return
ret
neg_manageOverflow: ; Do the operation in 64 bits
push rsi
movsx rsi, eax
neg rsi
call rawCopyS2L
pop rsi
ret
<% } %>
<% function negL() { %>
add rdi, 8
add rsi, 8
call rawNegL
sub rdi, 8
sub rsi, 8
ret
<% } %>
;;;;;;;;;;;;;;;;;;;;;;
; neg
;;;;;;;;;;;;;;;;;;;;;;
; Adds two elements of any kind
; Params:
; rsi <= Pointer to element to be negated
; rdi <= Pointer to result
; [rdi] = -[rsi]
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_neg:
mov rax, [rsi]
bt rax, 63 ; Check if is short first operand
jc neg_l
neg_s: ; Operand is short
<%= negS() %>
neg_l:
mov [rdi], rax ; Copy the type
<%= negL() %>
;;;;;;;;;;;;;;;;;;;;;;
; rawNeg
;;;;;;;;;;;;;;;;;;;;;;
; Negates a value
; Params:
; rdi <= Pointer to the long data of result
; rsi <= Pointer to the long data of element 1
;
; [rdi] = - [rsi]
;;;;;;;;;;;;;;;;;;;;;;
rawNegL:
; Compare is zero
xor rax, rax
<% for (let i=0; i<n64; i++) { %>
cmp [rsi + <%=i*8%>], rax
jnz doNegate
<% } %>
; it's zero so just set to zero
<% for (let i=0; i<n64; i++) { %>
mov [rdi + <%=i*8%>], rax
<% } %>
ret
doNegate:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "sub" : "sbb" %> rax, [rsi + <%=i*8%>]
mov [rdi + <%=i*8%>], rax
<% } %>
ret

View File

@@ -0,0 +1,33 @@
const tester = require("../c/buildasm/buildzqfieldtester2.js");
const bigInt = require("big-integer");
const __P__ = new bigInt("21888242871839275222246405745257275088548364400416034343698204186575808495617");
describe("basic cases", function () {
this.timeout(100000);
it("should do basic tests", async () => {
await tester(__P__, [
["add", 0, 0],
["add", 0, 1],
["add", 1, 0],
["add", 1, 1],
["add", 2, 1],
["add", 2, 10],
["add", -1, -1],
["add", -20, -10],
["add", "10604728079509999371218483608188593244163417117449316147628604036713980815027", "10604728079509999371218483608188593244163417117449316147628604036713980815027"],
["mul", 0, 0],
["mul", 0, 1],
["mul", 1, 0],
["mul", 1, 1],
["mul", 2, 1],
["mul", 2, 10],
["mul", -1, -1],
["mul", -20, -10],
["mul", "10604728079509999371218483608188593244163417117449316147628604036713980815027", "10604728079509999371218483608188593244163417117449316147628604036713980815027"],
]);
});
});

View File

@@ -0,0 +1,209 @@
const bigInt=require("big-integer");
class ZqBuilder {
constructor(q, name) {
this.q=bigInt(q);
this.h = [];
this.c = [];
this.name = name;
}
build() {
this._buildHeaders();
this._buildAdd();
this._buildMul();
this.c.push(""); this.h.push("");
return [this.h.join("\n"), this.c.join("\n")];
}
_buildHeaders() {
this.n64 = Math.floor((this.q.bitLength() - 1) / 64)+1;
this.h.push("typedef unsigned long long u64;");
this.h.push(`typedef u64 ${this.name}Element[${this.n64}];`);
this.h.push(`typedef u64 *P${this.name}Element;`);
this.h.push(`extern ${this.name}Element ${this.name}_q;`);
this.h.push(`#define ${this.name}_N64 ${this.n64}`);
this.c.push(`#include "${this.name.toLowerCase()}.h"`);
this._defineConstant(`${this.name}_q`, this.q);
this.c.push(""); this.h.push("");
}
_defineConstant(n, v) {
let S = `${this.name}Element ${n}={`;
const mask = bigInt("FFFFFFFFFFFFFFFF", 16);
for (let i=0; i<this.n64; i++) {
if (i>0) S = S+",";
let shex = v.shiftRight(i*64).and(mask).toString(16);
while (shex <16) shex = "0" + shex;
S = S + "0x" + shex + "ULL";
}
S += "};";
this.c.push(S);
}
_buildAdd() {
this.h.push(`void ${this.name}_add(P${this.name}Element r, P${this.name}Element a, P${this.name}Element b);`);
this.c.push(`void ${this.name}_add(P${this.name}Element r, P${this.name}Element a, P${this.name}Element b) {`);
this.c.push(" __asm__ __volatile__ (");
for (let i=0; i<this.n64; i++) {
this.c.push(` "movq ${i*8}(%2), %%rax;"`);
this.c.push(` "${i==0 ? "addq" : "adcq"} ${i*8}(%1), %%rax;"`);
this.c.push(` "movq %%rax, ${i*8}(%0);"`);
}
this.c.push(" \"jc SQ;\"");
for (let i=0; i<this.n64; i++) {
if (i>0) {
this.c.push(` "movq ${(this.n64 - i-1)*8}(%0), %%rax;"`);
}
this.c.push(` "cmp ${(this.n64 - i-1)*8}(%3), %%rax;"`);
this.c.push(" \"jg SQ;\"");
this.c.push(" \"jl DONE;\"");
}
this.c.push(" \"SQ:\"");
for (let i=0; i<this.n64; i++) {
this.c.push(` "movq ${i*8}(%3), %%rax;"`);
this.c.push(` "${i==0 ? "subq" : "sbbq"} %%rax, ${i*8}(%0);"`);
}
this.c.push(" \"DONE:\"");
this.c.push(` :: "r" (r), "r" (a), "r" (b), "r" (${this.name}_q) : "%rax", "memory");`);
this.c.push("}\n");
}
_buildMul() {
let r0, r1, r2;
function setR(step) {
if ((step % 3) == 0) {
r0 = "%%r8";
r1 = "%%r9";
r2 = "%%r10";
} else if ((step % 3) == 1) {
r0 = "%%r9";
r1 = "%%r10";
r2 = "%%r8";
} else {
r0 = "%%r10";
r1 = "%%r8";
r2 = "%%r9";
}
}
const base = bigInt.one.shiftLeft(64);
const np64 = base.minus(this.q.modInv(base));
this.h.push(`void ${this.name}_mul(P${this.name}Element r, P${this.name}Element a, P${this.name}Element b);`);
this.c.push(`void ${this.name}_mul(P${this.name}Element r, P${this.name}Element a, P${this.name}Element b) {`);
this.c.push(" __asm__ __volatile__ (");
this.c.push(` "subq $${this.n64*8}, %%rsp;"`);
this.c.push(` "movq $0x${np64.toString(16)}, %%r11;"`);
this.c.push(" \"movq $0x0, %%r8;\"");
this.c.push(" \"movq $0x0, %%r9;\"");
this.c.push(" \"movq $0x0, %%r10;\"");
for (let i=0; i<this.n64*2; i++) {
setR(i);
for (let o1=Math.max(0, i-this.n64+1); (o1<=i)&&(o1<this.n64); o1++) {
const o2= i-o1;
this.c.push(` "movq ${o1*8}(%1), %%rax;"`);
this.c.push(` "mulq ${o2*8}(%2);"`);
this.c.push(` "addq %%rax, ${r0};"`);
this.c.push(` "adcq %%rdx, ${r1};"`);
this.c.push(` "adcq $0x0, ${r2};"`);
}
for (let j=i-1; j>=0; j--) {
if (((i-j)<this.n64)&&(j<this.n64)) {
this.c.push(` "movq ${j*8}(%%rsp), %%rax;"`);
this.c.push(` "mulq ${(i-j)*8}(%3);"`);
this.c.push(` "addq %%rax, ${r0};"`);
this.c.push(` "adcq %%rdx, ${r1};"`);
this.c.push(` "adcq $0x0, ${r2};"`);
}
}
if (i<this.n64) {
this.c.push(` "movq ${r0}, %%rax;"`);
this.c.push(" \"mulq %%r11;\"");
this.c.push(` "movq %%rax, ${i*8}(%%rsp);"`);
this.c.push(" \"mulq (%3);\"");
this.c.push(` "addq %%rax, ${r0};"`);
this.c.push(` "adcq %%rdx, ${r1};"`);
this.c.push(` "adcq $0x0, ${r2};"`);
} else {
this.c.push(` "movq ${r0}, ${(i-this.n64)*8}(%0);"`);
this.c.push(` "movq $0, ${r0};"`);
}
}
this.c.push(` "cmp $0, ${r1};"`);
this.c.push(" \"jne SQ2;\"");
for (let i=0; i<this.n64; i++) {
this.c.push(` "movq ${(this.n64 - i-1)*8}(%0), %%rax;"`);
this.c.push(` "cmp ${(this.n64 - i-1)*8}(%3), %%rax;"`);
this.c.push(" \"jg SQ2;\"");
this.c.push(" \"jl DONE2;\"");
}
this.c.push(" \"SQ2:\"");
for (let i=0; i<this.n64; i++) {
this.c.push(` "movq ${i*8}(%3), %%rax;"`);
this.c.push(` "${i==0 ? "subq" : "sbbq"} %%rax, ${i*8}(%0);"`);
}
this.c.push(" \"DONE2:\"");
this.c.push(` "addq $${this.n64*8}, %%rsp;"`);
this.c.push(` :: "r" (r), "r" (a), "r" (b), "r" (${this.name}_q) : "%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "memory");`);
this.c.push("}\n");
}
_buildIDiv() {
this.h.push(`void ${this.name}_idiv(P${this.name}Element r, P${this.name}Element a, P${this.name}Element b);`);
this.c.push(`void ${this.name}_idiv(P${this.name}Element r, P${this.name}Element a, P${this.name}Element b) {`);
this.c.push(" __asm__ __volatile__ (");
this.c.push(" \"pxor %%xmm0, %%xmm0;\""); // Comparison Register
if (this.n64 == 1) {
this.c.push(` "mov %%rax, $${this.n64 - 8};"`);
} else {
this.c.push(` "mov %%rax, $${this.n64 -16};"`);
}
this.c.push(` :: "r" (r), "r" (a), "r" (b), "r" (${this.name}_q) : "%rax", "%rdx", "%r8", "%r9", "%r10", "%r11", "memory");`);
this.c.push("}\n");
}
}
var runningAsScript = !module.parent;
if (runningAsScript) {
const fs = require("fs");
var argv = require("yargs")
.usage("Usage: $0 -q [primeNum] -n [name] -oc [out .c file] -oh [out .h file]")
.demandOption(["q","n"])
.alias("q", "prime")
.alias("n", "name")
.argv;
const q = bigInt(argv.q);
const cFileName = (argv.oc) ? argv.oc : argv.name.toLowerCase() + ".c";
const hFileName = (argv.oh) ? argv.oh : argv.name.toLowerCase() + ".h";
const builder = new ZqBuilder(q, argv.name);
const res = builder.build();
fs.writeFileSync(hFileName, res[0], "utf8");
fs.writeFileSync(cFileName, res[1], "utf8");
} else {
module.exports = function(q, name) {
const builder = new ZqBuilder(q, name);
return builder.build();
};
}

View File

@@ -0,0 +1,68 @@
const chai = require("chai");
const assert = chai.assert;
const fs = require("fs");
var tmp = require("tmp-promise");
const path = require("path");
const util = require("util");
const exec = util.promisify(require("child_process").exec);
const bigInt = require("big-integer");
const BuildZqField = require("./buildzqfield");
const ZqField = require("fflib").ZqField;
module.exports = testField;
function toMontgomeryStr(a, prime) {
const n64 = Math.floor((prime.bitLength() - 1) / 64)+1;
return a.shiftLeft(n64*64).mod(prime).toString(10);
}
function fromMontgomeryStr(a, prime) {
const n64 = Math.floor((prime.bitLength() - 1) / 64)+1;
const R = bigInt.one.shiftLeft(n64*64).mod(prime);
const RI = R.modInv(prime);
return bigInt(a).times(RI).mod(prime);
}
async function testField(prime, test) {
tmp.setGracefulCleanup();
const F = new ZqField(prime);
const dir = await tmp.dir({prefix: "circom_", unsafeCleanup: true });
const [hSource, cSource] = BuildZqField(prime, "Fr");
await fs.promises.writeFile(path.join(dir.path, "fr.h"), hSource, "utf8");
await fs.promises.writeFile(path.join(dir.path, "fr.c"), cSource, "utf8");
await exec("g++" +
` ${path.join(__dirname, "tester.c")}` +
` ${path.join(dir.path, "fr.c")}` +
` -o ${path.join(dir.path, "tester")}` +
" -lgmp"
);
for (let i=0; i<test.length; i++) {
let a = bigInt(test[i][1]).mod(prime);
if (a.isNegative()) a = prime.add(a);
let b = bigInt(test[i][2]).mod(prime);
if (b.isNegative()) b = prime.add(b);
const ec = F[test[i][0]](a,b);
// console.log(toMontgomeryStr(a, prime));
// console.log(toMontgomeryStr(b, prime));
const res = await exec(`${path.join(dir.path, "tester")}` +
` ${test[i][0]}` +
` ${toMontgomeryStr(a, prime)}` +
` ${toMontgomeryStr(b, prime)}`
);
// console.log(res.stdout);
const c=fromMontgomeryStr(res.stdout, prime);
assert.equal(ec.toString(), c.toString());
}
}

View File

@@ -0,0 +1,302 @@
global <%=name%>_add
global <%=name%>_mul
global <%=name%>_q
DEFAULT REL
section .text
;;;;;;;;;;;;;;;;;;;;;;
; add
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_add:
; Add component by component with carry
<% for (let i=0; i<n64; i++) { %>
mov rax, [rsi + <%=i*8%>]
<%= i==0 ? "add" : "adc" %> rax, [rdx + <%=i*8%>]
mov [rdi + <%=i*8%>], rax
<% } %>
jc add_sq ; if overflow, substract q
; Compare with q
<% for (let i=0; i<n64; i++) { %>
<% if (i>0) { %>
mov rax, [rdi + <%= (n64-i-1)*8 %>]
<% } %>
cmp rax, [q + <%= (n64-i-1)*8 %>]
jg add_sq
jl add_done
<% } %>
; If equal substract q
add_sq:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "sub" : "sbb" %> [rdi + <%=i*8%>], rax
mov [rdx + <%=i*8%>], rax
<% } %>
add_done:
ret
;;;;;;;;;;;;;;;;;;;;;;
; mul Montgomery
;;;;;;;;;;;;;;;;;;;;;;
mulM:
<%
let r0, r1, r2;
function setR(step) {
if ((step % 3) == 0) {
r0 = "r8";
r1 = "r9";
r2 = "r10";
} else if ((step % 3) == 1) {
r0 = "r9";
r1 = "r10";
r2 = "r8";
} else {
r0 = "r10";
r1 = "r8";
r2 = "r9";
}
}
const base = bigInt.one.shiftLeft(64);
const np64 = base.minus(q.modInv(base));
%>
sub rsp, <%= n64*8 %> ; Reserve space for ms
mov rcx, rdx ; rdx is needed for multiplications so keep it in cx
mov r11, 0x<%= np64.toString(16) %> ; np
xor r8,r8
xor r9,r9
xor r10,r10
<%
// Main loop
for (let i=0; i<n64*2; i++) {
setR(i);
%>
<%
// Same Digit
for (let o1=Math.max(0, i-n64+1); (o1<=i)&&(o1<n64); o1++) {
const o2= i-o1;
%>
mov rax, [rsi + <%= 8*o1 %>]
mul qword [rcx + <%= 8*o2 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} // Same digit
%>
<%
for (let j=i-1; j>=0; j--) { // All ms
if (((i-j)<n64)&&(j<n64)) {
%>
mov rax, [rsp + <%= j*8 %>]
mul qword [q + <%= (i-j)*8 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
}
} // ms
%>
<%
if (i<n64) {
%>
mov rax, <%= r0 %>
mul r11
mov [rsp + <%= i*8 %>], rax
mul qword [q]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} else {
%>
mov [rdi + <%= (i-n64)*8 %> ], <%= r0 %>
xor <%= r0 %>,<%= r0 %>
<%
}
%>
<%
} // Main Loop
%>
cmp <%= r1 %>, 0x0
jne mulM_sq
; Compare with q
<%
for (let i=0; i<n64; i++) {
%>
mov rax, [rdi + <%= (n64-i-1)*8 %>]
cmp rax, [q + <%= (n64-i-1)*8 %>]
jg mulM_sq
jl mulM_done
<%
}
%>
; If equal substract q
mulM_sq:
<%
for (let i=0; i<n64; i++) {
%>
mov rax, [q + <%= i*8 %>]
<%= i==0 ? "sub" : "sbb" %> [rdi + <%= i*8 %>], rax
mov [rdx + <%= i*8 %>], rax
<%
}
%>
mulM_done:
add rsp, <%= n64*8 %> ; recover rsp
ret
;;;;;;;;;;;;;;;;;;;;;;
; mul MontgomeryShort
;;;;;;;;;;;;;;;;;;;;;;
mulSM:
;;;;;;;;;;;;;;;;;;;;;;
; mul
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_mul:
mov rax, [rsi]
bt rax, 63
jc l1
mov rcx, [rdx]
bt rcx, 63
jc s1l2
s1s2: ; short first and second
mul ecx
jc rs2l ; If if doesn't feed in 32 bits convert the result to long
; The shorts multiplication is done. copy the val to destination and return
mov [rdi], rax
ret
rs2l: ; The result in the multiplication doen't feed
; we have the result in edx:eax we need to convert it to long
shl rdx, 32
mov edx, eax ; pack edx:eax to rdx
xor rax, rax ; Set the format to long
bts rax, 63
mov [rdi], rax ; move the first digit
cmp rdx, 0 ; check if redx is negative.
jl rs2ln
; edx is positive.
mov [rdi + 8], rdx ; Set the firs digit
xor rax, rax ; Set the remaining digits to 0
<% for (let i=1; i<n64; i++) { %>
mov [rdi + <%= (i+1)*8 %>], rax
<% } %>
ret
; edx is negative.
rs2ln:
add rdx, [q] ; Set the firs digit
mov [rdi + 8], rdx ;
mov rdx, -1 ; all ones
<% for (let i=1; i<n64; i++) { %>
mov rax, rdx ; Add to q
adc rax, [q + <%= i*8 %> ]
mov [rdi + <%= (i+1)*8 %>], rax
<% } %>
ret
l1:
mov rcx, [rdx]
bt rcx, 63
jc ll
l1s2:
xor rdx, rdx
mov edx, ecx
bt rax, 62
jc lsM
jmp lsN
s1l2:
mov rsi, rdx
xor rdx, rdx
mov edx, eax
bt rcx, 62
jc lsM
jmp lsN
lsN:
mov byte [rdi + 3], 0xC0 ; set the result to montgomery
add rsi, 8
add rdi, 8
call mulSM
mov rdx, R3
call mulM
ret
lsM:
mov byte [rdi + 3], 0x80 ; set the result to long normal
add rsi, 8
add rdi, 8
call mulSM
ret
ll:
bt rax, 62
jc lml
bt rcx, 62
jc lnlm
lnln:
mov byte [rdi + 3], 0xC0 ; set the result to long montgomery
add rsi, 8
add rdi, 8
add rdx, 8
call mulM
mov rdi, rsi
mov rdx, R3
call mulM
ret
lml:
bt rcx, 62
jc lmlm
lnlm:
mov byte [rdi + 3], 0x80 ; set the result to long normal
add rsi, 8
add rdi, 8
add rdx, 8
call mulM
ret
lmlm:
mov byte [rdi + 3], 0xC0 ; set the result to long montgomery
add rsi, 8
add rdi, 8
add rdx, 8
call mulM
ret
section .data
<%=name%>_q:
dd 0
dd 0x80000000
q dq <%= constantElement(q) %>
R3 dq <%= constantElement(bigInt.one.shiftLeft(n64*64*3).mod(q)) %>

View File

@@ -0,0 +1,251 @@
;;;;;;;;;;;;;;;;;;;;;;
; mul Montgomery
;;;;;;;;;;;;;;;;;;;;;;
mulM:
<%
let r0, r1, r2;
function setR(step) {
if ((step % 3) == 0) {
r0 = "r8";
r1 = "r9";
r2 = "r10";
} else if ((step % 3) == 1) {
r0 = "r9";
r1 = "r10";
r2 = "r8";
} else {
r0 = "r10";
r1 = "r8";
r2 = "r9";
}
}
const base = bigInt.one.shiftLeft(64);
const np64 = base.minus(q.modInv(base));
%>
sub rsp, <%= n64*8 %> ; Reserve space for ms
mov rcx, rdx ; rdx is needed for multiplications so keep it in cx
mov r11, 0x<%= np64.toString(16) %> ; np
xor r8,r8
xor r9,r9
xor r10,r10
<%
// Main loop
for (let i=0; i<n64*2; i++) {
setR(i);
%>
<%
// Same Digit
for (let o1=Math.max(0, i-n64+1); (o1<=i)&&(o1<n64); o1++) {
const o2= i-o1;
%>
mov rax, [rsi + <%= 8*o1 %>]
mul qword [rcx + <%= 8*o2 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} // Same digit
%>
<%
for (let j=i-1; j>=0; j--) { // All ms
if (((i-j)<n64)&&(j<n64)) {
%>
mov rax, [rsp + <%= j*8 %>]
mul qword [q + <%= (i-j)*8 %>]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
}
} // ms
%>
<%
if (i<n64) {
%>
mov rax, <%= r0 %>
mul r11
mov [rsp + <%= i*8 %>], rax
mul qword [q]
add <%= r0 %>, rax
adc <%= r1 %>, rdx
adc <%= r2 %>, 0x0
<%
} else {
%>
mov [rdi + <%= (i-n64)*8 %> ], <%= r0 %>
xor <%= r0 %>,<%= r0 %>
<%
}
%>
<%
} // Main Loop
%>
cmp <%= r1 %>, 0x0
jne mulM_sq
; Compare with q
<%
for (let i=0; i<n64; i++) {
%>
mov rax, [rdi + <%= (n64-i-1)*8 %>]
cmp rax, [q + <%= (n64-i-1)*8 %>]
jg mulM_sq
jl mulM_done
<%
}
%>
; If equal substract q
mulM_sq:
<%
for (let i=0; i<n64; i++) {
%>
mov rax, [q + <%= i*8 %>]
<%= i==0 ? "sub" : "sbb" %> [rdi + <%= i*8 %>], rax
<%
}
%>
mulM_done:
add rsp, <%= n64*8 %> ; recover rsp
ret
;;;;;;;;;;;;;;;;;;;;;;
; mul MontgomeryShort
;;;;;;;;;;;;;;;;;;;;;;
mulSM:
;;;;;;;;;;;;;;;;;;;;;;
; mul
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_mul:
mov rax, [rsi]
bt rax, 63
jc l1
mov rcx, [rdx]
bt rcx, 63
jc s1l2
s1s2: ; short first and second
mul ecx
jc rs2l ; If if doesn't feed in 32 bits convert the result to long
; The shorts multiplication is done. copy the val to destination and return
mov [rdi], rax
ret
rs2l: ; The result in the multiplication doen't feed
; we have the result in edx:eax we need to convert it to long
shl rdx, 32
mov edx, eax ; pack edx:eax to rdx
xor rax, rax ; Set the format to long
bts rax, 63
mov [rdi], rax ; move the first digit
cmp rdx, 0 ; check if redx is negative.
jl rs2ln
; edx is positive.
mov [rdi + 8], rdx ; Set the firs digit
xor rax, rax ; Set the remaining digits to 0
<% for (let i=1; i<n64; i++) { %>
mov [rdi + <%= (i+1)*8 %>], rax
<% } %>
ret
; edx is negative.
rs2ln:
add rdx, [q] ; Set the firs digit
mov [rdi + 8], rdx ;
mov rdx, -1 ; all ones
<% for (let i=1; i<n64; i++) { %>
mov rax, rdx ; Add to q
adc rax, [q + <%= i*8 %> ]
mov [rdi + <%= (i+1)*8 %>], rax
<% } %>
ret
l1:
mov rcx, [rdx]
bt rcx, 63
jc ll
l1s2:
xor rdx, rdx
mov edx, ecx
bt rax, 62
jc lsM
jmp lsN
s1l2:
mov rsi, rdx
xor rdx, rdx
mov edx, eax
bt rcx, 62
jc lsM
jmp lsN
lsN:
mov byte [rdi + 7], 0xC0 ; set the result to montgomery
add rsi, 8
add rdi, 8
call mulSM
mov rsi, rdi
lea rdx, [R3]
call mulM
ret
lsM:
mov byte [rdi + 7], 0x80 ; set the result to long normal
add rsi, 8
add rdi, 8
call mulSM
ret
ll:
bt rax, 62
jc lml
bt rcx, 62
jc lnlm
lnln:
mov byte [rdi + 7], 0xC0 ; set the result to long montgomery
add rsi, 8
add rdi, 8
add rdx, 8
call mulM
mov rsi, rdi
lea rdx, [R3]
call mulM
ret
lml:
bt rcx, 62
jc lmlm
lnlm:
mov byte [rdi + 7], 0x80 ; set the result to long normal
add rsi, 8
add rdi, 8
add rdx, 8
call mulM
ret
lmlm:
mov byte [rdi + 7], 0xC0 ; set the result to long montgomery
add rsi, 8
add rdi, 8
add rdx, 8
call mulM
ret

View File

@@ -0,0 +1,317 @@
<% function subS1S2() { %>
xor rdx, rdx
mov edx, eax
sub edx, ecx
jo sub_manageOverflow ; rsi already is the 64bits result
mov [rdi], rdx ; not necessary to adjust so just save and return
ret
sub_manageOverflow: ; Do the operation in 64 bits
push rsi
movsx rsi, eax
movsx rdx, ecx
sub rsi, rdx
call rawCopyS2L
pop rsi
ret
<% } %>
<% function subL1S2(t) { %>
add rsi, 8
movsx rdx, ecx
add rdi, 8
cmp rdx, 0
<% const rawSubLabel = global.tmpLabel() %>
jns <%= rawSubLabel %>
neg rdx
call rawAddLS
sub rdi, 8
sub rsi, 8
ret
<%= rawSubLabel %>:
call rawSubLS
sub rdi, 8
sub rsi, 8
ret
<% } %>
<% function subS1L2(t) { %>
cmp eax, 0
<% const s1NegLabel = global.tmpLabel() %>
js <%= s1NegLabel %>
; First Operand is positive
push rsi
add rdi, 8
movsx rsi, eax
add rdx, 8
call rawSubSL
sub rdi, 8
pop rsi
ret
<%= s1NegLabel %>: ; First operand is negative
push rsi
lea rsi, [rdx + 8]
movsx rdx, eax
add rdi, 8
neg rdx
call rawNegLS
sub rdi, 8
pop rsi
ret
<% } %>
<% function subL1L2(t) { %>
add rdi, 8
add rsi, 8
add rdx, 8
call rawSubLL
sub rdi, 8
sub rsi, 8
ret
<% } %>
;;;;;;;;;;;;;;;;;;;;;;
; sub
;;;;;;;;;;;;;;;;;;;;;;
; Substracts two elements of any kind
; Params:
; rsi <= Pointer to element 1
; rdx <= Pointer to element 2
; rdi <= Pointer to result
; Modified Registers:
; r8, r9, 10, r11, rax, rcx
;;;;;;;;;;;;;;;;;;;;;;
<%=name%>_sub:
mov rax, [rsi]
mov rcx, [rdx]
bt rax, 63 ; Check if is long first operand
jc sub_l1
bt rcx, 63 ; Check if is long second operand
jc sub_s1l2
sub_s1s2: ; Both operands are short
<%= subS1S2() %>
sub_l1:
bt rcx, 63 ; Check if is short second operand
jc sub_l1l2
;;;;;;;;
sub_l1s2:
bt rax, 62 ; check if montgomery first
jc sub_l1ms2
sub_l1ns2:
<%= global.setTypeDest("0x80"); %>
<%= subL1S2(); %>
sub_l1ms2:
bt rcx, 62 ; check if montgomery second
jc sub_l1ms2m
sub_l1ms2n:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_b() %>
<%= subL1L2() %>
sub_l1ms2m:
<%= global.setTypeDest("0xC0"); %>
<%= subL1L2() %>
;;;;;;;;
sub_s1l2:
bt rcx, 62 ; check if montgomery first
jc sub_s1l2m
sub_s1l2n:
<%= global.setTypeDest("0x80"); %>
<%= subS1L2(); %>
sub_s1l2m:
bt rax, 62 ; check if montgomery second
jc sub_s1ml2m
sub_s1nl2m:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_a() %>
<%= subL1L2() %>
sub_s1ml2m:
<%= global.setTypeDest("0xC0"); %>
<%= subL1L2() %>
;;;;
sub_l1l2:
bt rax, 62 ; check if montgomery first
jc sub_l1ml2
sub_l1nl2:
bt rcx, 62 ; check if montgomery second
jc sub_l1nl2m
sub_l1nl2n:
<%= global.setTypeDest("0x80"); %>
<%= subL1L2() %>
sub_l1nl2m:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_a(); %>
<%= subL1L2() %>
sub_l1ml2:
bt rcx, 62 ; check if montgomery seconf
jc sub_l1ml2m
sub_l1ml2n:
<%= global.setTypeDest("0xC0"); %>
<%= global.toMont_b(); %>
<%= subL1L2() %>
sub_l1ml2m:
<%= global.setTypeDest("0xC0"); %>
<%= subL1L2() %>
;;;;;;;;;;;;;;;;;;;;;;
; rawSubLS
;;;;;;;;;;;;;;;;;;;;;;
; Substracts a short element from the long element
; Params:
; rdi <= Pointer to the long data of result
; rsi <= Pointer to the long data of element 1 where will be substracted
; rdx <= Value to be substracted
; [rdi] = [rsi] - rdx
; Modified Registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;
rawSubLS:
; Substract first digit
mov rax, [rsi]
sub rax, rdx
mov [rdi] ,rax
mov rdx, 0
<% for (let i=1; i<n64; i++) { %>
mov rax, [rsi + <%=i*8%>]
sbb rax, rdx
mov [rdi + <%=i*8%>], rax
<% } %>
jnc rawSubLS_done ; if overflow, add q
; Add q
rawSubLS_aq:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "add" : "adc" %> [rdi + <%=i*8%>], rax
<% } %>
rawSubLS_done:
ret
;;;;;;;;;;;;;;;;;;;;;;
; rawSubSL
;;;;;;;;;;;;;;;;;;;;;;
; Substracts a long element from a short element
; Params:
; rdi <= Pointer to the long data of result
; rsi <= Value from where will bo substracted
; rdx <= Pointer to long of the value to be substracted
;
; [rdi] = rsi - [rdx]
; Modified Registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;
rawSubSL:
; Substract first digit
sub rsi, [rdx]
mov [rdi] ,rsi
<% for (let i=1; i<n64; i++) { %>
mov rax, 0
sbb rax, [rdx + <%=i*8%>]
mov [rdi + <%=i*8%>], rax
<% } %>
jnc rawSubSL_done ; if overflow, add q
; Add q
rawSubSL_aq:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "add" : "adc" %> [rdi + <%=i*8%>], rax
<% } %>
rawSubSL_done:
ret
;;;;;;;;;;;;;;;;;;;;;;
; rawSubLL
;;;;;;;;;;;;;;;;;;;;;;
; Substracts a long element from a short element
; Params:
; rdi <= Pointer to the long data of result
; rsi <= Pointer to long from where substracted
; rdx <= Pointer to long of the value to be substracted
;
; [rdi] = [rsi] - [rdx]
; Modified Registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;
rawSubLL:
; Substract first digit
<% for (let i=0; i<n64; i++) { %>
mov rax, [rsi + <%=i*8%>]
<%= i==0 ? "sub" : "sbb" %> rax, [rdx + <%=i*8%>]
mov [rdi + <%=i*8%>], rax
<% } %>
jnc rawSubLL_done ; if overflow, add q
; Add q
rawSubLL_aq:
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "add" : "adc" %> [rdi + <%=i*8%>], rax
<% } %>
rawSubLL_done:
ret
;;;;;;;;;;;;;;;;;;;;;;
; rawNegLS
;;;;;;;;;;;;;;;;;;;;;;
; Substracts a long element and a short element form 0
; Params:
; rdi <= Pointer to the long data of result
; rsi <= Pointer to long from where substracted
; rdx <= short value to be substracted too
;
; [rdi] = -[rsi] - rdx
; Modified Registers:
; rax
;;;;;;;;;;;;;;;;;;;;;;
rawNegLS:
mov rax, [q]
sub rax, rdx
mov [rdi], rax
<% for (let i=1; i<n64; i++) { %>
mov rax, [q + <%=i*8%> ]
sbb rax, 0
mov [rdi + <%=i*8%>], rax
<% } %>
setc dl
<% for (let i=0; i<n64; i++) { %>
mov rax, [rdi + <%=i*8%> ]
<%= i==0 ? "sub" : "sbb" %> rax, [rsi + <%=i*8%>]
mov [rdi + <%=i*8%>], rax
<% } %>
setc dh
or dl, dh
jz rawNegSL_done
; it is a negative value, so add q
<% for (let i=0; i<n64; i++) { %>
mov rax, [q + <%=i*8%>]
<%= i==0 ? "add" : "adc" %> [rdi + <%=i*8%>], rax
<% } %>
rawNegSL_done:
ret

BIN
ports/c/buildasm/tester Executable file

Binary file not shown.

218
ports/c/buildasm/tester.cpp Normal file
View File

@@ -0,0 +1,218 @@
#include <string>
#include <iostream>
#include <regex>
#include <string>
#include <iostream>
#include <stdexcept>
#include <sstream>
#include <stdio.h> /* printf, NULL */
#include <stdlib.h>
#include <cassert>
#include "fr.h"
typedef void (*Func1)(PFrElement, PFrElement);
typedef void (*Func2)(PFrElement, PFrElement, PFrElement);
typedef void *FuncAny;
typedef struct {
FuncAny fn;
int nOps;
} FunctionSpec;
std::map<std::string, FunctionSpec> functions;
std::vector<FrElement> stack;
void addFunction(std::string name, FuncAny f, int nOps) {
FunctionSpec fs;
fs.fn = f;
fs.nOps = nOps;
functions[name] = fs;
}
void fillMap() {
addFunction("add", (FuncAny)Fr_add, 2);
addFunction("sub", (FuncAny)Fr_sub, 2);
addFunction("neg", (FuncAny)Fr_neg, 1);
addFunction("mul", (FuncAny)Fr_mul, 2);
addFunction("square", (FuncAny)Fr_square, 1);
addFunction("idiv", (FuncAny)Fr_idiv, 2);
addFunction("inv", (FuncAny)Fr_inv, 1);
addFunction("div", (FuncAny)Fr_div, 2);
addFunction("band", (FuncAny)Fr_band, 2);
addFunction("bor", (FuncAny)Fr_bor, 2);
addFunction("bxor", (FuncAny)Fr_bxor, 2);
addFunction("bnot", (FuncAny)Fr_bnot, 1);
addFunction("eq", (FuncAny)Fr_eq, 2);
addFunction("neq", (FuncAny)Fr_neq, 2);
addFunction("lt", (FuncAny)Fr_lt, 2);
addFunction("gt", (FuncAny)Fr_gt, 2);
addFunction("leq", (FuncAny)Fr_leq, 2);
addFunction("geq", (FuncAny)Fr_geq, 2);
addFunction("land", (FuncAny)Fr_land, 2);
addFunction("lor", (FuncAny)Fr_lor, 2);
addFunction("lnot", (FuncAny)Fr_lnot, 1);
}
u_int64_t readInt(std::string &s) {
if (s.rfind("0x", 0) == 0) {
return std::stoull(s.substr(2), 0, 16);
} else {
return std::stoull(s, 0, 10);
}
}
void pushNumber(std::vector<std::string> &v) {
u_int64_t a;
if ((v.size()<1) || (v.size() > (Fr_N64+1))) {
printf("Invalid Size: %d - %d \n", v.size(), Fr_N64);
throw std::runtime_error("Invalid number of parameters for number");
}
FrElement e;
a = readInt(v[0]);
*(u_int64_t *)(&e) = a;
for (int i=0; i<Fr_N64; i++) {
if (i+1 < v.size()) {
a = readInt(v[i+1]);
} else {
a = 0;
}
e.longVal[i] = a;
}
stack.push_back(e);
}
void callFunction(FunctionSpec fs) {
if (stack.size() < fs.nOps) {
throw new std::runtime_error("Not enough elements in stack");
}
if (fs.nOps == 1) {
FrElement a = stack.back();
stack.pop_back();
FrElement c;
(*(Func1)fs.fn)(&c, &a);
stack.push_back(c);
} else if (fs.nOps == 2) {
FrElement b = stack.back();
stack.pop_back();
FrElement a = stack.back();
stack.pop_back();
FrElement c;
(*(Func2)fs.fn)(&c, &a, &b);
stack.push_back(c);
} else {
assert(false);
}
}
void processLine(std::string &line) {
std::regex re("(\\s*[,;]\\s*)|\\s+"); // whitespace
std::sregex_token_iterator begin( line.begin(), line.end(), re ,-1);
std::sregex_token_iterator end;
std::vector<std::string> tokens;
std::copy(begin, end, std::back_inserter(tokens));
// Remove initial empty tokens
while ((tokens.size() > 0)&&(tokens[0] == "")) {
tokens.erase(tokens.begin());
}
// Empty lines are valid but are not processed
if (tokens.size() == 0) return;
auto search = functions.find(tokens[0]);
if (search == functions.end()) {
pushNumber(tokens);
} else {
if (tokens.size() != 1) {
throw std::runtime_error("Functions does not accept parameters");
}
callFunction(search->second);
}
}
int main(void)
{
Fr_init();
fillMap();
std::string line;
int i=0;
while (std::getline(std::cin, line)) {
processLine(line);
// if (i%1000 == 0) printf("%d\n", i);
// printf("%d\n", i);
i++;
}
// Print the elements in the stack
//
for (int i=0; i<stack.size(); i++) {
char *s;
s = Fr_element2str(&stack[i]);
printf("%s\n", s);
free(s);
}
return EXIT_SUCCESS;
}
/*
#include <stdlib.h>
#include <string.h>
#include "fr.h"
typedef void (*Func2)(PFrElement, PFrElement, PFrElement);
typedef struct {
const char *fnName;
Func2 fn;
} FN;
#define NFN 2
FN fns[NFN] = {
{"add", Fr_add},
{"mul", Fr_mul},
};
int main(int argc, char **argv) {
if (argc <= 1) {
fprintf( stderr, "invalid number of parameters");
return 1;
}
for (int i=0; i< NFN;i++) {
if (strcmp(argv[1], fns[i].fnName) == 0) {
if (argc != 4) {
fprintf( stderr, "invalid number of parameters");
return 1;
}
FrElement a;
FrElement b;
Fr_str2element(&a, argv[2]);
Fr_str2element(&b, argv[3]);
FrElement c;
fns[i].fn(&c, &a, &b);
char *s;
s = Fr_element2str(&c);
printf("%s", s);
free(s);
return 0;
}
}
fprintf( stderr, "invalid operation %s", argv[1]);
return 1;
}
*/

View File

@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple Computer//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>English</string>
<key>CFBundleIdentifier</key>
<string>com.apple.xcode.dsym.tester</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundlePackageType</key>
<string>dSYM</string>
<key>CFBundleSignature</key>
<string>????</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleVersion</key>
<string>1</string>
</dict>
</plist>

View File

@@ -0,0 +1,72 @@
<% global.setTypeDest = function (t) {
return (
` mov r11b, ${t}
shl r11, 56
mov [rdi], r11`);
} %>
<% global.toMont_a = function () {
return (
` push rdi
mov rdi, rsi
mov rsi, rdx
call ${name}_toMontgomery
mov rdx, rsi
mov rsi, rdi
pop rdi`);
} %>
<% global.toMont_b = function() {
return (
` push rdi
mov rdi, rdx
call ${name}_toMontgomery
mov rdx, rdi
pop rdi`);
} %>
<% global.fromMont_a = function () {
return (
` push rdi
mov rdi, rsi
mov rsi, rdx
call ${name}_toNormal
mov rdx, rsi
mov rsi, rdi
pop rdi`);
} %>
<% global.fromMont_b = function() {
return (
` push rdi
mov rdi, rdx
call ${name}_toNormal
mov rdx, rdi
pop rdi`);
} %>
<% global.toLong_a = function () {
return (
` push rdi
push rdx
mov rdi, rsi
movsx rsi, r8d
call rawCopyS2L
mov rsi, rdi
pop rdx
pop rdi`);
} %>
<% global.toLong_b = function() {
return (
` push rdi
push rsi
mov rdi, rdx
movsx rsi, r9d
call rawCopyS2L
mov rdx, rdi
pop rsi
pop rdi`);
} %>

627
ports/c/builder.js Normal file
View File

@@ -0,0 +1,627 @@
const streamFromMultiArray = require("../../src/streamfromarray_txt.js");
const bigInt = require("big-integer");
const utils = require("../../src/utils");
const assert = require("assert");
function ref2src(c) {
if ((c[0] == "R")||(c[0] == "RI")) {
return c[1];
} else if (c[0] == "V") {
return c[1].toString();
} else if (c[0] == "C") {
return `(ctx->circuit->constants + ${c[1]})`;
} else if (c[0] == "CC") {
return "__cIdx";
} else {
assert(false);
}
}
class CodeBuilderC {
constructor() {
this.ops = [];
}
addComment(comment) {
this.ops.push({op: "COMMENT", comment});
}
addBlock(block) {
this.ops.push({op: "BLOCK", block});
}
calcOffset(dLabel, offsets) {
this.ops.push({op: "CALCOFFSETS", dLabel, offsets});
}
assign(dLabel, src, sOffset) {
this.ops.push({op: "ASSIGN", dLabel, src, sOffset});
}
getSubComponentOffset(dLabel, component, hash, hashLabel) {
this.ops.push({op: "GETSUBCOMPONENTOFFSET", dLabel, component, hash, hashLabel});
}
getSubComponentSizes(dLabel, component, hash, hashLabel) {
this.ops.push({op: "GETSUBCOMPONENTSIZES", dLabel, component, hash, hashLabel});
}
getSignalOffset(dLabel, component, hash, hashLabel) {
this.ops.push({op: "GETSIGNALOFFSET", dLabel, component, hash, hashLabel});
}
getSignalSizes(dLabel, component, hash, hashLabel) {
this.ops.push({op: "GETSIGNALSIZES", dLabel, component, hash, hashLabel});
}
setSignal(component, signal, value) {
this.ops.push({op: "SETSIGNAL", component, signal, value});
}
getSignal(dLabel, component, signal) {
this.ops.push({op: "GETSIGNAL", dLabel, component, signal});
}
copyN(dLabel, offset, src, n) {
this.ops.push({op: "COPYN", dLabel, offset, src, n});
}
copyNRet(src, n) {
this.ops.push({op: "COPYNRET", src, n});
}
fieldOp(dLabel, fOp, params) {
this.ops.push({op: "FOP", dLabel, fOp, params});
}
ret() {
this.ops.push({op: "RET"});
}
addLoop(condLabel, body) {
this.ops.push({op: "LOOP", condLabel, body});
}
addIf(condLabel, thenCode, elseCode) {
this.ops.push({op: "IF", condLabel, thenCode, elseCode});
}
fnCall(fnName, retLabel, params) {
this.ops.push({op: "FNCALL", fnName, retLabel, params});
}
checkConstraint(a, b, strErr) {
this.ops.push({op: "CHECKCONSTRAINT", a, b, strErr});
}
log(val) {
this.ops.push({op: "LOG", val});
}
concat(cb) {
this.ops.push(...cb.ops);
}
hasCode() {
for (let i=0; i<this.ops.length; i++) {
if (this.ops[i].op != "COMMENT") return true;
}
return false;
}
_buildOffset(offsets) {
let rN=0;
let S = "";
offsets.forEach((o) => {
if ((o[0][0] == "V") && (o[1][0]== "V")) {
rN += o[0][1]*o[1][1];
return;
}
let f="";
if (o[0][0] == "V") {
if (o[0][1]==0) return;
f += o[0][1];
} else if (o[0][0] == "RI") {
if (o[0][1]==0) return;
f += o[0][1];
} else if (o[0][0] == "R") {
f += `Fr_toInt(${o[0][1]})`;
} else {
assert(false);
}
if (o[1][0] == "V") {
if (o[1][1]==0) return;
if (o[1][1]>1) {
f += "*" + o[1][1];
}
} else if (o[1][0] == "RS") {
f += `*${o[1][1]}[${o[1][2]}]`;
} else {
assert(false);
}
if (S!="") S+= " + ";
S += f;
});
if (rN>0) {
S = `${rN} + ${S}`;
}
return S;
}
build(code) {
this.ops.forEach( (o) => {
if (o.op == "COMMENT") {
code.push(`/* ${o.comment} */`);
} else if (o.op == "BLOCK") {
const codeBlock=[];
o.block.build(codeBlock);
code.push(utils.ident(codeBlock));
} else if (o.op == "CALCOFFSETS") {
code.push(`${o.dLabel} = ${this._buildOffset(o.offsets)};`);
} else if (o.op == "ASSIGN") {
const oS = ref2src(o.sOffset);
if (oS != "0") {
code.push(`${o.dLabel} = ${ref2src(o.src)} + ${oS};`);
} else {
code.push(`${o.dLabel} = ${ref2src(o.src)};`);
}
} else if (o.op == "GETSUBCOMPONENTOFFSET") {
code.push(`${o.dLabel} = ctx->getSubComponentOffset(${ref2src(o.component)}, 0x${o.hash}LL /* ${o.hashLabel} */);`);
} else if (o.op == "GETSUBCOMPONENTSIZES") {
code.push(`${o.dLabel} = ctx->getSubComponentSizes(${ref2src(o.component)}, 0x${o.hash}LL /* ${o.hashLabel} */);`);
} else if (o.op == "GETSIGNALOFFSET") {
code.push(`${o.dLabel} = ctx->getSignalOffset(${ref2src(o.component)}, 0x${o.hash}LL /* ${o.hashLabel} */);`);
} else if (o.op == "GETSIGNALSIZES") {
code.push(`${o.dLabel} = ctx->getSignalSizes(${ref2src(o.component)}, 0x${o.hash}LL /* ${o.hashLabel} */);`);
} else if (o.op == "SETSIGNAL") {
code.push(`ctx->setSignal(__cIdx, ${ref2src(o.component)}, ${ref2src(o.signal)}, ${ref2src(o.value)});`);
} else if (o.op == "GETSIGNAL") {
code.push(`ctx->getSignal(__cIdx, ${ref2src(o.component)}, ${ref2src(o.signal)}, ${o.dLabel});`);
} else if (o.op == "COPYN") {
const oS = ref2src(o.offset);
const dLabel = (oS != "0") ? (o.dLabel + "+" + oS) : o.dLabel;
code.push(`Fr_copyn(${dLabel}, ${ref2src(o.src)}, ${o.n});`);
} else if (o.op == "COPYNRET") {
code.push(`Fr_copyn(__retValue, ${ref2src(o.src)}, ${o.n});`);
} else if (o.op == "RET") {
code.push("goto returnFunc;");
} else if (o.op == "FOP") {
let paramsS = "";
for (let i=0; i<o.params.length; i++) {
if (i>0) paramsS += ", ";
paramsS += ref2src(o.params[i]);
}
code.push(`Fr_${o.fOp}(${o.dLabel}, ${paramsS});`);
} else if (o.op == "LOOP") {
code.push(`while (Fr_isTrue(${o.condLabel})) {`);
const body = [];
o.body.build(body);
code.push(utils.ident(body));
code.push("}");
} else if (o.op == "IF") {
code.push(`if (Fr_isTrue(${o.condLabel})) {`);
const thenCode = [];
o.thenCode.build(thenCode);
code.push(utils.ident(thenCode));
if (o.elseCode) {
code.push("} else {");
const elseCode = [];
o.elseCode.build(elseCode);
code.push(utils.ident(elseCode));
}
code.push("}");
} else if (o.op == "FNCALL") {
code.push(`${o.fnName}(ctx, ${o.retLabel}, ${o.params.join(",")});`);
} else if (o.op == "CHECKCONSTRAINT") {
code.push(`ctx->checkConstraint(__cIdx, ${ref2src(o.a)}, ${ref2src(o.b)}, "${o.strErr}");`);
} else if (o.op == "LOG") {
code.push(`ctx->log(${ref2src(o.val)});`);
}
});
}
}
class FunctionBuilderC {
constructor(name, instanceDef, type) {
this.name = name;
this.instanceDef = instanceDef;
this.type = type; // "COMPONENT" or "FUNCTIOM"
this.definedFrElements = [];
this.definedIntElements = [];
this.definedSizeElements = [];
this.definedPFrElements = [];
this.initializedElements = [];
this.initializedSignalOffset = [];
this.initializedSignalSizes = [];
}
defineFrElements(dLabel, size) {
this.definedFrElements.push({dLabel, size});
}
defineIntElement(dLabel) {
this.definedIntElements.push({dLabel});
}
defineSizesElement(dLabel) {
this.definedSizeElements.push({dLabel});
}
definePFrElement(dLabel) {
this.definedPFrElements.push({dLabel});
}
initializeFrElement(dLabel, offset, idConstant) {
this.initializedElements.push({dLabel, offset, idConstant});
}
initializeSignalOffset(dLabel, component, hash, hashLabel) {
this.initializedSignalOffset.push({dLabel, component, hash, hashLabel});
}
initializeSignalSizes(dLabel, component, hash, hashLabel) {
this.initializedSignalSizes.push({dLabel, component, hash, hashLabel});
}
setParams(params) {
this.params = params;
}
_buildHeader(code) {
this.definedFrElements.forEach( (o) => {
code.push(`FrElement ${o.dLabel}[${o.size}];`);
});
this.definedIntElements.forEach( (o) => {
code.push(`int ${o.dLabel};`);
});
this.definedSizeElements.forEach( (o) => {
code.push(`Circom_Sizes ${o.dLabel};`);
});
this.definedPFrElements.forEach( (o) => {
code.push(`PFrElement ${o.dLabel};`);
});
this.initializedElements.forEach( (o) => {
code.push(`Fr_copy(&(${o.dLabel}[${o.offset}]), ctx->circuit->constants +${o.idConstant});`);
});
this.initializedSignalOffset.forEach( (o) => {
code.push(`${o.dLabel} = ctx->getSignalOffset(${ref2src(o.component)}, 0x${o.hash}LL /* ${o.hashLabel} */);`);
});
this.initializedSignalSizes.forEach( (o) => {
code.push(`${o.dLabel} = ctx->getSignalSizes(${ref2src(o.component)}, 0x${o.hash}LL /* ${o.hashLabel} */);`);
});
}
_buildFooter(code) {
}
newCodeBuilder() {
return new CodeBuilderC();
}
setBody(body) {
this.body = body;
}
build(code) {
code.push(
"/*",
this.instanceDef,
"*/"
);
if (this.type=="COMPONENT") {
code.push(`void ${this.name}(Circom_CalcWit *ctx, int __cIdx) {`);
} else if (this.type=="FUNCTION") {
let sParams = "";
for (let i=0;i<this.params.length;i++ ) sParams += `, PFrElement ${this.params[i]}`;
code.push(`void ${this.name}(Circom_CalcWit *ctx, PFrElement __retValue ${sParams}) {`);
} else {
assert(false);
}
const fnCode = [];
this._buildHeader(fnCode);
this.body.build(fnCode);
if (this.type=="COMPONENT") {
fnCode.push("ctx->finished(__cIdx);");
} else if (this.type=="FUNCTION") {
fnCode.push("returnFunc: ;");
} else {
assert(false);
}
this._buildFooter(fnCode);
code.push(utils.ident(fnCode));
code.push("}");
}
}
class BuilderC {
constructor() {
this.hashMaps={};
this.componentEntriesTables={};
this.sizes ={};
this.constants = [];
this.functions = [];
this.components = [];
this.usedConstants = {};
}
setHeader(header) {
this.header=header;
}
// ht is an array of 256 element that can be undefined or [Hash, Idx, KeyName] elements.
addHashMap(name, hm) {
this.hashMaps[name] = hm;
}
addComponentEntriesTable(name, cet) {
this.componentEntriesTables[name] = cet;
}
addSizes(name, accSizes) {
this.sizes[name] = accSizes;
}
addConstant(c) {
c = bigInt(c);
const cS = c.toString();
if (this.usedConstants[cS]) return this.usedConstants[cS];
this.constants.push(c);
this.usedConstants[cS] = this.constants.length - 1;
return this.constants.length - 1;
}
addFunction(fnBuilder) {
this.functions.push(fnBuilder);
}
addComponent(component) {
this.components.push(component);
}
setMapIsInput(map) {
this.mapIsInput = map;
}
setWit2Sig(wit2sig) {
this.wit2sig = wit2sig;
}
newComponentFunctionBuilder(name, instanceDef) {
return new FunctionBuilderC(name, instanceDef, "COMPONENT");
}
newFunctionBuilder(name, instanceDef) {
return new FunctionBuilderC(name, instanceDef, "FUNCTION");
}
// Body functions
_buildHeader(code) {
code.push(
"#include \"circom.h\"",
"#include \"calcwit.h\"",
`#define NSignals ${this.header.NSignals}`,
`#define NComponents ${this.header.NComponents}`,
`#define NOutputs ${this.header.NOutputs}`,
`#define NInputs ${this.header.NInputs}`,
`#define NVars ${this.header.NVars}`,
`#define __P__ "${this.header.P.toString()}"`,
""
);
}
_buildHashMaps(code) {
code.push("// Hash Maps ");
for (let hmName in this.hashMaps ) {
const hm = this.hashMaps[hmName];
let c = `Circom_HashEntry ${hmName}[256] = {`;
for (let i=0; i<256; i++) {
c += i>0 ? "," : "";
if (hm[i]) {
c += `{0x${hm[i][0]}LL, ${hm[i][1]}} /* ${hm[i][2]} */`;
} else {
c += "{0,0}";
}
}
c += "};";
code.push(c);
}
}
_buildComponentEntriesTables(code) {
code.push("// Component Entry tables");
for (let cetName in this.componentEntriesTables) {
const cet = this.componentEntriesTables[cetName];
code.push(`Circom_ComponentEntry ${cetName}[${cet.length}] = {`);
for (let j=0; j<cet.length; j++) {
const ty = cet[j].type == "S" ? "_typeSignal" : "_typeComponent";
code.push(` ${j>0?",":" "}{${cet[j].offset},${cet[j].sizeName}, ${ty}}`);
}
code.push("};");
}
}
_buildSizes(code) {
code.push("// Sizes");
for (let sName in this.sizes) {
const accSizes = this.sizes[sName];
let c = `Circom_Size ${sName}[${accSizes.length}] = {`;
for (let i=0; i<accSizes.length; i++) {
if (i>0) c += ",";
c += accSizes[i];
}
c += "};";
code.push(c);
}
}
_buildConstants(code) {
const self = this;
const n64 = Math.floor((self.header.P.bitLength() - 1) / 64)+1;
const R = bigInt.one.shiftLeft(n64*64);
code.push("// Constants");
code.push(`FrElement _constants[${self.constants.length}] = {`);
for (let i=0; i<self.constants.length; i++) {
code.push((i>0 ? "," : " ") + "{" + number2Code(self.constants[i]) + "}");
}
code.push("};");
function number2Code(n) {
if (n.lt(bigInt("80000000", 16)) ) {
return addShortMontgomeryPositive(n);
}
if (n.geq(self.header.P.minus(bigInt("80000000", 16))) ) {
return addShortMontgomeryNegative(n);
}
return addLongMontgomery(n);
function addShortMontgomeryPositive(a) {
return `${a.toString()}, 0x40000000, { ${getLongString(toMontgomery(a))} }`;
}
function addShortMontgomeryNegative(a) {
const b = a.minus(self.header.P);
return `${b.toString()}, 0x40000000, { ${getLongString(toMontgomery(a))} }`;
}
function addLongMontgomery(a) {
return `0, 0xC0000000, { ${getLongString(toMontgomery(a))} }`;
}
function getLongString(a) {
let r = bigInt(a);
let S = "";
let i = 0;
while (!r.isZero()) {
if (S!= "") S = S+",";
S += "0x" + r.and(bigInt("FFFFFFFFFFFFFFFF", 16)).toString(16) + "LL";
i++;
r = r.shiftRight(64);
}
while (i<n64) {
if (S!= "") S = S+",";
S += "0LL";
i++;
}
return S;
}
function toMontgomery(a) {
return a.times(R).mod(self.header.P);
}
}
}
_buildFunctions(code) {
for (let i=0; i<this.functions.length; i++) {
const cfb = this.functions[i];
cfb.build(code);
}
}
_buildComponents(code) {
code.push("// Components");
code.push(`Circom_Component _components[${this.components.length}] = {`);
for (let i=0; i<this.components.length; i++) {
const c = this.components[i];
const sep = i>0 ? " ," : " ";
code.push(`${sep}{${c.hashMapName}, ${c.entryTableName}, ${c.functionName}, ${c.nInSignals}, ${c.newThread}}`);
}
code.push("};");
}
_buildMapIsInput(code) {
code.push("// mapIsInput");
code.push(`u32 _mapIsInput[${this.mapIsInput.length}] = {`);
let line = "";
for (let i=0; i<this.mapIsInput.length; i++) {
line += i>0 ? ", " : " ";
line += toHex(this.mapIsInput[i]);
if (((i+1) % 64)==0) {
code.push(" "+line);
line = "";
}
}
if (line != "") code.push(" "+line);
code.push("};");
function toHex(number) {
if (number < 0) number = 0xFFFFFFFF + number + 1;
let S=number.toString(16).toUpperCase();
while (S.length<8) S = "0" + S;
return "0x"+S;
}
}
_buildWit2Sig(code) {
code.push("// Witness to Signal Table");
code.push(`int _wit2sig[${this.wit2sig.length}] = {`);
let line = "";
for (let i=0; i<this.wit2sig.length; i++) {
line += i>0 ? "," : " ";
line += this.wit2sig[i];
if (((i+1) % 64) == 0) {
code.push(" "+line);
line = "";
}
}
if (line != "") code.push(" "+line);
code.push("};");
}
_buildCircuitVar(code) {
code.push(
"// Circuit Variable",
"Circom_Circuit _circuit = {" ,
" NSignals,",
" NComponents,",
" NInputs,",
" NOutputs,",
" NVars,",
" _wit2sig,",
" _components,",
" _mapIsInput,",
" _constants,",
" __P__",
"};"
);
}
build() {
const code=[];
this._buildHeader(code);
this._buildSizes(code);
this._buildConstants(code);
this._buildHashMaps(code);
this._buildComponentEntriesTables(code);
this._buildFunctions(code);
this._buildComponents(code);
this._buildMapIsInput(code);
this._buildWit2Sig(code);
this._buildCircuitVar(code);
return streamFromMultiArray(code);
}
}
module.exports = BuilderC;

234
ports/c/calcwit.cpp Normal file
View File

@@ -0,0 +1,234 @@
#include <string>
#include <stdexcept>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <stdlib.h>
#include <assert.h>
#include <stdarg.h>
#include <thread>
#include "calcwit.h"
#include "utils.h"
Circom_CalcWit::Circom_CalcWit(Circom_Circuit *aCircuit) {
circuit = aCircuit;
#ifdef SANITY_CHECK
signalAssigned = new bool[circuit->NSignals];
signalAssigned[0] = true;
#endif
mutexes = new std::mutex[NMUTEXES];
cvs = new std::condition_variable[NMUTEXES];
inputSignalsToTrigger = new int[circuit->NComponents];
signalValues = new FrElement[circuit->NSignals];
// Set one signal
Fr_copy(&signalValues[0], circuit->constants + 1);
reset();
}
Circom_CalcWit::~Circom_CalcWit() {
#ifdef SANITY_CHECK
delete signalAssigned;
#endif
delete[] cvs;
delete[] mutexes;
delete[] signalValues;
delete[] inputSignalsToTrigger;
}
void Circom_CalcWit::syncPrintf(const char *format, ...) {
va_list args;
va_start(args, format);
printf_mutex.lock();
vprintf(format, args);
printf_mutex.unlock();
va_end(args);
}
void Circom_CalcWit::reset() {
#ifdef SANITY_CHECK
for (int i=1; i<circuit->NComponents; i++) signalAssigned[i] = false;
#endif
for (int i=0; i<circuit->NComponents; i++) {
inputSignalsToTrigger[i] = circuit->components[i].inputSignals;
}
for (int i=0; i<circuit->NComponents; i++) {
if (inputSignalsToTrigger[i] == 0) triggerComponent(i);
}
}
int Circom_CalcWit::getSubComponentOffset(int cIdx, u64 hash) {
int hIdx;
for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
}
int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
throw std::runtime_error("invalid type");
}
return circuit->components[cIdx].entries[entryPos].offset;
}
Circom_Sizes Circom_CalcWit::getSubComponentSizes(int cIdx, u64 hash) {
int hIdx;
for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
}
int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
if (circuit->components[cIdx].entries[entryPos].type != _typeComponent) {
throw std::runtime_error("invalid type");
}
return circuit->components[cIdx].entries[entryPos].sizes;
}
int Circom_CalcWit::getSignalOffset(int cIdx, u64 hash) {
int hIdx;
for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
}
int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
throw std::runtime_error("invalid type");
}
return circuit->components[cIdx].entries[entryPos].offset;
}
Circom_Sizes Circom_CalcWit::getSignalSizes(int cIdx, u64 hash) {
int hIdx;
for(hIdx = int(hash & 0xFF); hash!=circuit->components[cIdx].hashTable[hIdx].hash; hIdx++) {
if (!circuit->components[cIdx].hashTable[hIdx].hash) throw std::runtime_error("hash not found: " + int_to_hex(hash));
}
int entryPos = circuit->components[cIdx].hashTable[hIdx].pos;
if (circuit->components[cIdx].entries[entryPos].type != _typeSignal) {
throw std::runtime_error("invalid type");
}
return circuit->components[cIdx].entries[entryPos].sizes;
}
void Circom_CalcWit::getSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
// syncPrintf("getSignal: %d\n", sIdx);
if ((circuit->components[cIdx].newThread)&&(currentComponentIdx != cIdx)) {
std::unique_lock<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
while (inputSignalsToTrigger[cIdx] != -1) {
cvs[cIdx % NMUTEXES].wait(lk);
}
// cvs[cIdx % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[cIdx] == -1;});
lk.unlock();
}
#ifdef SANITY_CHECK
if (signalAssigned[sIdx] == false) {
fprintf(stderr, "Accessing a not assigned signal: %d\n", sIdx);
assert(false);
}
#endif
Fr_copy(value, signalValues + sIdx);
/*
char *valueStr = mpz_get_str(0, 10, *value);
syncPrintf("%d, Get %d --> %s\n", currentComponentIdx, sIdx, valueStr);
free(valueStr);
*/
}
void Circom_CalcWit::finished(int cIdx) {
{
std::lock_guard<std::mutex> lk(mutexes[cIdx % NMUTEXES]);
inputSignalsToTrigger[cIdx] = -1;
}
// syncPrintf("Finished: %d\n", cIdx);
cvs[cIdx % NMUTEXES].notify_all();
}
void Circom_CalcWit::setSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value) {
// syncPrintf("setSignal: %d\n", sIdx);
#ifdef SANITY_CHECK
if (signalAssigned[sIdx] == true) {
fprintf(stderr, "Signal assigned twice: %d\n", sIdx);
assert(false);
}
signalAssigned[sIdx] = true;
#endif
// Log assignement
/*
char *valueStr = mpz_get_str(0, 10, *value);
syncPrintf("%d, Set %d --> %s\n", currentComponentIdx, sIdx, valueStr);
free(valueStr);
*/
Fr_copy(signalValues + sIdx, value);
if ( BITMAP_ISSET(circuit->mapIsInput, sIdx) ) {
if (inputSignalsToTrigger[cIdx]>0) {
inputSignalsToTrigger[cIdx]--;
if (inputSignalsToTrigger[cIdx] == 0) triggerComponent(cIdx);
} else {
fprintf(stderr, "Input signals does not match with map: %d\n", sIdx);
assert(false);
}
}
}
void Circom_CalcWit::checkConstraint(int currentComponentIdx, PFrElement value1, PFrElement value2, char const *err) {
#ifdef SANITY_CHECK
FrElement tmp;
Fr_eq(&tmp, value1, value2);
if (!Fr_isTrue(&tmp)) {
char *pcV1 = Fr_element2str(value1);
char *pcV2 = Fr_element2str(value2);
// throw std::runtime_error(std::to_string(currentComponentIdx) + std::string(", Constraint doesn't match, ") + err + ". " + sV1 + " != " + sV2 );
fprintf(stderr, "Constraint doesn't match, %s: %s != %s", err, pcV1, pcV2);
free(pcV1);
free(pcV2);
assert(false);
}
#endif
}
void Circom_CalcWit::triggerComponent(int newCIdx) {
//int oldCIdx = cIdx;
// cIdx = newCIdx;
if (circuit->components[newCIdx].newThread) {
// syncPrintf("Triggered: %d\n", newCIdx);
std::thread t(circuit->components[newCIdx].fn, this, newCIdx);
// t.join();
t.detach();
} else {
(*(circuit->components[newCIdx].fn))(this, newCIdx);
}
// cIdx = oldCIdx;
}
void Circom_CalcWit::log(PFrElement value) {
char *pcV = Fr_element2str(value);
syncPrintf("Log: %s\n", pcV);
free(pcV);
}
void Circom_CalcWit::join() {
for (int i=0; i<circuit->NComponents; i++) {
std::unique_lock<std::mutex> lk(mutexes[i % NMUTEXES]);
while (inputSignalsToTrigger[i] != -1) {
cvs[i % NMUTEXES].wait(lk);
}
// cvs[i % NMUTEXES].wait(lk, [&]{return inputSignalsToTrigger[i] == -1;});
lk.unlock();
// syncPrintf("Joined: %d\n", i);
}
}

73
ports/c/calcwit.h Normal file
View File

@@ -0,0 +1,73 @@
#ifndef CIRCOM_CALCWIT_H
#define CIRCOM_CALCWIT_H
#include "circom.h"
#include "fr.h"
#include <mutex>
#include <condition_variable>
#define NMUTEXES 128
class Circom_CalcWit {
#ifdef SANITY_CHECK
bool *signalAssigned;
#endif
// componentStatus -> For each component
// >0 Signals required to trigger
// == 0 Component triggered
// == -1 Component finished
int *inputSignalsToTrigger;
std::mutex *mutexes;
std::condition_variable *cvs;
std::mutex printf_mutex;
FrElement *signalValues;
void triggerComponent(int newCIdx);
void calculateWitness(void *input, void *output);
void syncPrintf(const char *format, ...);
public:
Circom_Circuit *circuit;
// Functions called by the circuit
Circom_CalcWit(Circom_Circuit *aCircuit);
~Circom_CalcWit();
int getSubComponentOffset(int cIdx, u64 hash);
Circom_Sizes getSubComponentSizes(int cIdx, u64 hash);
int getSignalOffset(int cIdx, u64 hash);
Circom_Sizes getSignalSizes(int cIdx, u64 hash);
void getSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value);
void setSignal(int currentComponentIdx, int cIdx, int sIdx, PFrElement value);
void checkConstraint(int currentComponentIdx, PFrElement value1, PFrElement value2, char const *err);
void log(PFrElement value);
void finished(int cIdx);
void join();
// Public functions
inline void setInput(int idx, PFrElement val) {
setSignal(0, 0, circuit->wit2sig[idx], val);
}
inline void getWitness(int idx, PFrElement val) {
Fr_copy(val, &signalValues[circuit->wit2sig[idx]]);
}
void reset();
};
#endif // CIRCOM_CALCWIT_H

58
ports/c/circom.h Normal file
View File

@@ -0,0 +1,58 @@
#ifndef __CIRCOM_H
#define __CIRCOM_H
#include <gmp.h>
#include <stdint.h>
#include "fr.h"
class Circom_CalcWit;
typedef unsigned long long u64;
typedef uint32_t u32;
typedef uint8_t u8;
typedef int Circom_Size;
typedef Circom_Size *Circom_Sizes;
struct Circom_HashEntry {
u64 hash;
int pos;
};
typedef Circom_HashEntry *Circom_HashTable;
typedef enum { _typeSignal, _typeComponent} Circom_EntryType;
struct Circom_ComponentEntry {
int offset;
Circom_Sizes sizes;
Circom_EntryType type;
};
typedef Circom_ComponentEntry *Circom_ComponentEntries;
typedef void (*Circom_ComponentFunction)(Circom_CalcWit *ctx, int __cIdx);
struct Circom_Component {
Circom_HashTable hashTable;
Circom_ComponentEntries entries;
Circom_ComponentFunction fn;
int inputSignals;
bool newThread;
};
class Circom_Circuit {
public:
int NSignals;
int NComponents;
int NInputs;
int NOutputs;
int NVars;
int *wit2sig;
Circom_Component *components;
u32 *mapIsInput;
PFrElement constants;
const char *P;
};
#define BITMAP_ISSET(m, b) (m[b>>5] & (1 << (b&0x1F)))
extern struct Circom_Circuit _circuit;
#endif

1
ports/c/fr.c Symbolic link
View File

@@ -0,0 +1 @@
buildasm/fr.c

1
ports/c/fr.h Symbolic link
View File

@@ -0,0 +1 @@
buildasm/fr.h

1
ports/c/fr.o Symbolic link
View File

@@ -0,0 +1 @@
buildasm/fr.o

202
ports/c/main.cpp Normal file
View File

@@ -0,0 +1,202 @@
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <iomanip>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
#include "calcwit.h"
#include "circom.h"
#include "utils.h"
#define handle_error(msg) \
do { perror(msg); exit(EXIT_FAILURE); } while (0)
void loadBin(Circom_CalcWit *ctx, std::string filename) {
int fd;
struct stat sb;
// map input
fd = open(filename.c_str(), O_RDONLY);
if (fd == -1)
handle_error("open");
if (fstat(fd, &sb) == -1) /* To obtain file size */
handle_error("fstat");
u8 *in;
in = (u8 *)mmap(NULL, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (in == MAP_FAILED)
handle_error("mmap");
close(fd);
FrElement v;
u8 *p = in;
for (int i=0; i<_circuit.NInputs; i++) {
v.type = Fr_LONG;
for (int j=0; j<Fr_N64; j++) {
v.longVal[j] = *(u64 *)p;
}
p += 8;
ctx->setSignal(0, 0, _circuit.wit2sig[1 + _circuit.NOutputs + i], &v);
}
}
typedef void (*ItFunc)(Circom_CalcWit *ctx, int idx, json val);
void iterateArr(Circom_CalcWit *ctx, int o, Circom_Sizes sizes, json jarr, ItFunc f) {
if (!jarr.is_array()) {
assert((sizes[0] == 1)&&(sizes[1] == 0));
f(ctx, o, jarr);
} else {
int n = sizes[0] / sizes[1];
for (int i=0; i<n; i++) {
iterateArr(ctx, o + i*sizes[1], sizes+1, jarr[i], f);
}
}
}
void itFunc(Circom_CalcWit *ctx, int o, json val) {
FrElement v;
std::string s;
if (val.is_string()) {
s = val.get<std::string>();
} else if (val.is_number()) {
double vd = val.get<double>();
std::stringstream stream;
stream << std::fixed << std::setprecision(0) << vd;
s = stream.str();
} else {
handle_error("Invalid JSON type");
}
Fr_str2element (&v, s.c_str());
ctx->setSignal(0, 0, o, &v);
}
void loadJson(Circom_CalcWit *ctx, std::string filename) {
std::ifstream inStream(filename);
json j;
inStream >> j;
for (json::iterator it = j.begin(); it != j.end(); ++it) {
// std::cout << it.key() << " => " << it.value() << '\n';
u64 h = fnv1a(it.key());
int o;
try {
o = ctx->getSignalOffset(0, h);
} catch (std::runtime_error e) {
std::ostringstream errStrStream;
errStrStream << "Error loadin variable: " << it.key() << "\n" << e.what();
throw std::runtime_error(errStrStream.str() );
}
Circom_Sizes sizes = ctx->getSignalSizes(0, h);
iterateArr(ctx, o, sizes, it.value(), itFunc);
}
}
void writeOutBin(Circom_CalcWit *ctx, std::string filename) {
FILE *write_ptr;
write_ptr = fopen(filename.c_str(),"wb");
FrElement v;
u8 buffOut[256];
for (int i=0;i<_circuit.NVars;i++) {
size_t size=256;
ctx->getWitness(i, &v);
Fr_toLongNormal(&v);
fwrite(v.longVal, Fr_N64*8, 1, write_ptr);
}
fclose(write_ptr);
}
void writeOutJson(Circom_CalcWit *ctx, std::string filename) {
std::ofstream outFile;
outFile.open (filename);
outFile << "[\n";
FrElement v;
for (int i=0;i<_circuit.NVars;i++) {
ctx->getWitness(i, &v);
char *pcV = Fr_element2str(&v);
std::string sV = std::string(pcV);
outFile << (i ? "," : " ") << "\"" << sV << "\"\n";
free(pcV);
}
outFile << "]\n";
outFile.close();
}
bool hasEnding (std::string const &fullString, std::string const &ending) {
if (fullString.length() >= ending.length()) {
return (0 == fullString.compare (fullString.length() - ending.length(), ending.length(), ending));
} else {
return false;
}
}
int main(int argc, char *argv[]) {
Fr_init();
if (argc!=3) {
std::string cl = argv[0];
std::string base_filename = cl.substr(cl.find_last_of("/\\") + 1);
std::cout << "Usage: " << base_filename << " <input.<bin|json>> <output.<bin|json>>\n";
} else {
// open output
Circom_CalcWit *ctx = new Circom_CalcWit(&_circuit);
std::string infilename = argv[1];
if (hasEnding(infilename, std::string(".bin"))) {
loadBin(ctx, infilename);
} else if (hasEnding(infilename, std::string(".json"))) {
loadJson(ctx, infilename);
} else {
handle_error("Invalid input extension (.bin / .json)");
}
ctx->join();
printf("Finished!\n");
std::string outfilename = argv[2];
if (hasEnding(outfilename, std::string(".bin"))) {
writeOutBin(ctx, outfilename);
} else if (hasEnding(outfilename, std::string(".json"))) {
writeOutJson(ctx, outfilename);
} else {
handle_error("Invalid output extension (.bin / .json)");
}
delete ctx;
exit(EXIT_SUCCESS);
}
}

47
ports/c/mainjson.cpp Normal file
View File

@@ -0,0 +1,47 @@
#include <iostream>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
#include "utils.h"
#include "circom.h"
#include "calcwit.h"
auto j = R"(
{
"in": "314"
}
)"_json;
typedef void (*ItFunc)(int idx, json val);
void iterateArr(int o, Circom_Sizes sizes, json jarr, ItFunc f) {
if (!jarr.is_array()) {
assert((sizes[0] == 1)&&(sizes[1] == 0));
f(o, jarr);
} else {
int n = sizes[0] / sizes[1];
for (int i=0; i<n; i++) {
iterateArr(o + i*sizes[1], sizes+1, jarr[i], f);
}
}
}
void itFunc(int o, json v) {
std::cout << o << " <-- " << v << '\n';
}
int main(int argc, char **argv) {
Circom_CalcWit *ctx = new Circom_CalcWit(&_circuit);
for (json::iterator it = j.begin(); it != j.end(); ++it) {
// std::cout << it.key() << " => " << it.value() << '\n';
u64 h = fnv1a(it.key());
int o = ctx->getSignalOffset(0, h);
Circom_Sizes sizes = ctx->getSignalSizes(0, h);
iterateArr(o, sizes, it.value(), itFunc);
}
}

186
ports/c/tester.js Normal file
View File

@@ -0,0 +1,186 @@
const chai = require("chai");
const assert = chai.assert;
const fs = require("fs");
var tmp = require("tmp-promise");
const path = require("path");
const compiler = require("../../src/compiler");
const util = require("util");
const exec = util.promisify(require("child_process").exec);
const stringifyBigInts = require("../../src/utils").stringifyBigInts;
const unstringifyBigInts = require("../../src/utils").unstringifyBigInts;
const bigInt = require("big-integer");
const utils = require("../../src/utils");
const loadR1cs = require("../../src/r1csfile").loadR1cs;
const ZqField = require("fflib").ZqField;
module.exports = c_tester;
async function c_tester(circomFile, _options) {
tmp.setGracefulCleanup();
const dir = await tmp.dir({prefix: "circom_", unsafeCleanup: true });
// console.log(dir.path);
const baseName = path.basename(circomFile, ".circom");
const options = Object.assign({}, _options);
options.cSourceWriteStream = fs.createWriteStream(path.join(dir.path, baseName + ".cpp"));
options.symWriteStream = fs.createWriteStream(path.join(dir.path, baseName + ".sym"));
options.r1csFileName = path.join(dir.path, baseName + ".r1cs");
await compiler(circomFile, options);
const cdir = path.join(__dirname, "..", "c");
await exec("cp" +
` ${path.join(dir.path, baseName + ".cpp")}` +
" /tmp/circuit.cpp"
);
await exec("g++" +
` ${path.join(cdir, "main.cpp")}` +
` ${path.join(cdir, "calcwit.cpp")}` +
` ${path.join(cdir, "utils.cpp")}` +
` ${path.join(cdir, "fr.c")}` +
` ${path.join(cdir, "fr.o")}` +
` ${path.join(dir.path, baseName + ".cpp")} ` +
` -o ${path.join(dir.path, baseName)}` +
` -I ${cdir}` +
" -lgmp -std=c++11 -DSANITY_CHECK"
);
// console.log(dir.path);
return new CTester(dir, baseName);
}
class CTester {
constructor(dir, baseName) {
this.dir=dir;
this.baseName = baseName;
}
async release() {
await this.dir.cleanup();
}
async calculateWitness(input) {
await fs.promises.writeFile(
path.join(this.dir.path, "in.json"),
JSON.stringify(stringifyBigInts(input), null, 1)
);
await exec(`${path.join(this.dir.path, this.baseName)}` +
` ${path.join(this.dir.path, "in.json")}` +
` ${path.join(this.dir.path, "out.json")}`
);
const resStr = await fs.promises.readFile(
path.join(this.dir.path, "out.json")
);
const res = unstringifyBigInts(JSON.parse(resStr));
return res;
}
async loadSymbols() {
if (this.symbols) return;
this.symbols = {};
const symsStr = await fs.promises.readFile(
path.join(this.dir.path, this.baseName + ".sym"),
"utf8"
);
const lines = symsStr.split("\n");
for (let i=0; i<lines.length; i++) {
const arr = lines[i].split(",");
if (arr.length!=3) continue;
this.symbols[arr[2]] = {
idx: Number(arr[0]),
idxWit: Number(arr[1])
};
}
}
async loadConstraints() {
const self = this;
if (this.constraints) return;
const r1cs = await loadR1cs(path.join(this.dir.path, this.baseName + ".r1cs"),true, false);
self.field = new ZqField(r1cs.prime);
self.nWires = r1cs.nWires;
self.constraints = r1cs.constraints;
}
async assertOut(actualOut, expectedOut) {
const self = this;
if (!self.symbols) await self.loadSymbols();
checkObject("main", expectedOut);
function checkObject(prefix, eOut) {
if (Array.isArray(eOut)) {
for (let i=0; i<eOut.length; i++) {
checkObject(prefix + "["+i+"]", eOut[i]);
}
} else if ((typeof eOut == "object")&&(eOut.constructor.name == "Object")) {
for (let k in eOut) {
checkObject(prefix + "."+k, eOut[k]);
}
} else {
if (typeof self.symbols[prefix] == "undefined") {
assert(false, "Output variable not defined: "+ prefix);
}
const ba = bigInt(actualOut[self.symbols[prefix].idxWit]).toString();
const be = bigInt(eOut).toString();
assert.strictEqual(ba, be, prefix);
}
}
}
async getDecoratedOutput(witness) {
const self = this;
const lines = [];
if (!self.symbols) await self.loadSymbols();
for (let n in self.symbols) {
let v;
if (utils.isDefined(witness[self.symbols[n].idxWit])) {
v = witness[self.symbols[n].idxWit].toString();
} else {
v = "undefined";
}
lines.push(`${n} --> ${v}`);
}
return lines.join("\n");
}
async checkConstraints(witness) {
const self = this;
if (!self.constraints) await self.loadConstraints();
for (let i=0; i<self.constraints.length; i++) {
checkConstraint(self.constraints[i]);
}
function checkConstraint(constraint) {
const F = self.field;
const a = evalLC(constraint.a);
const b = evalLC(constraint.b);
const c = evalLC(constraint.c);
assert (F.sub(F.mul(a,b), c).isZero(), "Constraint doesn't match");
}
function evalLC(lc) {
const F = self.field;
let v = F.zero;
for (let w in lc) {
v = F.add(
v,
F.mul( lc[w], witness[w] )
);
}
return v;
}
}
}

25
ports/c/utils.cpp Normal file
View File

@@ -0,0 +1,25 @@
#include <string>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <stdlib.h>
#include "utils.h"
std::string int_to_hex( u64 i )
{
std::stringstream stream;
stream << "0x"
<< std::setfill ('0') << std::setw(16)
<< std::hex << i;
return stream.str();
}
u64 fnv1a(std::string s) {
u64 hash = 0xCBF29CE484222325LL;
for(char& c : s) {
hash ^= u64(c);
hash *= 0x100000001B3LL;
}
return hash;
}

10
ports/c/utils.h Normal file
View File

@@ -0,0 +1,10 @@
#ifndef __UTILS_H
#define __UTILS_H
#include "circom.h"
std::string int_to_hex( u64 i );
u64 fnv1a(std::string s);
#endif // __UTILS_H

199
ports/c/zqfield.cpp Normal file
View File

@@ -0,0 +1,199 @@
#include "zqfield.h"
ZqField::ZqField(PBigInt ap) {
mpz_init_set(p, *ap);
mpz_init_set_ui(zero, 0);
mpz_init_set_ui(one, 1);
nBits = mpz_sizeinbase (p, 2);
mpz_init(mask);
mpz_mul_2exp(mask, one, nBits-1);
mpz_sub(mask, mask, one);
}
ZqField::~ZqField() {
mpz_clear(p);
mpz_clear(zero);
mpz_clear(one);
}
void ZqField::add(PBigInt r, PBigInt a, PBigInt b) {
mpz_add(*r,*a,*b);
if (mpz_cmp(*r, p) >= 0) {
mpz_sub(*r, *r, p);
}
}
void ZqField::sub(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_cmp(*a, *b) >= 0) {
mpz_sub(*r, *a, *b);
} else {
mpz_sub(*r, *b, *a);
mpz_sub(*r, p, *r);
}
}
void ZqField::neg(PBigInt r, PBigInt a) {
if (mpz_sgn(*a) > 0) {
mpz_sub(*r, p, *a);
} else {
mpz_set(*r, *a);
}
}
void ZqField::mul(PBigInt r, PBigInt a, PBigInt b) {
mpz_t tmp;
mpz_init(tmp);
mpz_mul(tmp,*a,*b);
mpz_fdiv_r(*r, tmp, p);
mpz_clear(tmp);
}
void ZqField::div(PBigInt r, PBigInt a, PBigInt b) {
mpz_t tmp;
mpz_init(tmp);
mpz_invert(tmp, *b, p);
mpz_mul(tmp,*a,tmp);
mpz_fdiv_r(*r, tmp, p);
mpz_clear(tmp);
}
void ZqField::idiv(PBigInt r, PBigInt a, PBigInt b) {
mpz_fdiv_q(*r, *a, *b);
}
void ZqField::mod(PBigInt r, PBigInt a, PBigInt b) {
mpz_fdiv_r(*r, *a, *b);
}
void ZqField::pow(PBigInt r, PBigInt a, PBigInt b) {
mpz_powm(*r, *a, *b, p);
}
void ZqField::lt(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b);
if (c<0) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::eq(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b);
if (c==0) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::gt(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b);
if (c>0) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::leq(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b);
if (c<=0) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::geq(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b);
if (c>=0) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::neq(PBigInt r, PBigInt a, PBigInt b) {
int c = mpz_cmp(*a, *b);
if (c!=0) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::land(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_sgn(*a) && mpz_sgn(*b)) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::lor(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_sgn(*a) || mpz_sgn(*b)) {
mpz_set(*r, one);
} else {
mpz_set(*r, zero);
}
}
void ZqField::lnot(PBigInt r, PBigInt a) {
if (mpz_sgn(*a)) {
mpz_set(*r, zero);
} else {
mpz_set(*r, one);
}
}
int ZqField::isTrue(PBigInt a) {
return mpz_sgn(*a);
}
void ZqField::copyn(PBigInt a, PBigInt b, int n) {
for (int i=0;i<n; i++) mpz_set(a[i], b[i]);
}
void ZqField::band(PBigInt r, PBigInt a, PBigInt b) {
mpz_and(*r, *a, *b);
mpz_and(*r, *r, mask);
}
void ZqField::bor(PBigInt r, PBigInt a, PBigInt b) {
mpz_ior(*r, *a, *b);
mpz_and(*r, *r, mask);
}
void ZqField::bxor(PBigInt r, PBigInt a, PBigInt b) {
mpz_xor(*r, *a, *b);
mpz_and(*r, *r, mask);
}
void ZqField::bnot(PBigInt r, PBigInt a) {
mpz_xor(*r, *a, mask);
mpz_and(*r, *r, mask);
}
void ZqField::shl(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_cmp_ui(*b, nBits) >= 0) {
mpz_set(*r, zero);
} else {
mpz_mul_2exp(*r, *a, mpz_get_ui(*b));
mpz_and(*r, *r, mask);
}
}
void ZqField::shr(PBigInt r, PBigInt a, PBigInt b) {
if (mpz_cmp_ui(*b, nBits) >= 0) {
mpz_set(*r, zero);
} else {
mpz_tdiv_q_2exp(*r, *a, mpz_get_ui(*b));
mpz_and(*r, *r, mask);
}
}
int ZqField::toInt(PBigInt a) {
return mpz_get_si (*a);
}

49
ports/c/zqfield.h Normal file
View File

@@ -0,0 +1,49 @@
#ifndef ZQFIELD_H
#define ZQFIELD_H
#include "circom.h"
class ZqField {
public:
BigInt p;
BigInt one;
BigInt zero;
size_t nBits;
BigInt mask;
ZqField(PBigInt ap);
~ZqField();
void copyn(PBigInt a, PBigInt b, int n);
void add(PBigInt r,PBigInt a, PBigInt b);
void sub(PBigInt r,PBigInt a, PBigInt b);
void neg(PBigInt r,PBigInt a);
void mul(PBigInt r,PBigInt a, PBigInt b);
void div(PBigInt r,PBigInt a, PBigInt b);
void idiv(PBigInt r,PBigInt a, PBigInt b);
void mod(PBigInt r,PBigInt a, PBigInt b);
void pow(PBigInt r,PBigInt a, PBigInt b);
void lt(PBigInt r, PBigInt a, PBigInt b);
void eq(PBigInt r, PBigInt a, PBigInt b);
void gt(PBigInt r, PBigInt a, PBigInt b);
void leq(PBigInt r, PBigInt a, PBigInt b);
void geq(PBigInt r, PBigInt a, PBigInt b);
void neq(PBigInt r, PBigInt a, PBigInt b);
void land(PBigInt r, PBigInt a, PBigInt b);
void lor(PBigInt r, PBigInt a, PBigInt b);
void lnot(PBigInt r, PBigInt a);
void band(PBigInt r, PBigInt a, PBigInt b);
void bor(PBigInt r, PBigInt a, PBigInt b);
void bxor(PBigInt r, PBigInt a, PBigInt b);
void bnot(PBigInt r, PBigInt a);
void shl(PBigInt r, PBigInt a, PBigInt b);
void shr(PBigInt r, PBigInt a, PBigInt b);
int isTrue(PBigInt a);
int toInt(PBigInt a);
};
#endif // ZQFIELD_H

748
ports/wasm/build_runtime.js Normal file
View File

@@ -0,0 +1,748 @@
const errs = require("./errs");
const buildWasmFf = require("fflib").buildWasmFf;
module.exports = function buildRuntime(module, builder) {
function buildInit() {
const f = module.addFunction("init");
f.addLocal("i", "i32");
const c = f.getCodeBuilder();
// Set the stack to current memory
f.addCode(
c.i32_store(
c.i32_const(4),
c.i32_shl(
c.i32_and(
c.current_memory(),
c.i32_const(0xFFFFFFF8)
),
c.i32_const(16)
)
)
);
f.addCode(
// i=0
c.setLocal("i", c.i32_const(0)),
c.block(c.loop(
// if (i==NComponents) break
c.br_if(1, c.i32_eq(c.getLocal("i"), c.i32_const(builder.header.NComponents))),
// inputSignalsToTrigger[i] = components[i].nInputSignals
c.i32_store(
c.i32_add(
c.i32_const(builder.pInputSignalsToTrigger),
c.i32_mul(
c.getLocal("i"),
c.i32_const(4)
)
),
c.i32_load(
c.i32_add(
c.i32_load(c.i32_const(builder.ppComponents)),
c.i32_mul(
c.getLocal("i"),
c.i32_const(builder.sizeofComponent) // Sizeof component
)
),
builder.offsetComponentNInputSignals
)
),
// i=i+1
c.setLocal(
"i",
c.i32_add(
c.getLocal("i"),
c.i32_const(1)
)
),
c.br(0)
))
);
if (builder.sanityCheck) {
f.addCode(
// i=0
c.setLocal("i", c.i32_const(0)),
c.block(c.loop(
// if (i==NSignals) break
c.br_if(1, c.i32_eq(c.getLocal("i"), c.i32_const(builder.header.NSignals))),
// signalsAssigned[i] = false
c.i32_store(
c.i32_add(
c.i32_const(builder.pSignalsAssigned),
c.i32_mul(
c.getLocal("i"),
c.i32_const(4)
)
),
c.i32_const(0)
),
// i=i+1
c.setLocal(
"i",
c.i32_add(
c.getLocal("i"),
c.i32_const(1)
)
),
c.br(0)
))
);
}
f.addCode(
c.call(
"Fr_copy",
c.i32_const(builder.pSignals),
c.i32_add(
c.i32_load(c.i32_const(builder.ppConstants)),
c.i32_const(builder.addConstant(1) * builder.sizeFr)
)
)
);
if (builder.sanityCheck) {
f.addCode(
c.i32_store(
c.i32_const(builder.pSignalsAssigned),
c.i32_const(1)
)
);
}
f.addCode(
// i=0
c.setLocal("i", c.i32_const(0)),
c.block(c.loop(
// if (i==NComponents) break
c.br_if(1, c.i32_eq(c.getLocal("i"), c.i32_const(builder.header.NComponents))),
// if (inputSignalsToTrigger[i] == 0) triggerComponent(i)
c.if(
c.i32_eqz(
c.i32_load(
c.i32_add(
c.i32_const(builder.pInputSignalsToTrigger),
c.i32_mul(
c.getLocal("i"),
c.i32_const(4)
)
)
)
),
c.call(
"triggerComponent",
c.getLocal("i")
)
),
// i=i+1
c.setLocal(
"i",
c.i32_add(
c.getLocal("i"),
c.i32_const(1)
)
),
c.br(0)
))
);
}
function buildTriggerComponent() {
const f = module.addFunction("triggerComponent");
f.addParam("component", "i32");
const c = f.getCodeBuilder();
f.addCode(
c.call_indirect(
c.getLocal("component"), // Idx in table
c.getLocal("component") // Parameter
)
);
}
function buildHash2ComponentEntry() {
const f = module.addFunction("hash2ComponentEntry");
f.addParam("component", "i32");
f.addParam("hash", "i64");
f.setReturnType("i32");
f.addLocal("pComponent", "i32");
f.addLocal("pHashTable", "i32");
f.addLocal("hIdx", "i32");
f.addLocal("h", "i64");
const c = f.getCodeBuilder();
f.addCode(
c.setLocal(
"pComponent",
c.i32_add(
c.i32_load(c.i32_const(builder.ppComponents)), // pComponents
c.i32_mul(
c.getLocal("component"),
c.i32_const(20) // sizeof(Component)
)
)
),
c.setLocal(
"pHashTable",
c.i32_load(c.getLocal("pComponent"))
),
c.setLocal(
"hIdx",
c.i32_and(
c.i32_wrap_i64(c.getLocal("hash")),
c.i32_const(0xFF)
)
),
c.block(c.loop(
c.setLocal(
"h",
c.i64_load(
c.i32_add(
c.getLocal("pHashTable"),
c.i32_mul(
c.getLocal("hIdx"),
c.i32_const(12)
)
)
)
),
c.br_if(1, c.i64_eq(c.getLocal("h"), c.getLocal("hash"))),
c.if(
c.i64_eqz(c.getLocal("h")),
c.call(
"err",
c.i32_const(errs.HASH_NOT_FOUND.code),
c.i32_const(errs.HASH_NOT_FOUND.pointer)
)
),
c.setLocal(
"hIdx",
c.i32_and(
c.i32_add(
c.getLocal("hIdx"),
c.i32_const(1)
),
c.i32_const(0xFF)
)
),
c.br(0)
)),
c.i32_add( // pComponentEntry
c.i32_load( // pComponentEntryTable
c.i32_add(
c.getLocal("pComponent"),
c.i32_const(4)
)
),
c.i32_mul(
c.i32_load( // idx to the componentEntry
c.i32_add(
c.getLocal("pHashTable"),
c.i32_mul(
c.getLocal("hIdx"),
c.i32_const(12)
)
),
8
),
c.i32_const(12)
)
)
);
}
function buildGetFromComponentEntry(fnName, offset, type) {
const f = module.addFunction(fnName);
f.addParam("pR", "i32");
f.addParam("component", "i32");
f.addParam("hash", "i64");
f.addLocal("pComponentEntry", "i32");
const c = f.getCodeBuilder();
f.addCode(
c.setLocal(
"pComponentEntry",
c.call(
"hash2ComponentEntry",
c.getLocal("component"),
c.getLocal("hash")
)
),
c.if( // If type is not signal
c.i32_ne(
c.i32_load(
c.getLocal("pComponentEntry"),
8 // type offset
),
c.i32_const(type)
),
c.call(
"err",
c.i32_const(errs.INVALID_TYPE.code),
c.i32_const(errs.INVALID_TYPE.pointer)
)
),
c.i32_store(
c.getLocal("pR"),
c.i32_load(
c.getLocal("pComponentEntry"),
offset
)
)
);
const f2 = module.addFunction(fnName + "32");
f2.addParam("pR", "i32");
f2.addParam("component", "i32");
f2.addParam("hashMSB", "i32");
f2.addParam("hashLSB", "i32");
const c2 = f2.getCodeBuilder();
f2.addCode(
c2.call(
fnName,
c2.getLocal("pR"),
c2.getLocal("component"),
c2.i64_or(
c2.i64_shl(
c2.i64_extend_i32_u(c2.getLocal("hashMSB")),
c2.i64_const(32)
),
c2.i64_extend_i32_u(c2.getLocal("hashLSB"))
)
)
);
}
function buildGetSignal() {
const f = module.addFunction("getSignal");
f.addParam("cIdx", "i32");
f.addParam("pR", "i32");
f.addParam("component", "i32");
f.addParam("signal", "i32");
const c = f.getCodeBuilder();
if (builder.sanityCheck) {
f.addCode(
c.if(
c.i32_eqz(
c.i32_load(
c.i32_add(
c.i32_const(builder.pSignalsAssigned),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(4)
)
),
)
),
c.call(
"err",
c.i32_const(errs.ACCESSING_NOT_ASSIGNED_SIGNAL.code),
c.i32_const(errs.ACCESSING_NOT_ASSIGNED_SIGNAL.pointer)
)
)
);
}
f.addCode(
c.call(
"Fr_copy",
c.getLocal("pR"),
c.i32_add(
c.i32_const(builder.pSignals),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(builder.sizeFr)
)
)
)
);
}
function buildSetSignal() {
const f = module.addFunction("setSignal");
f.addParam("cIdx", "i32");
f.addParam("component", "i32");
f.addParam("signal", "i32");
f.addParam("pVal", "i32");
f.addLocal("signalsToTrigger", "i32");
const c = f.getCodeBuilder();
if (builder.sanityCheck) {
f.addCode(
c.if(
c.i32_load(
c.i32_add(
c.i32_const(builder.pSignalsAssigned),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(4)
)
),
),
c.call(
"err",
c.i32_const(errs.SIGNAL_ASSIGNED_TWICE.code),
c.i32_const(errs.SIGNAL_ASSIGNED_TWICE.pointer)
)
),
c.i32_store(
c.i32_add(
c.i32_const(builder.pSignalsAssigned),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(4)
)
),
c.i32_const(1)
),
);
}
f.addCode(
c.call(
"Fr_copy",
c.i32_add(
c.i32_const(builder.pSignals),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(builder.sizeFr)
)
),
c.getLocal("pVal"),
)
);
f.addCode(
c.if( // If ( mapIsInput[s >> 5] & 1 << (s & 0x1f) )
c.i32_and(
c.i32_load(
c.i32_add(
c.i32_load(c.i32_const(builder.ppMapIsInput)),
c.i32_shl(
c.i32_shr_u(
c.getLocal("signal"),
c.i32_const(5)
),
c.i32_const(2)
)
)
),
c.i32_shl(
c.i32_const(1),
c.i32_and(
c.getLocal("signal"),
c.i32_const(0x1F)
)
)
),
[
...c.setLocal(
"signalsToTrigger",
c.i32_load(
c.i32_add(
c.i32_const(builder.pInputSignalsToTrigger),
c.i32_mul(
c.getLocal("component"),
c.i32_const(4)
)
)
)
),
...c.if( // if (signalsToTrigger > 0)
c.i32_gt_u(
c.getLocal("signalsToTrigger"),
c.i32_const(0)
),
[
...c.setLocal( // signalsToTrigger--
"signalsToTrigger",
c.i32_sub(
c.getLocal("signalsToTrigger"),
c.i32_const(1)
)
),
...c.i32_store(
c.i32_add(
c.i32_const(builder.pInputSignalsToTrigger),
c.i32_mul(
c.getLocal("component"),
c.i32_const(4)
)
),
c.getLocal("signalsToTrigger"),
),
...c.if( // if (signalsToTrigger==0) triggerCompomnent(component)
c.i32_eqz(c.getLocal("signalsToTrigger")),
c.call(
"triggerComponent",
c.getLocal("component")
)
)
],
c.call(
"err2",
c.i32_const(errs.MAPISINPUT_DONT_MATCH.code),
c.i32_const(errs.MAPISINPUT_DONT_MATCH.pointer),
c.getLocal("component"),
c.getLocal("signal")
)
)
]
)
);
}
function buildComponentFinished() {
const f = module.addFunction("componentFinished");
f.addParam("cIdx", "i32");
const c = f.getCodeBuilder();
f.addCode(c.ret([]));
}
function buildCheckConstraint() {
const pTmp = module.alloc(builder.sizeFr);
const f = module.addFunction("checkConstraint");
f.addParam("cIdx", "i32");
f.addParam("pA", "i32");
f.addParam("pB", "i32");
f.addParam("pStr", "i32");
const c = f.getCodeBuilder();
if (builder.sanityCheck) {
f.addCode(
c.call(
"Fr_eq",
c.getLocal(c.i32_const(pTmp)),
c.getLocal("pA"),
c.getLocal("pB")
),
c.if (
c.eqz(
c.call(
"Fr_isTrue",
c.getLocal(c.i32_const(pTmp)),
)
),
c.call(
"err4",
c.i32_const(errs.CONSTRAIN_DOES_NOT_MATCH.code),
c.i32_const(errs.CONSTRAIN_DOES_NOT_MATCH.pointer),
c.getLocal("cIdx"),
c.getLocal("pA"),
c.getLocal("pB"),
c.getLocal("pStr"),
)
)
);
}
}
function buildGetNVars() {
const f = module.addFunction("getNVars");
f.setReturnType("i32");
const c = f.getCodeBuilder();
f.addCode(c.i32_const(builder.header.NVars));
}
function buildGetFrLen() {
const f = module.addFunction("getFrLen");
f.setReturnType("i32");
const c = f.getCodeBuilder();
f.addCode(
c.i32_const(builder.sizeFr));
}
function buildGetPRawPrime() {
const f = module.addFunction("getPRawPrime");
f.setReturnType("i32");
const c = f.getCodeBuilder();
f.addCode(
c.i32_const(module.modules["Fr_F1m"].pq));
}
function buildGetPWitness() {
const f = module.addFunction("getPWitness");
f.addParam("w", "i32");
f.addLocal("signal", "i32");
f.setReturnType("i32");
const c = f.getCodeBuilder();
f.addCode(
c.setLocal(
"signal",
c.i32_load( // wit2sig[w]
c.i32_add(
c.i32_load( c.i32_const(builder.ppWit2sig)),
c.i32_mul(
c.getLocal("w"),
c.i32_const(4)
)
)
)
)
);
if (builder.sanityCheck) {
f.addCode(
c.if(
c.i32_eqz(
c.i32_load(
c.i32_add(
c.i32_const(builder.pSignalsAssigned),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(4)
)
),
)
),
c.call(
"err",
c.i32_const(errs.ACCESSING_NOT_ASSIGNED_SIGNAL.code),
c.i32_const(errs.ACCESSING_NOT_ASSIGNED_SIGNAL.pointer)
)
)
);
}
f.addCode(
c.i32_add(
c.i32_const(builder.pSignals),
c.i32_mul(
c.getLocal("signal"),
c.i32_const(builder.sizeFr)
)
)
);
}
function buildFrToInt() {
const f = module.addFunction("Fr_toInt");
f.addParam("p", "i32");
f.setReturnType("i32");
const c = f.getCodeBuilder();
f.addCode(
c.i32_load(c.getLocal("p"))
);
// TODO Handle long and montgomery.
}
const fErr = module.addIimportFunction("err", "runtime");
fErr.addParam("code", "i32");
fErr.addParam("pStr", "i32");
const fErr1 = module.addIimportFunction("err1", "runtime");
fErr1.addParam("code", "i32");
fErr1.addParam("pStr", "i32");
fErr1.addParam("param1", "i32");
const fErr2 = module.addIimportFunction("err2", "runtime");
fErr2.addParam("code", "i32");
fErr2.addParam("pStr", "i32");
fErr2.addParam("param1", "i32");
fErr2.addParam("param2", "i32");
const fErr3 = module.addIimportFunction("err3", "runtime");
fErr3.addParam("code", "i32");
fErr3.addParam("pStr", "i32");
fErr3.addParam("param1", "i32");
fErr3.addParam("param2", "i32");
fErr3.addParam("param3", "i32");
const fErr4 = module.addIimportFunction("err4", "runtime");
fErr4.addParam("code", "i32");
fErr4.addParam("pStr", "i32");
fErr4.addParam("param1", "i32");
fErr4.addParam("param2", "i32");
fErr4.addParam("param3", "i32");
fErr4.addParam("param4", "i32");
buildWasmFf(module, "Fr", builder.header.P);
builder.pSignals=module.alloc(builder.header.NSignals*builder.sizeFr);
builder.pInputSignalsToTrigger=module.alloc(builder.header.NComponents*4);
if (builder.sanityCheck) {
builder.pSignalsAssigned=module.alloc(builder.header.NSignals*4);
}
buildHash2ComponentEntry();
buildTriggerComponent();
buildInit();
buildGetFromComponentEntry("getSubComponentOffset", 0 /* offset */, builder.TYPE_COMPONENT);
buildGetFromComponentEntry("getSubComponentSizes", 4 /* offset */, builder.TYPE_COMPONENT);
buildGetFromComponentEntry("getSignalOffset", 0 /* offset */, builder.TYPE_SIGNAL);
buildGetFromComponentEntry("getSignalSizes", 4 /* offset */, builder.TYPE_SIGNAL);
buildGetSignal();
buildSetSignal();
buildComponentFinished();
buildCheckConstraint();
buildGetNVars();
buildGetFrLen();
buildGetPWitness();
buildGetPRawPrime();
buildFrToInt();
module.exportFunction("init");
module.exportFunction("getNVars");
module.exportFunction("getFrLen");
module.exportFunction("getSignalOffset32");
module.exportFunction("setSignal");
module.exportFunction("getPWitness");
module.exportFunction("Fr_toInt");
module.exportFunction("getPRawPrime");
};

1003
ports/wasm/builder.js Normal file

File diff suppressed because it is too large Load Diff

10
ports/wasm/errs.js Normal file
View File

@@ -0,0 +1,10 @@
module.exports = {
STACK_OUT_OF_MEM: {code: 1, str: "Stack out of memory"},
STACK_TOO_SMALL: {code: 2, str: "Stack too small"},
HASH_NOT_FOUND: {code: 3, str: "Hash not found"},
INVALID_TYPE: {code: 4, str: "Invalid type"},
ACCESSING_NOT_ASSIGNED_SIGNAL: {code: 5, str: "Accessing a not assigned signal"},
SIGNAL_ASSIGNED_TWICE: {code: 6, str: "Signal assigned twice"},
CONSTRAIN_DOES_NOT_MATCH: {code: 7, str: "Constraint doesn't match"},
MAPISINPUT_DONT_MATCH: {code: 8, str: "MapIsInput don't match"},
};

167
ports/wasm/tester.js Normal file
View File

@@ -0,0 +1,167 @@
const chai = require("chai");
const assert = chai.assert;
const fs = require("fs");
var tmp = require("tmp-promise");
const path = require("path");
const compiler = require("../../src/compiler");
const util = require("util");
const exec = util.promisify(require("child_process").exec);
const stringifyBigInts = require("../../src/utils").stringifyBigInts;
const unstringifyBigInts = require("../../src/utils").unstringifyBigInts;
const bigInt = require("big-integer");
const utils = require("../../src/utils");
const loadR1cs = require("../../src/r1csfile").loadR1cs;
const ZqField = require("fflib").ZqField;
const WitnessCalculator = require("./witness_calculator");
module.exports = wasm_tester;
async function wasm_tester(circomFile, _options) {
tmp.setGracefulCleanup();
const dir = await tmp.dir({prefix: "circom_", unsafeCleanup: true });
// console.log(dir.path);
const baseName = path.basename(circomFile, ".circom");
const options = Object.assign({}, _options);
options.wasmWriteStream = fs.createWriteStream(path.join(dir.path, baseName + ".wasm"));
options.symWriteStream = fs.createWriteStream(path.join(dir.path, baseName + ".sym"));
options.r1csFileName = path.join(dir.path, baseName + ".r1cs");
const promisesArr = [];
promisesArr.push(new Promise(fulfill => options.wasmWriteStream.on("finish", fulfill)));
await compiler(circomFile, options);
await Promise.all(promisesArr);
const wc = await WitnessCalculator.fromFile(path.join(dir.path, baseName + ".wasm"));
return new WasmTester(dir, baseName, wc);
}
class WasmTester {
constructor(dir, baseName, witnessCalculator) {
this.dir=dir;
this.baseName = baseName;
this.witnessCalculator = witnessCalculator;
}
async release() {
await this.dir.cleanup();
}
async calculateWitness(input) {
return await this.witnessCalculator.calculateWitness(input);
}
async loadSymbols() {
if (this.symbols) return;
this.symbols = {};
const symsStr = await fs.promises.readFile(
path.join(this.dir.path, this.baseName + ".sym"),
"utf8"
);
const lines = symsStr.split("\n");
for (let i=0; i<lines.length; i++) {
const arr = lines[i].split(",");
if (arr.length!=3) continue;
this.symbols[arr[2]] = {
idx: Number(arr[0]),
idxWit: Number(arr[1])
};
}
}
async loadConstraints() {
const self = this;
if (this.constraints) return;
const r1cs = await loadR1cs(path.join(this.dir.path, this.baseName + ".r1cs"),true, false);
self.field = new ZqField(r1cs.prime);
self.nWires = r1cs.nWires;
self.constraints = r1cs.constraints;
}
async assertOut(actualOut, expectedOut) {
const self = this;
if (!self.symbols) await self.loadSymbols();
checkObject("main", expectedOut);
function checkObject(prefix, eOut) {
if (Array.isArray(eOut)) {
for (let i=0; i<eOut.length; i++) {
checkObject(prefix + "["+i+"]", eOut[i]);
}
} else if ((typeof eOut == "object")&&(eOut.constructor.name == "Object")) {
for (let k in eOut) {
checkObject(prefix + "."+k, eOut[k]);
}
} else {
if (typeof self.symbols[prefix] == "undefined") {
assert(false, "Output variable not defined: "+ prefix);
}
const ba = bigInt(actualOut[self.symbols[prefix].idxWit]).toString();
const be = bigInt(eOut).toString();
assert.strictEqual(ba, be, prefix);
}
}
}
async getDecoratedOutput(witness) {
const self = this;
const lines = [];
if (!self.symbols) await self.loadSymbols();
for (let n in self.symbols) {
let v;
if (utils.isDefined(witness[self.symbols[n].idxWit])) {
v = witness[self.symbols[n].idxWit].toString();
} else {
v = "undefined";
}
lines.push(`${n} --> ${v}`);
}
return lines.join("\n");
}
async checkConstraints(witness) {
const self = this;
if (!self.constraints) await self.loadConstraints();
for (let i=0; i<self.constraints.length; i++) {
checkConstraint(self.constraints[i]);
}
function checkConstraint(constraint) {
const F = self.field;
const a = evalLC(constraint.a);
const b = evalLC(constraint.b);
const c = evalLC(constraint.c);
assert (F.sub(F.mul(a,b), c).isZero(), "Constraint doesn't match");
}
function evalLC(lc) {
const F = self.field;
let v = F.zero;
for (let w in lc) {
v = F.add(
v,
F.mul( lc[w], witness[w] )
);
}
return v;
}
}
}

View File

@@ -0,0 +1,186 @@
/* globals WebAssembly */
const fs = require("fs");
const utils = require("../../src/utils");
const bigInt = require("big-integer");
module.exports.fromFile = async function(file) {
const code = await fs.promises.readFile(file);
return await module.exports.fromBuffer(code);
};
module.exports.fromBuffer = async function(code) {
const memory = new WebAssembly.Memory({initial:20000});
const wasmModule = await WebAssembly.compile(code);
const instance = await WebAssembly.instantiate(wasmModule, {
env: {
"memory": memory
},
runtime: {
err: function(code, pstr) {
console.log("ERROR", code, p2str(pstr));
},
err1: function(code, pstr, a) {
console.log("ERROR: ", code, p2str(pstr), a);
},
err2: function(code, pstr, a, b) {
console.log("ERROR: ", code, p2str(pstr), a, b);
},
err3: function(code, pstr, a, b, c) {
console.log("ERROR: ", code, p2str(pstr), a, b, c);
},
err4: function(code, pstr, a,b,c,d) {
console.log("ERROR: ", code, p2str(pstr), a, b, c, d);
},
}
});
return new WitnessCalculator(memory, instance);
function p2str(p) {
return "TODO"+p;
}
};
class WitnessCalculator {
constructor(memory, instance) {
this.memory = memory;
this.i32 = new Uint32Array(memory.buffer);
this.instance = instance;
this.n32 = (this.instance.exports.getFrLen() >> 2) - 2;
const pRawPrime = this.instance.exports.getPRawPrime();
this.prime = bigInt(0);
for (let i=this.n32-1; i>=0; i--) {
this.prime = this.prime.shiftLeft(32);
this.prime = this.prime.add(bigInt(this.i32[(pRawPrime >> 2) + i]));
}
this.mask32 = bigInt("FFFFFFFF", 16);
this.NVars = this.instance.exports.getNVars();
this.n64 = Math.floor((this.prime.bitLength() - 1) / 64)+1;
this.R = bigInt.one.shiftLeft(this.n64*64);
this.RInv = this.R.modInv(this.prime);
}
async calculateWitness(input) {
const w = [];
const old0 = this.i32[0];
this.instance.exports.init();
const pSigOffset = this.allocInt();
const pFr = this.allocFr();
for (let k in input) {
const h = utils.fnvHash(k);
const hMSB = parseInt(h.slice(0,8), 16);
const hLSB = parseInt(h.slice(8,16), 16);
this.instance.exports.getSignalOffset32(pSigOffset, 0, hMSB, hLSB);
const sigOffset = this.getInt(pSigOffset);
const fArr = utils.flatArray(input[k]);
for (let i=0; i<fArr.length; i++) {
this.setFr(pFr, fArr[i]);
this.instance.exports.setSignal(0, 0, sigOffset + i, pFr);
}
}
for (let i=0; i<this.NVars; i++) {
const pWitness = this.instance.exports.getPWitness(i);
w.push(this.getFr(pWitness));
}
this.i32[0] = old0;
return w;
}
allocInt() {
const p = this.i32[0];
this.i32[0] = p+8;
return p;
}
allocFr() {
const p = this.i32[0];
this.i32[0] = p+this.n32*4 + 8;
return p;
}
getInt(p) {
return this.i32[p>>2];
}
setInt(p, v) {
this.i32[p>>2] = v;
}
getFr(p) {
const idx = (p>>2);
if (this.i32[idx + 1] & 0x80000000) {
let res= bigInt(0);
for (let i=this.n32-1; i>=0; i--) {
res = res.shiftLeft(32);
res = res.add(bigInt(this.i32[idx+2+i]));
}
if (this.i32[idx + 1] & 0x40000000) {
return fromMontgomery(res);
} else {
return res;
}
} else {
if (this.i32[idx] & 0x80000000) {
return this.prime.add( bigInt(this.i32[idx]).minus(bigInt(0x100000000)) );
} else {
return bigInt(this.i32[idx]);
}
}
function fromMontgomery(n) {
return n.times(this.RInv).mod(this.prime);
}
}
setFr(p, v) {
const self = this;
v = bigInt(v);
if (v.lt(bigInt("80000000", 16)) ) {
return setShortPositive(v);
}
if (v.geq(self.prime.minus(bigInt("80000000", 16))) ) {
return setShortNegative(v);
}
return setLongNormal(v);
function setShortPositive(a) {
self.i32[(p >> 2)] = parseInt(a);
self.i32[(p >> 2) + 1] = 0;
}
function setShortNegative(a) {
const b = bigInt("80000000", 16 ).add(a.minus( self.prime.minus(bigInt("80000000", 16 ))));
self.i32[(p >> 2)] = parseInt(b);
self.i32[(p >> 2) + 1] = 0;
}
function setLongNormal(a) {
self.i32[(p >> 2)] = 0;
self.i32[(p >> 2) + 1] = 0x80000000;
for (let i=0; i<self.n32; i++) {
self.i32[(p >> 2) + 2 + i] = a.shiftRight(i*32).and(self.mask32);
}
}
}
}