From de0e08149442360390d1cd7c2162c7d5ba723e07 Mon Sep 17 00:00:00 2001 From: mottla Date: Tue, 14 May 2019 16:35:51 +0200 Subject: [PATCH] reducing multiplication gates with constants --- README.md | 2 +- circuitcompiler/Programm.go | 115 +++++++++++++++++++++++++++---- circuitcompiler/Programm_test.go | 21 +++--- circuitcompiler/circuit.go | 33 +-------- r1csqap/r1csqap.go | 4 ++ snark_test.go | 25 +++---- 6 files changed, 130 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index 8c1794f..d46eeb2 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ UNDER CONSTRUCTION! Implementation of the zkSNARK [Pinocchio protocol](https://eprint.iacr.org/2013/279.pdf) from scratch in Go to understand the concepts. Do not use in production. -This forked aims to extend its functionalities s.t. one can prove set-membership in zero knowledge. +This fork aims to extend its functionalities s.t. one can prove set-membership in zero knowledge. Current implementation status: - [x] Finite Fields (1, 2, 6, 12) operations diff --git a/circuitcompiler/Programm.go b/circuitcompiler/Programm.go index 0369bc9..4085abb 100644 --- a/circuitcompiler/Programm.go +++ b/circuitcompiler/Programm.go @@ -28,7 +28,7 @@ func (p *Program) BuildConstraintTrees() { functionRootMap := make(map[string]*gate) for _, circuit := range p.functions { - circuit.addConstraint(p.oneConstraint()) + //circuit.addConstraint(p.oneConstraint()) fName := composeNewFunction(circuit.Name, circuit.Inputs) root := &gate{value: circuit.constraintMap[fName]} functionRootMap[fName] = root @@ -123,19 +123,20 @@ func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, //if root == nil { // return //} + //fmt.Printf("\n%p",mGatesUsed) if root.OperationType() == FUNC { //if a input has already been built, we let this subroutine know - newMap := make(map[string]bool) + //newMap := make(map[string]bool) for _, in := range root.funcInputs { if _, ex := mGatesUsed[in.value.Out]; ex { - newMap[in.value.Out] = true + //newMap[in.value.Out] = true } else { traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse) } } //mGatesUsed[root.value.Out] = true - traverseCombinedMultiplicationGates(functionRenamer(root.value), newMap, orderedmGates, functionRootMap, functionRenamer, negate, inverse) + traverseCombinedMultiplicationGates(functionRenamer(root.value), mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse) } else { if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 { traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse) @@ -148,12 +149,24 @@ func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, if root.OperationType() == MULTIPLY { + _, n, _ := isFunction(root.value.Out) + if (root.left.OperationType()|root.right.OperationType())&CONST != 0 && n != "main" { + return + } + root.leftIns = make(map[string]int) - collectAtomsInSubtree(root.left, root.leftIns, functionRootMap, negate, inverse) + collectAtomsInSubtree(root.left, mGatesUsed, 1, root.leftIns, functionRootMap, negate, inverse) root.rightIns = make(map[string]int) - collectAtomsInSubtree(root.right, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert)) + //if root.left.value.Out== root.right.value.Out{ + // //note this is not a full copy, but shouldnt be a problem + // root.rightIns= root.leftIns + //}else{ + // collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert)) + //} + collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert)) root.index = len(mGatesUsed) mGatesUsed[root.value.Out] = true + rootGate := cloneGate(root) *orderedmGates = append(*orderedmGates, *rootGate) } @@ -161,6 +174,63 @@ func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, //TODO optimize if output is not a multipication gate } +func collectAtomsInSubtree(g *gate, mGatesUsed map[string]bool, multiplicative int, in map[string]int, functionRootMap map[string]*gate, negate bool, invert bool) { + if g == nil { + return + } + if _, ex := mGatesUsed[g.value.Out]; ex { + addToMap(g.value.Out, multiplicative, in, negate) + return + } + + if g.OperationType()&(IN|CONST) != 0 { + addToMap(g.value.Out, multiplicative, in, negate) + return + } + + if g.OperationType()&(MULTIPLY) != 0 { + b1, v1 := isValue(g.value.V1) + b2, v2 := isValue(g.value.V2) + + if b1 && !b2 { + multiplicative *= v1 + collectAtomsInSubtree(g.right, mGatesUsed, multiplicative, in, functionRootMap, Xor(negate, g.value.negate), invert) + return + } else if !b1 && b2 { + multiplicative *= v2 + collectAtomsInSubtree(g.left, mGatesUsed, multiplicative, in, functionRootMap, negate, invert) + return + } else if b1 && b2 { + panic("multiply constants not supported yet") + } else { + panic("werird") + } + } + if g.OperationType() == FUNC { + if b, name, _ := isFunction(g.value.Out); b { + collectAtomsInSubtree(functionRootMap[name], mGatesUsed, multiplicative, in, functionRootMap, negate, invert) + + } else { + panic("function expected") + } + + } + collectAtomsInSubtree(g.left, mGatesUsed, multiplicative, in, functionRootMap, negate, invert) + collectAtomsInSubtree(g.right, mGatesUsed, multiplicative, in, functionRootMap, Xor(negate, g.value.negate), invert) + +} + +func addOneToMap(value string, in map[string]int, negate bool) { + addToMap(value, 1, in, negate) +} +func addToMap(value string, val int, in map[string]int, negate bool) { + if negate { + in[value] = (in[value] - 1) * val + } else { + in[value] = (in[value] + 1) * val + } +} + //copies a gate neglecting its references to other gates func cloneGate(in *gate) (out *gate) { constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1} @@ -184,15 +254,16 @@ func (p *Program) addGlobalInput(c *Constraint) { } func NewProgramm() *Program { - return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: CONST, Out: "one"}}} + //return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: PLUS, V1:"1",V2:"0", Out: "one"}}} + return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: IN, Out: "one"}}} } -func (p *Program) oneConstraint() *Constraint { - if p.globalInputs[0].Out != "one" { - panic("'one' should be first global input") - } - return p.globalInputs[0] -} +//func (p *Program) oneConstraint() *Constraint { +// if p.globalInputs[0].Out != "one" { +// panic("'one' should be first global input") +// } +// return p.globalInputs[0] +//} func (p *Program) addSignal(name string) { p.signals = append(p.signals, name) @@ -262,7 +333,25 @@ func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) { bConstraint := r1csqap.ArrayOfBigZeros(size) cConstraint := r1csqap.ArrayOfBigZeros(size) + //if len(gate.leftIns)>=len(gate.rightIns){ + // for leftInput, _ := range gate.leftIns { + // if v, ex := gate.rightIns[leftInput]; ex { + // gate.leftIns[leftInput] *= v + // gate.rightIns[leftInput] = 1 + // + // } + // } + //}else{ + // for rightInput, _ := range gate.rightIns { + // if v, ex := gate.leftIns[rightInput]; ex { + // gate.rightIns[rightInput] *= v + // gate.leftIns[rightInput] = 1 + // } + // } + //} + for leftInput, val := range gate.leftIns { + insertVar3(aConstraint, val, leftInput, indexMap[leftInput]) } for rightInput, val := range gate.rightIns { diff --git a/circuitcompiler/Programm_test.go b/circuitcompiler/Programm_test.go index e25895d..2a84951 100644 --- a/circuitcompiler/Programm_test.go +++ b/circuitcompiler/Programm_test.go @@ -17,29 +17,30 @@ func TestNewProgramm(t *testing.T) { flat := ` func do(x): - e = x * x - b = e * e - c = b * b - d = c * c + e = x * 5 + b = e * 6 + c = b * 7 + f = c * 1 + d = c / f out = d * 1 func add(x ,k): z = k * x out = do(x) + mul(x,z) - func main(a,b): - out = do(5) + 4 + func main(x,z): + out = do(x) + 4 func mul(a,b): out = a * b ` //flat := ` - //func do(x): - // b = x - 2 - // out = x * b //func main(a,b): - // out = do(a) + 4 + // e = 4 * a + // c = a * e + // d = c * 70 + // out = a * d //` parser := NewParser(strings.NewReader(flat)) program, err := parser.Parse() diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index 7353eab..791ce0f 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -109,13 +109,15 @@ func (circ *Circuit) addConstraint(constraint *Constraint) { //the main functions output must be a multiplication gate //if its not, then we simple create one where outNew = 1 * outOld if constraint.Op&(MINUS|PLUS) != 0 { - newOut := &Constraint{Out: constraint.Out, V1: "one", V2: "out2", Op: MULTIPLY} + newOut := &Constraint{Out: constraint.Out, V1: "1", V2: "out2", Op: MULTIPLY} + //TODO reachable? delete(circ.constraintMap, constraint.Out) circ.addConstraint(newOut) constraint.Out = "out2" circ.addConstraint(constraint) } } + } addConstantsAndFunctions := func(constraint string) { @@ -274,35 +276,6 @@ func printTree(g *gate, d int) { } } -func addToMap(value string, in map[string]int, negate bool) { - if negate { - in[value] = in[value] - 1 - } else { - in[value] = in[value] + 1 - } -} - -func collectAtomsInSubtree(g *gate, in map[string]int, functionRootMap map[string]*gate, negate bool, invert bool) { - if g == nil { - return - } - if g.OperationType()&(MULTIPLY|IN|CONST) != 0 { - addToMap(g.value.Out, in, negate) - return - } - if g.OperationType() == FUNC { - if b, name, _ := isFunction(g.value.Out); b { - collectAtomsInSubtree(functionRootMap[name], in, functionRootMap, negate, invert) - } else { - panic("function expected") - } - - } - - collectAtomsInSubtree(g.left, in, functionRootMap, negate, invert) - collectAtomsInSubtree(g.right, in, functionRootMap, Xor(negate, g.value.negate), invert) -} - func Xor(a, b bool) bool { return (a && !b) || (!a && b) } diff --git a/r1csqap/r1csqap.go b/r1csqap/r1csqap.go index 824a247..f16fd01 100644 --- a/r1csqap/r1csqap.go +++ b/r1csqap/r1csqap.go @@ -2,6 +2,7 @@ package r1csqap import ( "bytes" + "fmt" "math/big" "github.com/mottla/go-snark/fields" @@ -162,6 +163,9 @@ func (pf PolynomialField) R1CSToQAP(a, b, c [][]*big.Int) ([][]*big.Int, [][]*bi aT := Transpose(a) bT := Transpose(b) cT := Transpose(c) + fmt.Println(aT) + fmt.Println(bT) + fmt.Println(cT) var alphas [][]*big.Int for i := 0; i < len(aT); i++ { alphas = append(alphas, pf.LagrangeInterpolation(aT[i])) diff --git a/snark_test.go b/snark_test.go index 32f9d55..c3e4956 100644 --- a/snark_test.go +++ b/snark_test.go @@ -3,7 +3,6 @@ package snark import ( "fmt" "github.com/mottla/go-snark/circuitcompiler" - "github.com/mottla/go-snark/r1csqap" "github.com/stretchr/testify/assert" "math/big" "strings" @@ -13,16 +12,10 @@ import ( func TestNewProgramm(t *testing.T) { flat := ` - - func add(x ,k): - z = k * x - out = x + mul(x,z) - - func main(a,b): - out = add(a,b) * a - - func mul(a,b): - out = a * b + func main(a,b,c,d): + e = a + b + f = c * d + out = e * f ` parser := circuitcompiler.NewParser(strings.NewReader(flat)) @@ -35,7 +28,7 @@ func TestNewProgramm(t *testing.T) { fmt.Println(flat) program.BuildConstraintTrees() - program.PrintConstraintTrees() + program.PrintContraintTrees() fmt.Println("\nReduced gates") //PrintTree(froots["mul"]) gates := program.ReduceCombinedTree() @@ -50,7 +43,7 @@ func TestNewProgramm(t *testing.T) { fmt.Println(c) a1 := big.NewInt(int64(6)) a2 := big.NewInt(int64(5)) - inputs := []*big.Int{a1, a2} + inputs := []*big.Int{a1, a2, a1, a2} w := program.CalculateWitness(inputs) fmt.Println("witness") fmt.Println(w) @@ -82,9 +75,9 @@ func TestNewProgramm(t *testing.T) { hzQAP := Utils.PF.Mul(hxQAP, zxQAP) assert.Equal(t, abc, hzQAP) - div, rem := Utils.PF.Div(px, zxQAP) - assert.Equal(t, hxQAP, div) - assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4)) + //div, rem := Utils.PF.Div(px, zxQAP) + //assert.Equal(t, hxQAP, div) + //assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4)) // calculate trusted setup //setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)