package circuitcompiler import ( "fmt" "github.com/stretchr/testify/assert" "math/big" "strings" "testing" ) type InOut struct { inputs []*big.Int result *big.Int } type TraceCorrectnessTest struct { code string io []InOut } var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10) var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10) var correctnesTest = []TraceCorrectnessTest{ { io: []InOut{{ inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, result: big.NewInt(int64(1729500084900343)), }, { inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, result: bigNumberResult1, }}, code: ` def main( x , z ) : out = do(z) + add(x,x) def do(x): e = x * 5 b = e * 6 c = b * 7 f = c * 1 d = c * f out = d * mul(d,e) def add(x ,k): z = k * x out = do(x) + mul(x,z) def mul(a,b): out = a * b `, }, {io: []InOut{{ inputs: []*big.Int{big.NewInt(int64(7))}, result: big.NewInt(int64(4)), }}, code: ` def mul(a,b): out = a * b def main(a): b = a * a c = 4 - b d = 5 * c out = mul(d,c) / mul(b,b) `, }, {io: []InOut{{ inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, result: big.NewInt(int64(22638)), }, { inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, result: bigNumberResult2, }}, code: ` def main(a,b): d = b + b c = a * d e = c - a out = e * c `, }, { io: []InOut{{ inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))}, result: big.NewInt(int64(98441327276)), }, { inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, result: big.NewInt(int64(8675445947220)), }}, code: ` def main(a,b): c = a + b e = c - a f = e + b g = f + 2 out = g * a `, }, { io: []InOut{{ inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))}, result: big.NewInt(int64(444675)), }}, code: ` def main(a,b,c,d): e = a * b f = c * d g = e * f h = g / e i = h * 5 out = g * i `, }, { io: []InOut{{ inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))}, result: big.NewInt(int64(264)), }}, code: ` def main(a,b,c,d): e = a * 3 f = b * 7 g = c * 11 h = d * 13 i = e + f j = g + h k = i + j out = k * 1 `, }, } func TestCorrectness(t *testing.T) { for _, test := range correctnesTest { parser := NewParser(strings.NewReader(test.code)) program, err := parser.Parse() if err != nil { panic(err) } fmt.Println("\n unreduced") fmt.Println(test.code) program.BuildConstraintTrees() for k, v := range program.functions { fmt.Println(k) PrintTree(v.root) } fmt.Println("\nReduced gates") //PrintTree(froots["mul"]) gates := program.ReduceCombinedTree() for _, g := range gates { fmt.Printf("\n %v", g) } fmt.Println("\n generating R1CS") r1cs := program.GenerateReducedR1CS(gates) fmt.Println(r1cs.A) fmt.Println(r1cs.B) fmt.Println(r1cs.C) for _, io := range test.io { inputs := io.inputs fmt.Println("input") fmt.Println(inputs) w := CalculateWitness(inputs, r1cs) fmt.Println("witness") fmt.Println(w) assert.Equal(t, io.result, w[program.GlobalInputCount()]) } } } //test to check gate optimisation //mess around the code s.t. results is unchanged. number of gates should remain the same in any case func TestGateOptimisation(t *testing.T) { io := InOut{ inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, result: bigNumberResult1, } equalCodes := []string{ ` def main( x , z ) : out = do(z) + add(x,x) def do(x): e = x * 5 b = e * 6 c = 7 * b f = c * 1 d = f * c out = d * mul(d,e) def add(x ,k): z = k * x out = do(x) + mul(x,z) def mul(a,b): out = b * a `, //switching order ` def main( x , z ) : out = do(z) + add(x,x) def do(x): e = x * 5 b = e * 6 c = b * 7 f = c * 1 d = c * f out = d * mul(d,e) def add(x ,k): z = k * x out = do(x) + mul(x,z) def mul(a,b): out = a * b `, //switching order ` def main( x , z ) : out = do(z) + add(x,x) def do(x): e = x * 5 j = e * 3 k = e * 3 b = j+k c = b * 7 f = c * 1 d = c * f g = d * 1 out = g * mul(d,e) def add(k ,x): z = k * x out = do(x) + mul(x,z) def mul(b,a): out = a * b `, ` def main( x , z ) : out = add(x,x)+do(z) def do(x): e = x * 5 j = 3 * e k = e * 3 b = j+k c = b * 7 f = c * 1 d = c * f g = d * 1 out = mul(d,e) * g def add(k ,x): z = k * x out = mul(x,z) + do(x) def mul(b,a): out = a * b `, } var r1css = make([]R1CS, len(equalCodes)) for i, c := range equalCodes { parser := NewParser(strings.NewReader(c)) program, err := parser.Parse() if err != nil { panic(err) } program.BuildConstraintTrees() gates := program.ReduceCombinedTree() for _, g := range gates { fmt.Printf("\n %v", g) } fmt.Println("\n generating R1CS") r1cs := program.GenerateReducedR1CS(gates) r1css[i] = r1cs fmt.Println(r1cs.A) fmt.Println(r1cs.B) fmt.Println(r1cs.C) } for i := 0; i < len(equalCodes)-1; i++ { assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A)) } for i := 0; i < len(equalCodes); i++ { //assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A)) w := CalculateWitness(io.inputs, r1css[i]) fmt.Println("witness") fmt.Println(w) assert.Equal(t, io.result, w[3]) } }