diff --git a/circuitcompiler/test.py b/circuitcompiler/test.py index 5ce384c..62efd1a 100644 --- a/circuitcompiler/test.py +++ b/circuitcompiler/test.py @@ -78,8 +78,3 @@ def main(): if __name__ == '__main__': #pascal(8) main() - - - [[0 1 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1 0 0]] - [[0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 5 0]] - [[0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 1]] \ No newline at end of file diff --git a/snark.go b/snark.go index 4220d92..0f3f71d 100644 --- a/snark.go +++ b/snark.go @@ -357,7 +357,9 @@ func VerifyProof(setup Setup, proof Proof, publicSignals []*big.Int, debug bool) } //TODO this is just a workaround to place the output after the input signals. Will be removed once the handling of private variables is already considered in the lexer -func RelocateOutput(numberOfInputs int, r1cs circuitcompiler.R1CS, witness []*big.Int) (r circuitcompiler.R1CS, w []*big.Int) { +func moveOutputToBegining(r1cs circuitcompiler.R1CS) (r circuitcompiler.R1CS) { + return r1cs + // activating this part, causes a huge messup I want to deal with a bit later tmpA, tmpB, tmpC := [][]*big.Int{}, [][]*big.Int{}, [][]*big.Int{} tmpA = append(tmpA, r1cs.A[len(r1cs.A)-1]) @@ -369,8 +371,15 @@ func RelocateOutput(numberOfInputs int, r1cs circuitcompiler.R1CS, witness []*bi tmpC = append(tmpC, r1cs.C[len(r1cs.C)-1]) tmpC = append(tmpC, r1cs.C[:len(r1cs.C)-1]...) + return circuitcompiler.R1CS{A: tmpA, B: tmpB, C: tmpC} +} + +//TODO this is just a workaround to place the output after the input signals. Will be removed once the handling of private variables is already considered in the lexer +func moveWitnessOutputAfterInputs(numberOfInputs int, witness []*big.Int) (w []*big.Int) { + return witness + // activating this part, causes a huge messup I want to deal with a bit later wtmp := append(witness[:numberOfInputs], witness[len(witness)-1]) wtmp = append(wtmp, witness[numberOfInputs:len(witness)-2]...) - return circuitcompiler.R1CS{A: tmpA, B: tmpB, C: tmpC}, wtmp + return wtmp } diff --git a/snark_test.go b/snark_test.go index 0e952c2..5cb969c 100644 --- a/snark_test.go +++ b/snark_test.go @@ -8,11 +8,106 @@ import ( "math/big" "strings" "testing" + "time" ) -func TestNewProgramm(t *testing.T) { +type InOut struct { + inputs []*big.Int + result *big.Int +} + +type TraceCorrectnessTest struct { + code string + io []InOut +} + +var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10) +var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10) - flat := ` +var correctnesTest = []TraceCorrectnessTest{ + //{ + // io: []InOut{{ + // inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, + // result: big.NewInt(int64(1729500084900343)), + // }, { + // inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + // + // result: bigNumberResult1, + // }}, + // code: ` + //def main( x , z ) : + // out = do(z) + add(x,x) + // + //def do(x): + // e = x * 5 + // b = e * 6 + // c = b * 7 + // f = c * 1 + // d = c * f + // out = d * mul(d,e) + // + //def add(x ,k): + // z = k * x + // out = do(x) + mul(x,z) + // + // + //def mul(a,b): + // out = a * b + //`, + //}, + //{io: []InOut{{ + // inputs: []*big.Int{big.NewInt(int64(7))}, + // result: big.NewInt(int64(4)), + //}}, + // code: ` + //def mul(a,b): + // out = a * b + // + //def main(a): + // b = a * a + // c = 4 - b + // d = 5 * c + // out = mul(d,c) / mul(b,b) + //`, + //}, + //{io: []InOut{{ + // inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, + // result: big.NewInt(int64(22638)), + //}, { + // inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + // result: bigNumberResult2, + //}}, + // code: ` + //def main(a,b): + // d = b + b + // c = a * d + // e = c - a + // out = e * c + //`, + //}, + //{ + // io: []InOut{{ + // inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))}, + // result: big.NewInt(int64(98441327276)), + // }, { + // inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + // result: big.NewInt(int64(8675445947220)), + // }}, + // code: ` + //def main(a,b): + // c = a + b + // e = c - a + // f = e + b + // g = f + 2 + // out = g * a + //`, + //}, + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(444675)), + }}, + code: ` def main(a,b,c,d): e = a * b f = c * d @@ -20,106 +115,115 @@ func TestNewProgramm(t *testing.T) { h = g / e i = h * 5 out = g * i - ` + `, + }, +} - parser := circuitcompiler.NewParser(strings.NewReader(flat)) - program, err := parser.Parse() +func TestGenerateAndVerifyProof(t *testing.T) { + + for _, test := range correctnesTest { + + parser := circuitcompiler.NewParser(strings.NewReader(test.code)) + program, err := parser.Parse() + + if err != nil { + panic(err) + } + fmt.Println("\n unreduced") + fmt.Println(test.code) + + program.BuildConstraintTrees() + program.PrintContraintTrees() + fmt.Println("\nReduced gates") + //PrintTree(froots["mul"]) + gates := program.ReduceCombinedTree() + for _, g := range gates { + fmt.Println(g) + } + + fmt.Println("generating R1CS") + //NOTE MOVE DOES NOTHING CURRENTLY + r1cs := moveOutputToBegining(program.GenerateReducedR1CS(gates)) + //[[0 1 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1 0 0]] + //[[0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 5 0]] + //[[0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 1]] + + a, b, c := r1cs.A, r1cs.B, r1cs.C + fmt.Println(a) + fmt.Println(b) + fmt.Println(c) + + // R1CS to QAP + alphas, betas, gammas, domain := Utils.PF.R1CSToQAP(a, b, c) + fmt.Println("QAP array lengths") + fmt.Println("alphas", len(alphas)) + fmt.Println("betas", len(betas)) + fmt.Println("gammas", len(gammas)) + fmt.Println("domain polynomial ", len(domain)) + + before := time.Now() + //calculate trusted setup + setup, err := GenerateTrustedSetup(len(alphas[0]), alphas, betas, gammas) + fmt.Println("Generate CRS time elapsed:", time.Since(before)) + assert.Nil(t, err) + fmt.Println("\nt:", setup.Toxic.T) + + for _, io := range test.io { + + inputs := io.inputs + fmt.Println("input") + fmt.Println(inputs) + w := circuitcompiler.CalculateWitness(inputs, r1cs) + fmt.Println("\nwitness", w) + //NOTE MOVE DOES NOTHING + w = moveWitnessOutputAfterInputs(program.GlobalInputCount(), w) + fmt.Println("\nwitness Reordered ", w) + + assert.Equal(t, io.result, w[len(w)-1]) + + ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) + fmt.Println("ax length", len(ax)) + fmt.Println("bx length", len(bx)) + fmt.Println("cx length", len(cx)) + fmt.Println("px length", len(px)) + + hxQAP := Utils.PF.DivisorPolynomial(px, domain) + fmt.Println("hx length", len(hxQAP)) + + // hx==px/zx so px==hx*zx + assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain)) + + // p(x) = a(x) * b(x) - c(x) == h(x) * z(x) + abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx) + assert.Equal(t, abc, px) + + div, rem := Utils.PF.Div(px, domain) + assert.Equal(t, hxQAP, div) //not necessary, since DivisorPolynomial is Div, just discarding 'rem' + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(len(px)-len(domain))) + + //// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation + //assert.Equal(t, domain, setup.Pk.Z) + + hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) + + // assert.Equal(t, hxQAP, hx) + assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain)) + assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z)) + assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1) + assert.Equal(t, len(hxQAP), len(px)-len(domain)+1) + + before := time.Now() + // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) + proof, err := GenerateProofs(setup, program.GlobalInputCount(), w, px) + fmt.Println("proof generation time elapsed:", time.Since(before)) + assert.Nil(t, err) + + before = time.Now() + assert.True(t, VerifyProof(setup, proof, append(w[1:program.GlobalInputCount()], w[len(w)-1]), true)) + fmt.Println("verify proof time elapsed:", time.Since(before)) + + } - if err != nil { - panic(err) - } - fmt.Println("\n unreduced") - fmt.Println(flat) - - program.BuildConstraintTrees() - program.PrintContraintTrees() - fmt.Println("\nReduced gates") - //PrintTree(froots["mul"]) - gates := program.ReduceCombinedTree() - for _, g := range gates { - fmt.Println(g) } - fmt.Println("generating R1CS") - r1cs := program.GenerateReducedR1CS(gates) - //[[0 1 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1 0 0]] - //[[0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 5 0]] - //[[0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 1]] - a1 := big.NewInt(int64(6)) - a2 := big.NewInt(int64(5)) - inputs := []*big.Int{a1, a2, a1, a2} - w := circuitcompiler.CalculateWitness(inputs, r1cs) - - r1csReordered, wReordered := RelocateOutput(program.GlobalInputCount(), r1cs, w) - - a, b, c := r1csReordered.A, r1csReordered.B, r1csReordered.C - fmt.Println(a) - fmt.Println(b) - fmt.Println(c) - - // R1CS to QAP - alphas, betas, gammas, domain := Utils.PF.R1CSToQAP(a, b, c) - fmt.Println("QAP array lengths") - fmt.Println("alphas", len(alphas)) - fmt.Println("betas", len(betas)) - fmt.Println("gammas", len(gammas)) - fmt.Println("domain polynomial ", len(domain)) - - ax, bx, cx, px := Utils.PF.CombinePolynomials(wReordered, alphas, betas, gammas) - fmt.Println("ax length", len(ax)) - fmt.Println("bx length", len(bx)) - fmt.Println("cx length", len(cx)) - fmt.Println("px length", len(px)) - - hxQAP := Utils.PF.DivisorPolynomial(px, domain) - fmt.Println("hx length", hxQAP) - - // hx==px/zx so px==hx*zx - assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain)) - - // p(x) = a(x) * b(x) - c(x) == h(x) * z(x) - abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx) - assert.Equal(t, abc, px) - - div, rem := Utils.PF.Div(px, domain) - assert.Equal(t, hxQAP, div) //not necessary, since DivisorPolynomial is Div, just discarding 'rem' - assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(len(px)-len(domain))) - - //calculate trusted setup - setup, err := GenerateTrustedSetup(len(w), alphas, betas, gammas) - assert.Nil(t, err) - fmt.Println("\nt:", setup.Toxic.T) - // - //// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation - //assert.Equal(t, domain, setup.Pk.Z) - - fmt.Println("hx pk.z", hxQAP) - hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) - fmt.Println("hx pk.z", hx) - // assert.Equal(t, hxQAP, hx) - assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain)) - assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z)) - - assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1) - assert.Equal(t, len(hxQAP), len(px)-len(domain)+1) - // fmt.Println("pk.Z", len(setup.Pk.Z)) - // fmt.Println("zxQAP", len(zxQAP)) - - // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) - proof, err := GenerateProofs(setup, program.GlobalInputCount(), wReordered, px) - assert.Nil(t, err) - - // fmt.Println("\n proofs:") - // fmt.Println(proof) - - // fmt.Println("public signals:", proof.PublicSignals) - fmt.Println("\nwitness", w) - fmt.Println("\nwitness Reordered ", wReordered) - // b1 := big.NewInt(int64(1)) - //b35 := big.NewInt(int64(35)) - //// publicSignals := []*big.Int{b1, b35} - //publicSignals := []*big.Int{b35} - //before := time.Now() - assert.True(t, VerifyProof(setup, proof, wReordered[:program.GlobalInputCount()], true)) - //fmt.Println("verify proof time elapsed:", time.Since(before)) }