package circuitcompiler
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/stretchr/testify/assert"
|
|
"math/big"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
type InOut struct {
|
|
inputs []*big.Int
|
|
result *big.Int
|
|
}
|
|
|
|
type TraceCorrectnessTest struct {
|
|
code string
|
|
io []InOut
|
|
}
|
|
|
|
var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
|
|
var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
|
|
|
|
var correctnesTest = []TraceCorrectnessTest{
|
|
{
|
|
io: []InOut{{
|
|
inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
|
|
result: big.NewInt(int64(1729500084900343)),
|
|
}, {
|
|
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
|
|
|
|
result: bigNumberResult1,
|
|
}},
|
|
code: `
|
|
def main( x , z ) :
|
|
out = do(z) + add(x,x)
|
|
|
|
def do(x):
|
|
e = x * 5
|
|
b = e * 6
|
|
c = b * 7
|
|
f = c * 1
|
|
d = c * f
|
|
out = d * mul(d,e)
|
|
|
|
def add(x ,k):
|
|
z = k * x
|
|
out = do(x) + mul(x,z)
|
|
|
|
|
|
def mul(a,b):
|
|
out = a * b
|
|
`,
|
|
},
|
|
{io: []InOut{{
|
|
inputs: []*big.Int{big.NewInt(int64(7))},
|
|
result: big.NewInt(int64(4)),
|
|
}},
|
|
code: `
|
|
def mul(a,b):
|
|
out = a * b
|
|
|
|
def main(a):
|
|
b = a * a
|
|
c = 4 - b
|
|
d = 5 * c
|
|
out = mul(d,c) / mul(b,b)
|
|
`,
|
|
},
|
|
{io: []InOut{{
|
|
inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
|
|
result: big.NewInt(int64(22638)),
|
|
}, {
|
|
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
|
|
result: bigNumberResult2,
|
|
}},
|
|
code: `
|
|
def main(a,b):
|
|
d = b + b
|
|
c = a * d
|
|
e = c - a
|
|
out = e * c
|
|
`,
|
|
},
|
|
{
|
|
io: []InOut{{
|
|
inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
|
|
result: big.NewInt(int64(98441327276)),
|
|
}, {
|
|
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
|
|
result: big.NewInt(int64(8675445947220)),
|
|
}},
|
|
code: `
|
|
def main(a,b):
|
|
c = a + b
|
|
e = c - a
|
|
f = e + b
|
|
g = f + 2
|
|
out = g * a
|
|
`,
|
|
},
|
|
{
|
|
io: []InOut{{
|
|
inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
|
|
result: big.NewInt(int64(444675)),
|
|
}},
|
|
code: `
|
|
def main(a,b,c,d):
|
|
e = a * b
|
|
f = c * d
|
|
g = e * f
|
|
h = g / e
|
|
i = h * 5
|
|
out = g * i
|
|
`,
|
|
},
|
|
{
|
|
io: []InOut{{
|
|
inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
|
|
result: big.NewInt(int64(264)),
|
|
}},
|
|
code: `
|
|
def main(a,b,c,d):
|
|
e = a * 3
|
|
f = b * 7
|
|
g = c * 11
|
|
h = d * 13
|
|
i = e + f
|
|
j = g + h
|
|
k = i + j
|
|
out = k * 1
|
|
`,
|
|
},
|
|
}
|
|
|
|
func TestCorrectness(t *testing.T) {
|
|
|
|
for _, test := range correctnesTest {
|
|
parser := NewParser(strings.NewReader(test.code))
|
|
program, err := parser.Parse()
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Println("\n unreduced")
|
|
fmt.Println(test.code)
|
|
|
|
program.BuildConstraintTrees()
|
|
for k, v := range program.functions {
|
|
fmt.Println(k)
|
|
PrintTree(v.root)
|
|
}
|
|
|
|
fmt.Println("\nReduced gates")
|
|
//PrintTree(froots["mul"])
|
|
gates := program.ReduceCombinedTree()
|
|
for _, g := range gates {
|
|
fmt.Printf("\n %v", g)
|
|
}
|
|
|
|
fmt.Println("\n generating R1CS")
|
|
r1cs := program.GenerateReducedR1CS(gates)
|
|
fmt.Println(r1cs.A)
|
|
fmt.Println(r1cs.B)
|
|
fmt.Println(r1cs.C)
|
|
|
|
for _, io := range test.io {
|
|
inputs := io.inputs
|
|
fmt.Println("input")
|
|
fmt.Println(inputs)
|
|
w := CalculateWitness(inputs, r1cs)
|
|
fmt.Println("witness")
|
|
fmt.Println(w)
|
|
assert.Equal(t, io.result, w[program.GlobalInputCount()])
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
//test to check gate optimisation
|
|
//mess around the code s.t. results is unchanged. number of gates should remain the same in any case
|
|
func TestGateOptimisation(t *testing.T) {
|
|
|
|
io := InOut{
|
|
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
|
|
|
|
result: bigNumberResult1,
|
|
}
|
|
|
|
equalCodes := []string{
|
|
`
|
|
def main( x , z ) :
|
|
out = do(z) + add(x,x)
|
|
|
|
def do(x):
|
|
e = x * 5
|
|
b = e * 6
|
|
c = 7 * b
|
|
f = c * 1
|
|
d = f * c
|
|
out = d * mul(d,e)
|
|
|
|
def add(x ,k):
|
|
z = k * x
|
|
out = do(x) + mul(x,z)
|
|
|
|
|
|
def mul(a,b):
|
|
out = b * a
|
|
`, //switching order
|
|
`
|
|
def main( x , z ) :
|
|
out = do(z) + add(x,x)
|
|
|
|
def do(x):
|
|
e = x * 5
|
|
b = e * 6
|
|
c = b * 7
|
|
f = c * 1
|
|
d = c * f
|
|
out = d * mul(d,e)
|
|
|
|
def add(x ,k):
|
|
z = k * x
|
|
out = do(x) + mul(x,z)
|
|
|
|
|
|
def mul(a,b):
|
|
out = a * b
|
|
`, //switching order
|
|
`
|
|
def main( x , z ) :
|
|
out = do(z) + add(x,x)
|
|
|
|
def do(x):
|
|
e = x * 5
|
|
j = e * 3
|
|
k = e * 3
|
|
b = j+k
|
|
c = b * 7
|
|
f = c * 1
|
|
d = c * f
|
|
g = d * 1
|
|
out = g * mul(d,e)
|
|
|
|
def add(k ,x):
|
|
z = k * x
|
|
out = do(x) + mul(x,z)
|
|
|
|
|
|
def mul(b,a):
|
|
out = a * b
|
|
`, `
|
|
def main( x , z ) :
|
|
out = add(x,x)+do(z)
|
|
|
|
def do(x):
|
|
e = x * 5
|
|
j = 3 * e
|
|
k = e * 3
|
|
b = j+k
|
|
c = b * 7
|
|
f = c * 1
|
|
d = c * f
|
|
g = d * 1
|
|
out = mul(d,e) * g
|
|
|
|
def add(k ,x):
|
|
z = k * x
|
|
out = mul(x,z) + do(x)
|
|
|
|
|
|
def mul(b,a):
|
|
out = a * b
|
|
`,
|
|
}
|
|
|
|
var r1css = make([]R1CS, len(equalCodes))
|
|
|
|
for i, c := range equalCodes {
|
|
parser := NewParser(strings.NewReader(c))
|
|
program, err := parser.Parse()
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
program.BuildConstraintTrees()
|
|
|
|
gates := program.ReduceCombinedTree()
|
|
for _, g := range gates {
|
|
fmt.Printf("\n %v", g)
|
|
}
|
|
fmt.Println("\n generating R1CS")
|
|
r1cs := program.GenerateReducedR1CS(gates)
|
|
r1css[i] = r1cs
|
|
fmt.Println(r1cs.A)
|
|
fmt.Println(r1cs.B)
|
|
fmt.Println(r1cs.C)
|
|
}
|
|
|
|
for i := 0; i < len(equalCodes)-1; i++ {
|
|
assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A))
|
|
}
|
|
|
|
for i := 0; i < len(equalCodes); i++ {
|
|
//assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A))
|
|
w := CalculateWitness(io.inputs, r1css[i])
|
|
fmt.Println("witness")
|
|
fmt.Println(w)
|
|
assert.Equal(t, io.result, w[3])
|
|
|
|
}
|
|
|
|
}
|