You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

382 lines
9.5 KiB

package circuitcompiler
import (
"fmt"
"github.com/mottla/go-snark/r1csqap"
"math/big"
)
type Program struct {
functions map[string]*Circuit
signals []string
globalInputs []*Constraint
R1CS struct {
A [][]*big.Int
B [][]*big.Int
C [][]*big.Int
}
}
func (p *Program) PrintContraintTrees() {
for k, v := range p.functions {
fmt.Println(k)
PrintTree(v.root)
}
}
func (p *Program) BuildConstraintTrees() {
functionRootMap := make(map[string]*gate)
for _, circuit := range p.functions {
circuit.addConstraint(p.oneConstraint())
fName := composeNewFunction(circuit.Name, circuit.Inputs)
root := &gate{value: circuit.constraintMap[fName]}
functionRootMap[fName] = root
circuit.root = root
}
for _, circuit := range p.functions {
buildTree(circuit.constraintMap, circuit.root)
}
return
}
func buildTree(con map[string]*Constraint, g *gate) {
if _, ex := con[g.value.Out]; ex {
if g.OperationType()&(IN|CONST) != 0 {
return
}
} else {
panic(fmt.Sprintf("undefined variable %s", g.value.Out))
}
if g.OperationType() == FUNC {
g.funcInputs = []*gate{}
for _, in := range g.value.Inputs {
if constr, ex := con[in]; ex {
newGate := &gate{value: constr}
g.funcInputs = append(g.funcInputs, newGate)
buildTree(con, newGate)
} else {
panic(fmt.Sprintf("undefined value %s", g.value.V1))
}
}
return
}
if constr, ex := con[g.value.V1]; ex {
g.addLeft(constr)
buildTree(con, g.left)
} else {
panic(fmt.Sprintf("undefined value %s", g.value.V1))
}
if constr, ex := con[g.value.V2]; ex {
g.addRight(constr)
buildTree(con, g.right)
} else {
panic(fmt.Sprintf("undefined value %s", g.value.V2))
}
}
func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
mGatesUsed := make(map[string]bool)
orderedmGates = []gate{}
functionRootMap := make(map[string]*gate)
for k, v := range p.functions {
functionRootMap[k] = v.root
}
functionRenamer := func(c *Constraint) *gate {
if c.Op != FUNC {
panic("not a function")
}
if b, name, in := isFunction(c.Out); b {
if k, v := p.functions[name]; v {
//fmt.Println("unrenamed thing")
//PrintTree(k.root)
k.renameInputs(in)
//fmt.Println("renamed thing")
//PrintTree(k.root)
return k.root
}
} else {
panic("not a function dude")
}
return nil
}
traverseCombinedMultiplicationGates(p.getMainCircut().root, mGatesUsed, &orderedmGates, functionRootMap, functionRenamer, false, false)
//for _, g := range mGates {
// orderedmGates[len(orderedmGates)-1-g.index] = g
//}
return orderedmGates
}
func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, functionRootMap map[string]*gate, functionRenamer func(c *Constraint) *gate, negate bool, inverse bool) {
//if root == nil {
// return
//}
if root.OperationType() == FUNC {
//if a input has already been built, we let this subroutine know
newMap := make(map[string]bool)
for _, in := range root.funcInputs {
if _, ex := mGatesUsed[in.value.Out]; ex {
newMap[in.value.Out] = true
} else {
traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
}
}
//mGatesUsed[root.value.Out] = true
traverseCombinedMultiplicationGates(functionRenamer(root.value), newMap, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
} else {
if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
}
if _, alreadyComputed := mGatesUsed[root.value.V2]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
traverseCombinedMultiplicationGates(root.right, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
}
}
if root.OperationType() == MULTIPLY {
root.leftIns = make(map[string]int)
collectAtomsInSubtree(root.left, root.leftIns, functionRootMap, negate, inverse)
root.rightIns = make(map[string]int)
collectAtomsInSubtree(root.right, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
root.index = len(mGatesUsed)
mGatesUsed[root.value.Out] = true
rootGate := cloneGate(root)
*orderedmGates = append(*orderedmGates, *rootGate)
}
//TODO optimize if output is not a multipication gate
}
//copies a gate neglecting its references to other gates
func cloneGate(in *gate) (out *gate) {
constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
nRightins := make(map[string]int)
nLeftInst := make(map[string]int)
for k, v := range in.rightIns {
nRightins[k] = v
}
for k, v := range in.leftIns {
nLeftInst[k] = v
}
return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
}
func (p *Program) getMainCircut() *Circuit {
return p.functions["main"]
}
func (p *Program) addGlobalInput(c *Constraint) {
p.globalInputs = append(p.globalInputs, c)
}
func NewProgramm() *Program {
return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: CONST, Out: "one"}}}
}
func (p *Program) oneConstraint() *Constraint {
if p.globalInputs[0].Out != "one" {
panic("'one' should be first global input")
}
return p.globalInputs[0]
}
func (p *Program) addSignal(name string) {
p.signals = append(p.signals, name)
}
func (p *Program) addFunction(constraint *Constraint) (c *Circuit) {
name := constraint.Out
fmt.Println("try to add function ", name)
b, name2, _ := isFunction(name)
if !b {
panic(fmt.Sprintf("not a function: %v", constraint))
}
name = name2
if _, ex := p.functions[name]; ex {
panic("function already declared")
}
c = newCircuit(name)
p.functions[name] = c
//if constraint.Literal == "main" {
for _, in := range constraint.Inputs {
newConstr := &Constraint{
Op: IN,
Out: in,
}
if name == "main" {
p.addGlobalInput(newConstr)
}
c.addConstraint(newConstr)
}
c.Inputs = constraint.Inputs
return
}
// GenerateR1CS generates the R1CS polynomials from the Circuit
func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
// from flat code to R1CS
offset := len(p.globalInputs)
// one + in1 +in2+... + gate1 + gate2 .. + out
size := offset + len(mGates)
indexMap := make(map[string]int)
//circ.Signals = []string{"one"}
for i, v := range p.globalInputs {
indexMap[v.Out] = i
//circ.Signals = append(circ.Signals, v)
}
for i, v := range mGates {
indexMap[v.value.Out] = i + offset
//circ.Signals = append(circ.Signals, v.value.Out)
}
//circ.NVars = len(circ.Signals)
//circ.NSignals = len(circ.Signals)
for _, gate := range mGates {
if gate.OperationType() == MULTIPLY {
aConstraint := r1csqap.ArrayOfBigZeros(size)
bConstraint := r1csqap.ArrayOfBigZeros(size)
cConstraint := r1csqap.ArrayOfBigZeros(size)
for leftInput, val := range gate.leftIns {
insertVar3(aConstraint, val, leftInput, indexMap[leftInput])
}
for rightInput, val := range gate.rightIns {
insertVar3(bConstraint, val, rightInput, indexMap[rightInput])
}
cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
if gate.value.invert {
a = append(a, cConstraint)
b = append(b, bConstraint)
c = append(c, aConstraint)
} else {
a = append(a, aConstraint)
b = append(b, bConstraint)
c = append(c, cConstraint)
}
} else {
panic("not a m gate")
}
}
p.R1CS.A = a
p.R1CS.B = b
p.R1CS.C = c
return a, b, c
}
func insertVar3(arr []*big.Int, val int, input string, index int) {
isVal, value := isValue(input)
var valueBigInt *big.Int
if isVal {
valueBigInt = big.NewInt(int64(value))
arr[0] = new(big.Int).Add(arr[0], valueBigInt)
} else {
//if !indexMap[leftInput] {
// panic(errors.New("using variable before it's set"))
//}
valueBigInt = big.NewInt(int64(val))
arr[index] = new(big.Int).Add(arr[index], valueBigInt)
}
}
func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
if len(p.globalInputs)-1 != len(input) {
panic("input do not match the required inputs")
}
witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
set := make([]bool, len(witness))
witness[0] = big.NewInt(int64(1))
set[0] = true
for i := range input {
witness[i+1] = input[i]
set[i+1] = true
}
zero := big.NewInt(int64(0))
for i := 0; i < len(p.R1CS.A); i++ {
gatesLeftInputs := p.R1CS.A[i]
gatesRightInputs := p.R1CS.B[i]
gatesOutputs := p.R1CS.C[i]
sumLeft := big.NewInt(int64(0))
sumRight := big.NewInt(int64(0))
sumOut := big.NewInt(int64(0))
index := -1
division := false
for j, val := range gatesLeftInputs {
if val.Cmp(zero) != 0 {
if !set[j] {
index = j
division = true
break
}
sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
}
}
for j, val := range gatesRightInputs {
if val.Cmp(zero) != 0 {
sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
}
}
for j, val := range gatesOutputs {
if val.Cmp(zero) != 0 {
if !set[j] {
if index != -1 {
panic("invalid R1CS form")
}
index = j
break
}
sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
}
}
if !division {
set[index] = true
witness[index] = new(big.Int).Mul(sumLeft, sumRight)
} else {
b := sumRight.Int64()
c := sumOut.Int64()
set[index] = true
witness[index] = big.NewInt(c / b)
}
}
return
}