From 2fc30d452ab4850fc542be261966f6042d915b7f Mon Sep 17 00:00:00 2001 From: mottla Date: Sat, 1 Jun 2019 19:17:11 +0200 Subject: [PATCH] multiplication gate reduction via reusing old gates with equivalent behaviour multiplication gate reduction via reusing old gates with equivalent behaviour --- circuitcompiler/Programm.go | 146 +++++++++++++++++-------------- circuitcompiler/Programm_test.go | 34 +++---- circuitcompiler/test.py | 68 ++++++++++++++ snark.go | 9 +- snark_test.go | 69 +++++++-------- 5 files changed, 202 insertions(+), 124 deletions(-) create mode 100644 circuitcompiler/test.py diff --git a/circuitcompiler/Programm.go b/circuitcompiler/Programm.go index eb6ff8c..c122c3b 100644 --- a/circuitcompiler/Programm.go +++ b/circuitcompiler/Programm.go @@ -17,17 +17,28 @@ type utils struct { PF r1csqap.PolynomialField } +type R1CS struct { + A [][]*big.Int + B [][]*big.Int + C [][]*big.Int +} type Program struct { functions map[string]*Circuit globalInputs []string arithmeticEnvironment utils //find a better name sha256Hasher hash.Hash - computedInContext map[string]map[string]string - R1CS struct { - A [][]*big.Int - B [][]*big.Int - C [][]*big.Int - } + + //key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function. + //value 1 : map + // with key variable name + // with value variable name + hash Chain + //this datastructure is nice but maybe ill replace it later with something less confusing + //it serves the elementary purpose of not computing a variable a second time + computedInContext map[string]map[string]string + + //to reduce the number of multiplication gates, we store each factor signature, and the variable name, + //so each time a variable is computed, that happens to have the very same factors, we reuse the former gate + computedFactors map[string]string } func (p *Program) PrintContraintTrees() { @@ -48,9 +59,6 @@ func (p *Program) BuildConstraintTrees() { p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot } - //for _, in := range p.getMainCircuit().Inputs { - // p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs)) - //} for _, in := range p.getMainCircuit().Inputs { p.globalInputs = append(p.globalInputs, in) } @@ -58,10 +66,11 @@ func (p *Program) BuildConstraintTrees() { for _, circuit := range p.functions { wg.Add(1) - func() { - circuit.buildTree(circuit.root) + //interesting: if circuit is not passed as argument, the program fails. duno why.. + go func(c *Circuit) { + c.buildTree(c.root) wg.Done() - }() + }(circuit) } wg.Wait() @@ -109,12 +118,17 @@ func (p *Program) ReduceCombinedTree() (orderedmGates []gate) { //mGatesUsed := make(map[string]bool) orderedmGates = []gate{} p.computedInContext = make(map[string]map[string]string) + p.computedFactors = make(map[string]string) rootHash := []byte{} p.computedInContext[string(rootHash)] = make(map[string]string) p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false) return orderedmGates } +//recursively walks through the parse tree to create a list of all +//multiplication gates needed for the QAP construction +//Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so +//TODO if the same variable that has been computed in context A, is needed again but from a different context b, will be recomputed and not reused func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) { if node.OperationType() == CONST { @@ -163,6 +177,12 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root { return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons } + sig := factorsSignature(leftFactors, rightFactors) + if out, ex := p.computedFactors[sig]; ex { + return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true + + } + rootGate := cloneGate(node) rootGate.index = len(*orderedmGates) rootGate.leftIns = leftFactors @@ -172,6 +192,7 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10]) rootGate.value.Out = rootGate.value.Out + string(out[:10]) p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out + p.computedFactors[sig] = rootGate.value.Out *orderedmGates = append(*orderedmGates, *rootGate) hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out)) @@ -186,7 +207,6 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr panic("unexpected gate") } - //TODO optimize if output is not a multipication gate } type factor struct { @@ -214,6 +234,19 @@ func mul2DVector(a, b [2]int) [2]int { return [2]int{a[0] * b[0], a[1] * b[1]} } +func factorsSignature(leftFactors, rightFactors []factor) string { + hasher.Reset() + //not using a kommutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a) + //could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard + for _, facLeft := range leftFactors { + hasher.Write([]byte(facLeft.String())) + } + for _, Righ := range rightFactors { + hasher.Write([]byte(Righ.String())) + } + return string(hasher.Sum(nil))[:16] +} + func mulFactors(leftFactors, rightFactors []factor) (result []factor) { for _, facLeft := range leftFactors { @@ -248,7 +281,6 @@ func mulFactors(leftFactors, rightFactors []factor) (result []factor) { //continue } - fmt.Println("dsf") panic("unexpected") } @@ -340,11 +372,6 @@ func (p *Program) getMainCircuit() *Circuit { return p.functions["main"] } -//func (p *Program) addGlobalInput(c Constraint) { -// c.Out = "main@" + c.Out -// p.globalInputs = append(p.globalInputs, c) -//} - func prepareUtils() utils { bn, err := bn128.NewBn128() if err != nil { @@ -424,7 +451,7 @@ func NewProgram() (p *Program) { } // GenerateR1CS generates the R1CS polynomials from the Circuit -func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) { +func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) { // from flat code to R1CS offset := len(p.globalInputs) @@ -440,51 +467,52 @@ func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) { indexMap[v.value.Out] = i + offset } - for _, gate := range mGates { + for _, g := range mGates { - if gate.OperationType() == MULTIPLY { + if g.OperationType() == MULTIPLY { aConstraint := r1csqap.ArrayOfBigZeros(size) bConstraint := r1csqap.ArrayOfBigZeros(size) cConstraint := r1csqap.ArrayOfBigZeros(size) - for _, val := range gate.leftIns { + insertValue := func(val factor, arr []*big.Int) { if val.typ != CONST { if _, ex := indexMap[val.name]; !ex { panic(fmt.Sprintf("%v index not found!!!", val.name)) } } - convertAndInsertFactorAt(aConstraint, val, indexMap[val.name]) + value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative)) + if val.negate { + value.Neg(value) + } + //not that index is 0 if its a constant, since 0 is the map default if no entry was found + arr[indexMap[val.name]] = value } - for _, val := range gate.rightIns { - if val.typ != CONST { - if _, ex := indexMap[val.name]; !ex { - panic(fmt.Sprintf("%v index not found!!!", val.name)) - } - } + for _, val := range g.leftIns { + insertValue(val, aConstraint) + } - convertAndInsertFactorAt(bConstraint, val, indexMap[val.name]) + for _, val := range g.rightIns { + insertValue(val, bConstraint) } - cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1)) + cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1)) - if gate.value.invert { + if g.value.invert { tmp := aConstraint aConstraint = cConstraint cConstraint = tmp } - a = append(a, aConstraint) - b = append(b, bConstraint) - c = append(c, cConstraint) + r1CS.A = append(r1CS.A, aConstraint) + r1CS.B = append(r1CS.B, bConstraint) + r1CS.C = append(r1CS.C, cConstraint) } else { panic("not a m gate") } } - p.R1CS.A = a - p.R1CS.B = b - p.R1CS.C = c - return a, b, c + + return } var Utils = prepareUtils() @@ -494,25 +522,15 @@ func fractionToField(in [2]int) *big.Int { } -func convertAndInsertFactorAt(arr []*big.Int, val factor, index int) { - value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative)) - - if val.negate { - value.Neg(value) - } - - //not that index is 0 if its a constant, since 0 is the map default if no entry was found - arr[index] = value +//Calculates the witness (program trace) given some input +//asserts that R1CS has been computed and is stored in the program p memory calling this function +func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) { -} - -func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) { - - if len(p.globalInputs)-1 != len(input) { - panic("input do not match the required inputs") - } + //if len(p.globalInputs)-1 != len(input) { + // panic("input do not match the required inputs") + //} - witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0])) + witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0])) set := make([]bool, len(witness)) witness[0] = big.NewInt(int64(1)) set[0] = true @@ -524,10 +542,10 @@ func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) { zero := big.NewInt(int64(0)) - for i := 0; i < len(p.R1CS.A); i++ { - gatesLeftInputs := p.R1CS.A[i] - gatesRightInputs := p.R1CS.B[i] - gatesOutputs := p.R1CS.C[i] + for i := 0; i < len(r1cs.A); i++ { + gatesLeftInputs := r1cs.A[i] + gatesRightInputs := r1cs.B[i] + gatesOutputs := r1cs.C[i] sumLeft := big.NewInt(int64(0)) sumRight := big.NewInt(int64(0)) @@ -583,12 +601,6 @@ func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) { var hasher = sha256.New() -func hashFactorWithContext(f factor, currentCircuit *Circuit) []byte { - hasher.Reset() - hasher.Write([]byte(f.name)) - hasher.Write([]byte(currentCircuit.currentOutputName())) - return hasher.Sum(nil) -} func hashTogether(a, b []byte) []byte { hasher.Reset() hasher.Write(a) diff --git a/circuitcompiler/Programm_test.go b/circuitcompiler/Programm_test.go index 8604846..0bdaa1b 100644 --- a/circuitcompiler/Programm_test.go +++ b/circuitcompiler/Programm_test.go @@ -101,18 +101,18 @@ var correctnesTest = []TraceCorrectnessTest{ } func TestNewProgramm(t *testing.T) { - flat := ` - - def doSomething(x ,k): - z = k * x - out = 6 + mul(x,z) - - def main(x,z): - out = mul(x,z) - doSomething(x,x) - - def mul(a,b): - out = a * b - ` + //flat := ` + // + //def doSomething(x ,k): + // z = k * x + // out = 6 + mul(x,z) + // + //def main(x,z): + // out = mul(x,z) - doSomething(x,x) + // + //def mul(a,b): + // out = a * b + //` for _, test := range correctnesTest { parser := NewParser(strings.NewReader(test.code)) @@ -138,16 +138,16 @@ func TestNewProgramm(t *testing.T) { } fmt.Println("\n generating R1CS") - a, b, c := program.GenerateReducedR1CS(gates) - fmt.Println(a) - fmt.Println(b) - fmt.Println(c) + r1cs := program.GenerateReducedR1CS(gates) + fmt.Println(r1cs.A) + fmt.Println(r1cs.B) + fmt.Println(r1cs.C) for _, io := range test.io { inputs := io.inputs fmt.Println("input") fmt.Println(inputs) - w := program.CalculateWitness(inputs) + w := CalculateWitness(inputs, r1cs) fmt.Println("witness") fmt.Println(w) assert.Equal(t, io.result, w[len(w)-1]) diff --git a/circuitcompiler/test.py b/circuitcompiler/test.py new file mode 100644 index 0000000..c5d079f --- /dev/null +++ b/circuitcompiler/test.py @@ -0,0 +1,68 @@ +# def do(x): +# e = x * 5 +# b = e * 6 +# c = b * 7 +# f = c * 1 +# d = c * f +# return d * mul(d,e) +# +# def add(x ,k): +# z = k * x +# return do(x) + mul(x,z) +# +# +# def mul(a,b): +# return a * b +# +# def main(): +# x=365235 +# z=11876525 +# print(do(z) + add(x,x)) + +################################ + +# def add(x ,k): +# z = k * x +# return 6 + mul(x,z) + +# def asdf(a,b): +# d = b + b +# c = a * d +# e = c - a +# return e * c +# +# def asdf(a,b): +# c = a + b +# e = c - a +# f = e + b +# g = f + 2 +# return g * a + +def doSomething(x ,k): + z = k * x + return 6 + mul(x,z) + +# def main(x,z): +# out = + +def mul(a,b): + return a * b + +def main(): + x=64341 + z=76548465 + + print(mul(x,z) - doSomething(x,x)) +# +# def mul(a,b): +# return a * b +# +# def asdf(a): +# b = a * a +# c = 4 - b +# d = 5 * c +# return mul(d,c) / mul(b,b) + +if __name__ == '__main__': + #pascal(8) + main() \ No newline at end of file diff --git a/snark.go b/snark.go index d5a7750..dca9752 100644 --- a/snark.go +++ b/snark.go @@ -6,7 +6,6 @@ import ( "os" "github.com/mottla/go-snark/bn128" - "github.com/mottla/go-snark/circuitcompiler" "github.com/mottla/go-snark/fields" "github.com/mottla/go-snark/r1csqap" ) @@ -246,7 +245,7 @@ func GenerateTrustedSetup(witnessLength int, alphas, betas, gammas [][]*big.Int) } // GenerateProofs generates all the parameters to proof the zkSNARK from the Circuit, Setup and the Witness -func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int, px []*big.Int) (Proof, error) { +func GenerateProofs(setup Setup, nInputs int, w []*big.Int, px []*big.Int) (Proof, error) { var proof Proof proof.PiA = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} proof.PiAp = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} @@ -257,12 +256,12 @@ func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int, proof.PiH = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} proof.PiKp = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} - for i := circuit.NPublic + 1; i < circuit.NVars; i++ { + for i := nInputs; i < len(w)-1; i++ { proof.PiA = Utils.Bn.G1.Add(proof.PiA, Utils.Bn.G1.MulScalar(setup.Pk.A[i], w[i])) proof.PiAp = Utils.Bn.G1.Add(proof.PiAp, Utils.Bn.G1.MulScalar(setup.Pk.Ap[i], w[i])) } - for i := 0; i < circuit.NVars; i++ { + for i := 0; i < len(w); i++ { proof.PiB = Utils.Bn.G2.Add(proof.PiB, Utils.Bn.G2.MulScalar(setup.Pk.B[i], w[i])) proof.PiBp = Utils.Bn.G1.Add(proof.PiBp, Utils.Bn.G1.MulScalar(setup.Pk.Bp[i], w[i])) @@ -284,7 +283,7 @@ func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int, } // VerifyProof verifies over the BN128 the Pairings of the Proof -func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publicSignals []*big.Int, debug bool) bool { +func VerifyProof(setup Setup, proof Proof, publicSignals []*big.Int, debug bool) bool { // e(piA, Va) == e(piA', g2) pairingPiaVa := Utils.Bn.Pairing(proof.PiA, setup.Vk.Vka) pairingPiapG2 := Utils.Bn.Pairing(proof.PiAp, Utils.Bn.G2.G) diff --git a/snark_test.go b/snark_test.go index 2b9fc8f..cffbaa4 100644 --- a/snark_test.go +++ b/snark_test.go @@ -57,7 +57,7 @@ func TestGenerateProofs(t *testing.T) { func TestNewProgramm(t *testing.T) { flat := ` - func main(a,b,c,d): + def main(a,b,c,d): e = a * b f = c * d g = e * f @@ -85,14 +85,15 @@ func TestNewProgramm(t *testing.T) { } fmt.Println("generating R1CS") - a, b, c := program.GenerateReducedR1CS(gates) + r1cs := program.GenerateReducedR1CS(gates) + a, b, c := r1cs.A, r1cs.B, r1cs.C fmt.Println(a) fmt.Println(b) fmt.Println(c) a1 := big.NewInt(int64(6)) a2 := big.NewInt(int64(5)) inputs := []*big.Int{a1, a2, a1, a2} - w := program.CalculateWitness(inputs) + w := circuitcompiler.CalculateWitness(inputs, r1cs) fmt.Println("witness") fmt.Println(w) @@ -120,47 +121,45 @@ func TestNewProgramm(t *testing.T) { // p(x) = a(x) * b(x) - c(x) == h(x) * z(x) abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx) assert.Equal(t, abc, px) - hzQAP := Utils.PF.Mul(hxQAP, domain) - assert.Equal(t, abc, hzQAP) div, rem := Utils.PF.Div(px, domain) - assert.Equal(t, hxQAP, div) //not necessary + assert.Equal(t, hxQAP, div) //not necessary, since DivisorPolynomial is Div, just discarding 'rem' assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(len(px)-len(domain))) //calculate trusted setup - //setup, err := GenerateTrustedSetup(len(w),alphas, betas, gammas) - //assert.Nil(t, err) - //fmt.Println("\nt:", setup.Toxic.T) - //// - ////// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation - //assert.Equal(t, domain, setup.Pk.Z) - // - //fmt.Println("hx pk.z", hxQAP) - //hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) - //fmt.Println("hx pk.z", hx) - //// assert.Equal(t, hxQAP, hx) - //assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP)) - //assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z)) - // - //assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1) - //assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1) - //// fmt.Println("pk.Z", len(setup.Pk.Z)) - //// fmt.Println("zxQAP", len(zxQAP)) + setup, err := GenerateTrustedSetup(len(w), alphas, betas, gammas) + assert.Nil(t, err) + fmt.Println("\nt:", setup.Toxic.T) // - //// piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) - //proof, err := GenerateProofs(*circuit, setup, w, px) - //assert.Nil(t, err) - // - //// fmt.Println("\n proofs:") - //// fmt.Println(proof) - // - //// fmt.Println("public signals:", proof.PublicSignals) - //fmt.Println("\nwitness", w) - //// b1 := big.NewInt(int64(1)) + //// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation + //assert.Equal(t, domain, setup.Pk.Z) + + fmt.Println("hx pk.z", hxQAP) + hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) + fmt.Println("hx pk.z", hx) + // assert.Equal(t, hxQAP, hx) + assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain)) + assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z)) + + assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1) + assert.Equal(t, len(hxQAP), len(px)-len(domain)+1) + // fmt.Println("pk.Z", len(setup.Pk.Z)) + // fmt.Println("zxQAP", len(zxQAP)) + + // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) + proof, err := GenerateProofs(setup, 5, w, px) + assert.Nil(t, err) + + // fmt.Println("\n proofs:") + // fmt.Println(proof) + + // fmt.Println("public signals:", proof.PublicSignals) + fmt.Println("\nwitness", w) + // b1 := big.NewInt(int64(1)) //b35 := big.NewInt(int64(35)) //// publicSignals := []*big.Int{b1, b35} //publicSignals := []*big.Int{b35} //before := time.Now() - //assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true)) + assert.True(t, VerifyProof(setup, proof, w[:5], true)) //fmt.Println("verify proof time elapsed:", time.Since(before)) }