package circuitcompiler
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/mottla/go-snark/r1csqap"
|
|
"math/big"
|
|
)
|
|
|
|
type Program struct {
|
|
functions map[string]*Circuit
|
|
signals []string
|
|
globalInputs []*Constraint
|
|
R1CS struct {
|
|
A [][]*big.Int
|
|
B [][]*big.Int
|
|
C [][]*big.Int
|
|
}
|
|
}
|
|
|
|
func (p *Program) PrintContraintTrees() {
|
|
for k, v := range p.functions {
|
|
fmt.Println(k)
|
|
PrintTree(v.root)
|
|
}
|
|
}
|
|
|
|
func (p *Program) BuildConstraintTrees() {
|
|
|
|
functionRootMap := make(map[string]*gate)
|
|
for _, circuit := range p.functions {
|
|
circuit.addConstraint(p.oneConstraint())
|
|
fName := composeNewFunction(circuit.Name, circuit.Inputs)
|
|
root := &gate{value: circuit.constraintMap[fName]}
|
|
functionRootMap[fName] = root
|
|
circuit.root = root
|
|
}
|
|
|
|
for _, circuit := range p.functions {
|
|
|
|
buildTree(circuit.constraintMap, circuit.root)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
func buildTree(con map[string]*Constraint, g *gate) {
|
|
if _, ex := con[g.value.Out]; ex {
|
|
if g.OperationType()&(IN|CONST) != 0 {
|
|
return
|
|
}
|
|
} else {
|
|
panic(fmt.Sprintf("undefined variable %s", g.value.Out))
|
|
}
|
|
if g.OperationType() == FUNC {
|
|
g.funcInputs = []*gate{}
|
|
for _, in := range g.value.Inputs {
|
|
if constr, ex := con[in]; ex {
|
|
newGate := &gate{value: constr}
|
|
g.funcInputs = append(g.funcInputs, newGate)
|
|
buildTree(con, newGate)
|
|
} else {
|
|
panic(fmt.Sprintf("undefined value %s", g.value.V1))
|
|
}
|
|
}
|
|
return
|
|
}
|
|
if constr, ex := con[g.value.V1]; ex {
|
|
g.addLeft(constr)
|
|
buildTree(con, g.left)
|
|
} else {
|
|
panic(fmt.Sprintf("undefined value %s", g.value.V1))
|
|
}
|
|
|
|
if constr, ex := con[g.value.V2]; ex {
|
|
g.addRight(constr)
|
|
buildTree(con, g.right)
|
|
} else {
|
|
panic(fmt.Sprintf("undefined value %s", g.value.V2))
|
|
}
|
|
}
|
|
|
|
func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
|
|
mGatesUsed := make(map[string]bool)
|
|
orderedmGates = []gate{}
|
|
functionRootMap := make(map[string]*gate)
|
|
for k, v := range p.functions {
|
|
functionRootMap[k] = v.root
|
|
}
|
|
|
|
functionRenamer := func(c *Constraint) *gate {
|
|
|
|
if c.Op != FUNC {
|
|
panic("not a function")
|
|
}
|
|
if b, name, in := isFunction(c.Out); b {
|
|
|
|
if k, v := p.functions[name]; v {
|
|
//fmt.Println("unrenamed thing")
|
|
//PrintTree(k.root)
|
|
k.renameInputs(in)
|
|
//fmt.Println("renamed thing")
|
|
//PrintTree(k.root)
|
|
return k.root
|
|
}
|
|
} else {
|
|
panic("not a function dude")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
traverseCombinedMultiplicationGates(p.getMainCircut().root, mGatesUsed, &orderedmGates, functionRootMap, functionRenamer, false, false)
|
|
|
|
//for _, g := range mGates {
|
|
// orderedmGates[len(orderedmGates)-1-g.index] = g
|
|
//}
|
|
|
|
return orderedmGates
|
|
}
|
|
|
|
func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, functionRootMap map[string]*gate, functionRenamer func(c *Constraint) *gate, negate bool, inverse bool) {
|
|
//if root == nil {
|
|
// return
|
|
//}
|
|
if root.OperationType() == FUNC {
|
|
//if a input has already been built, we let this subroutine know
|
|
newMap := make(map[string]bool)
|
|
for _, in := range root.funcInputs {
|
|
|
|
if _, ex := mGatesUsed[in.value.Out]; ex {
|
|
newMap[in.value.Out] = true
|
|
} else {
|
|
traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
|
|
}
|
|
}
|
|
//mGatesUsed[root.value.Out] = true
|
|
traverseCombinedMultiplicationGates(functionRenamer(root.value), newMap, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
|
|
} else {
|
|
if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
|
|
traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
|
|
}
|
|
|
|
if _, alreadyComputed := mGatesUsed[root.value.V2]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
|
|
traverseCombinedMultiplicationGates(root.right, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
|
|
}
|
|
}
|
|
|
|
if root.OperationType() == MULTIPLY {
|
|
|
|
root.leftIns = make(map[string]int)
|
|
collectAtomsInSubtree(root.left, root.leftIns, functionRootMap, negate, inverse)
|
|
root.rightIns = make(map[string]int)
|
|
collectAtomsInSubtree(root.right, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
|
|
root.index = len(mGatesUsed)
|
|
mGatesUsed[root.value.Out] = true
|
|
rootGate := cloneGate(root)
|
|
*orderedmGates = append(*orderedmGates, *rootGate)
|
|
}
|
|
|
|
//TODO optimize if output is not a multipication gate
|
|
}
|
|
|
|
//copies a gate neglecting its references to other gates
|
|
func cloneGate(in *gate) (out *gate) {
|
|
constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
|
|
nRightins := make(map[string]int)
|
|
nLeftInst := make(map[string]int)
|
|
for k, v := range in.rightIns {
|
|
nRightins[k] = v
|
|
}
|
|
for k, v := range in.leftIns {
|
|
nLeftInst[k] = v
|
|
}
|
|
return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
|
|
}
|
|
|
|
func (p *Program) getMainCircut() *Circuit {
|
|
return p.functions["main"]
|
|
}
|
|
|
|
func (p *Program) addGlobalInput(c *Constraint) {
|
|
p.globalInputs = append(p.globalInputs, c)
|
|
}
|
|
|
|
func NewProgramm() *Program {
|
|
return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: CONST, Out: "one"}}}
|
|
}
|
|
|
|
func (p *Program) oneConstraint() *Constraint {
|
|
if p.globalInputs[0].Out != "one" {
|
|
panic("'one' should be first global input")
|
|
}
|
|
return p.globalInputs[0]
|
|
}
|
|
|
|
func (p *Program) addSignal(name string) {
|
|
p.signals = append(p.signals, name)
|
|
}
|
|
|
|
func (p *Program) addFunction(constraint *Constraint) (c *Circuit) {
|
|
name := constraint.Out
|
|
fmt.Println("try to add function ", name)
|
|
|
|
b, name2, _ := isFunction(name)
|
|
if !b {
|
|
panic(fmt.Sprintf("not a function: %v", constraint))
|
|
}
|
|
name = name2
|
|
|
|
if _, ex := p.functions[name]; ex {
|
|
panic("function already declared")
|
|
}
|
|
|
|
c = newCircuit(name)
|
|
|
|
p.functions[name] = c
|
|
|
|
//if constraint.Literal == "main" {
|
|
for _, in := range constraint.Inputs {
|
|
newConstr := &Constraint{
|
|
Op: IN,
|
|
Out: in,
|
|
}
|
|
if name == "main" {
|
|
p.addGlobalInput(newConstr)
|
|
}
|
|
c.addConstraint(newConstr)
|
|
}
|
|
|
|
c.Inputs = constraint.Inputs
|
|
return
|
|
|
|
}
|
|
|
|
// GenerateR1CS generates the R1CS polynomials from the Circuit
|
|
func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
|
|
// from flat code to R1CS
|
|
|
|
offset := len(p.globalInputs)
|
|
// one + in1 +in2+... + gate1 + gate2 .. + out
|
|
size := offset + len(mGates)
|
|
indexMap := make(map[string]int)
|
|
|
|
//circ.Signals = []string{"one"}
|
|
for i, v := range p.globalInputs {
|
|
indexMap[v.Out] = i
|
|
//circ.Signals = append(circ.Signals, v)
|
|
|
|
}
|
|
for i, v := range mGates {
|
|
indexMap[v.value.Out] = i + offset
|
|
//circ.Signals = append(circ.Signals, v.value.Out)
|
|
}
|
|
//circ.NVars = len(circ.Signals)
|
|
//circ.NSignals = len(circ.Signals)
|
|
|
|
for _, gate := range mGates {
|
|
|
|
if gate.OperationType() == MULTIPLY {
|
|
aConstraint := r1csqap.ArrayOfBigZeros(size)
|
|
bConstraint := r1csqap.ArrayOfBigZeros(size)
|
|
cConstraint := r1csqap.ArrayOfBigZeros(size)
|
|
|
|
for leftInput, val := range gate.leftIns {
|
|
insertVar3(aConstraint, val, leftInput, indexMap[leftInput])
|
|
}
|
|
for rightInput, val := range gate.rightIns {
|
|
insertVar3(bConstraint, val, rightInput, indexMap[rightInput])
|
|
}
|
|
cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
|
|
|
|
if gate.value.invert {
|
|
a = append(a, cConstraint)
|
|
b = append(b, bConstraint)
|
|
c = append(c, aConstraint)
|
|
} else {
|
|
a = append(a, aConstraint)
|
|
b = append(b, bConstraint)
|
|
c = append(c, cConstraint)
|
|
}
|
|
|
|
} else {
|
|
panic("not a m gate")
|
|
}
|
|
}
|
|
p.R1CS.A = a
|
|
p.R1CS.B = b
|
|
p.R1CS.C = c
|
|
return a, b, c
|
|
}
|
|
|
|
func insertVar3(arr []*big.Int, val int, input string, index int) {
|
|
isVal, value := isValue(input)
|
|
var valueBigInt *big.Int
|
|
if isVal {
|
|
valueBigInt = big.NewInt(int64(value))
|
|
arr[0] = new(big.Int).Add(arr[0], valueBigInt)
|
|
} else {
|
|
//if !indexMap[leftInput] {
|
|
// panic(errors.New("using variable before it's set"))
|
|
//}
|
|
valueBigInt = big.NewInt(int64(val))
|
|
arr[index] = new(big.Int).Add(arr[index], valueBigInt)
|
|
}
|
|
|
|
}
|
|
|
|
func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
|
|
|
|
if len(p.globalInputs)-1 != len(input) {
|
|
panic("input do not match the required inputs")
|
|
}
|
|
|
|
witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
|
|
set := make([]bool, len(witness))
|
|
witness[0] = big.NewInt(int64(1))
|
|
set[0] = true
|
|
|
|
for i := range input {
|
|
witness[i+1] = input[i]
|
|
set[i+1] = true
|
|
}
|
|
|
|
zero := big.NewInt(int64(0))
|
|
|
|
for i := 0; i < len(p.R1CS.A); i++ {
|
|
gatesLeftInputs := p.R1CS.A[i]
|
|
gatesRightInputs := p.R1CS.B[i]
|
|
gatesOutputs := p.R1CS.C[i]
|
|
|
|
sumLeft := big.NewInt(int64(0))
|
|
sumRight := big.NewInt(int64(0))
|
|
sumOut := big.NewInt(int64(0))
|
|
|
|
index := -1
|
|
division := false
|
|
for j, val := range gatesLeftInputs {
|
|
if val.Cmp(zero) != 0 {
|
|
if !set[j] {
|
|
index = j
|
|
division = true
|
|
break
|
|
}
|
|
sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
|
|
}
|
|
}
|
|
for j, val := range gatesRightInputs {
|
|
if val.Cmp(zero) != 0 {
|
|
sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
|
|
}
|
|
}
|
|
|
|
for j, val := range gatesOutputs {
|
|
if val.Cmp(zero) != 0 {
|
|
if !set[j] {
|
|
if index != -1 {
|
|
panic("invalid R1CS form")
|
|
}
|
|
|
|
index = j
|
|
break
|
|
}
|
|
sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
|
|
}
|
|
}
|
|
|
|
if !division {
|
|
set[index] = true
|
|
witness[index] = new(big.Int).Mul(sumLeft, sumRight)
|
|
|
|
} else {
|
|
b := sumRight.Int64()
|
|
c := sumOut.Int64()
|
|
set[index] = true
|
|
witness[index] = big.NewInt(c / b)
|
|
}
|
|
|
|
}
|
|
|
|
return
|
|
}
|