You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

162 lines
3.0 KiB

package circuitcompiler
import (
"fmt"
"github.com/stretchr/testify/assert"
"math/big"
"strings"
"testing"
)
type InOut struct {
inputs []*big.Int
result *big.Int
}
type TraceCorrectnessTest struct {
code string
io []InOut
}
var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
var correctnesTest = []TraceCorrectnessTest{
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(1729500084900343)),
}, {
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
result: bigNumberResult1,
}},
code: `
def main( x , z ) :
out = do(z) + add(x,x)
def do(x):
e = x * 5
b = e * 6
c = b * 7
f = c * 1
d = c * f
out = d * mul(d,e)
def add(x ,k):
z = k * x
out = do(x) + mul(x,z)
def mul(a,b):
out = a * b
`,
},
{io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(7))},
result: big.NewInt(int64(4)),
}},
code: `
def mul(a,b):
out = a * b
def main(a):
b = a * a
c = 4 - b
d = 5 * c
out = mul(d,c) / mul(b,b)
`,
},
{io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(22638)),
}, {
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
result: bigNumberResult2,
}},
code: `
def main(a,b):
d = b + b
c = a * d
e = c - a
out = e * c
`,
},
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
result: big.NewInt(int64(98441327276)),
}, {
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
result: big.NewInt(int64(8675445947220)),
}},
code: `
def main(a,b):
c = a + b
e = c - a
f = e + b
g = f + 2
out = g * a
`,
},
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(444675)),
}},
code: `
def main(a,b,c,d):
e = a * b
f = c * d
g = e * f
h = g / e
i = h * 5
out = g * i
`,
},
}
func TestNewProgramm(t *testing.T) {
for _, test := range correctnesTest {
parser := NewParser(strings.NewReader(test.code))
program, err := parser.Parse()
if err != nil {
panic(err)
}
fmt.Println("\n unreduced")
fmt.Println(test.code)
program.BuildConstraintTrees()
for k, v := range program.functions {
fmt.Println(k)
PrintTree(v.root)
}
fmt.Println("\nReduced gates")
//PrintTree(froots["mul"])
gates := program.ReduceCombinedTree()
for _, g := range gates {
fmt.Printf("\n %v", g)
}
fmt.Println("\n generating R1CS")
r1cs := program.GenerateReducedR1CS(gates)
fmt.Println(r1cs.A)
fmt.Println(r1cs.B)
fmt.Println(r1cs.C)
for _, io := range test.io {
inputs := io.inputs
fmt.Println("input")
fmt.Println(inputs)
w := CalculateWitness(inputs, r1cs)
fmt.Println("witness")
fmt.Println(w)
assert.Equal(t, io.result, w[program.GlobalInputCount()])
}
}
}