From 49fead21973815a725c0c72941db186d90b690a9 Mon Sep 17 00:00:00 2001 From: mottla Date: Mon, 17 Jun 2019 16:32:35 +0200 Subject: [PATCH] New Multiplication Gate reduction algorithm! Extracting coefficients from each output, s.t. each gate has a higher chance of being reused. See new Readme --- README.md | 16 +- circuitcompiler/Programm.go | 223 ++++-------------- circuitcompiler/Programm_test.go | 155 ++++++++++++- circuitcompiler/circuit.go | 4 +- circuitcompiler/factorHandling.go | 300 +++++++++++++++++++++++++ circuitcompiler/factorHandling_test.go | 205 +++++++++++++++++ 6 files changed, 711 insertions(+), 192 deletions(-) create mode 100644 circuitcompiler/factorHandling.go create mode 100644 circuitcompiler/factorHandling_test.go diff --git a/README.md b/README.md index 6cbf654..f072843 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ Fork UNDER CONSTRUCTION! Will ask for merge soon Current implementation status: +- [x] optimized gate reduction!! Reusing gates as often as possible! See the awesome results below :) - [x] extended circuit code compiler - [x] move witness calculation outside the setup phase - [x] fixed hard bugs @@ -33,12 +34,17 @@ def mul(a,b): ``` R1CS Output: ```go -[[0 0 210 0 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0 0 0 0 0]] -[[0 0 210 0 0 0 0 0 0 0 0 0] [0 0 5 0 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0] [0 5 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 1 0 0 1 0 1 0]] -[[0 0 0 1 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 0 0 0 0 1]] +[[0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1 0 0] [1 0 0 0 0 0 0 0 0 0]] +[[0 0 1 0 0 0 0 0 0 0] [0 0 1 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 9724050000 0 1 9724050000]] +[[0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 0 0 1] [0 0 0 1 0 0 0 0 0 0]] input [7 11] witness -[1 7 11 5336100 293485500 1566067976550000 2160900 75631500 163432108350000 49 343 1729500084900343] +[1 7 11 1729500084900343 121 1331 161051 49 343 16807] +another input +[365235 11876525] +witness +[1 365235 11876525 2297704271284150716235246193843898764109352875 141051846075625 1675205776213312203125 236290867291438012851239954111328125 133396605225 48721109109352875 6499230557984496821593771875] + ``` -Note that we only need 9 multiplication Gates instead of 16 +Note that we only need 7 multiplication Gates instead of 16. The 4th witness value is the programs output. Use python script to check correctness! diff --git a/circuitcompiler/Programm.go b/circuitcompiler/Programm.go index a8eab4e..b0a5b1a 100644 --- a/circuitcompiler/Programm.go +++ b/circuitcompiler/Programm.go @@ -21,6 +21,12 @@ type R1CS struct { B [][]*big.Int C [][]*big.Int } + +type MultiplicationGateSignature struct { + identifier string + commonExtracted [2]int //if the mgate had a extractable factor, it will be stored here +} + type Program struct { functions map[string]*Circuit globalInputs []string @@ -34,12 +40,12 @@ type Program struct { //this datastructure is nice but maybe ill replace it later with something less confusing //it serves the elementary purpose of not computing a variable a second time. //it boosts parse time - computedInContext map[string]map[string]string + computedInContext map[string]map[string]MultiplicationGateSignature //to reduce the number of multiplication gates, we store each factor signature, and the variable name, //so each time a variable is computed, that happens to have the very same factors, we reuse the former //it boost setup and proof time - computedFactors map[string]string + computedFactors map[string]MultiplicationGateSignature } //returns the cardinality of all main inputs + 1 for the "one" signal @@ -129,10 +135,10 @@ func (c *Circuit) buildTree(g *gate) { func (p *Program) ReduceCombinedTree() (orderedmGates []gate) { orderedmGates = []gate{} - p.computedInContext = make(map[string]map[string]string) - p.computedFactors = make(map[string]string) - rootHash := []byte{} - p.computedInContext[string(rootHash)] = make(map[string]string) + p.computedInContext = make(map[string]map[string]MultiplicationGateSignature) + p.computedFactors = make(map[string]MultiplicationGateSignature) + rootHash := make([]byte, 10) + p.computedInContext[string(rootHash)] = make(map[string]MultiplicationGateSignature) p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false) return orderedmGates } @@ -140,7 +146,7 @@ func (p *Program) ReduceCombinedTree() (orderedmGates []gate) { //recursively walks through the parse tree to create a list of all //multiplication gates needed for the QAP construction //Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so -func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) { +func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs factors, hashTraceResult []byte, variableEnd bool) { if node.OperationType() == CONST { b1, v1 := isValue(node.value.Out) @@ -152,7 +158,7 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr mul = [2]int{1, v1} } - return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false + return factors{{typ: CONST, negate: negate, multiplicative: mul}}, hashTraceBuildup, false } if node.OperationType() == FUNC { @@ -161,21 +167,19 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr node = nextContext.root hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName())) if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex { - p.computedInContext[string(hashTraceBuildup)] = make(map[string]string) + p.computedInContext[string(hashTraceBuildup)] = make(map[string]MultiplicationGateSignature) } } if node.OperationType() == IN { - fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}} - hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out)) - return []factor{fac}, hashTraceBuildup, true + fac := &factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}} + return factors{fac}, hashTraceBuildup, true } if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex { - fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}} - hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out)) - return []factor{fac}, hashTraceBuildup, true + fac := &factor{typ: IN, name: out.identifier, invert: invert, negate: negate, multiplicative: out.commonExtracted} + return factors{fac}, hashTraceBuildup, true } leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert) @@ -185,19 +189,25 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr if node.OperationType() == MULTIPLY { if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root { - //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root { - return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons + return mulFactors(leftFactors, rightFactors), hashTraceBuildup, variableEnd || cons + } - sig := factorsSignature(leftFactors, rightFactors) - if out, ex := p.computedFactors[sig]; ex { - return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true + sig, newLef, newRigh := factorsSignature(leftFactors, rightFactors) + if out, ex := p.computedFactors[sig.identifier]; ex { + return factors{{typ: IN, name: out.identifier, invert: invert, negate: negate, multiplicative: sig.commonExtracted}}, hashTraceBuildup, true } rootGate := cloneGate(node) + //rootGate := node rootGate.index = len(*orderedmGates) - rootGate.leftIns = leftFactors - rootGate.rightIns = rightFactors + if p.getMainCircuit().root == node { + newLef = mulFactors(newLef, factors{&factor{typ: CONST, multiplicative: sig.commonExtracted}}) + } + + rootGate.leftIns = newLef + rootGate.rightIns = newRigh + out := hashTogether(leftHash, rightHash) rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10]) rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10]) @@ -208,183 +218,28 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr rootGate.value.Out = rootGate.value.Out + string(out[:10]) } - p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out + p.computedInContext[string(hashTraceBuildup)][node.value.Out] = MultiplicationGateSignature{identifier: rootGate.value.Out, commonExtracted: sig.commonExtracted} - p.computedFactors[sig] = rootGate.value.Out + p.computedFactors[sig.identifier] = MultiplicationGateSignature{identifier: rootGate.value.Out, commonExtracted: sig.commonExtracted} *orderedmGates = append(*orderedmGates, *rootGate) - hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out)) - - return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true + return factors{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: sig.commonExtracted}}, hashTraceBuildup, true } switch node.OperationType() { case PLUS: - return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons + return addFactors(leftFactors, rightFactors), hashTraceBuildup, variableEnd || cons default: panic("unexpected gate") } } -type factor struct { - typ Token - name string - invert, negate bool - multiplicative [2]int -} - -func (f factor) String() string { - if f.typ == CONST { - return fmt.Sprintf("(const fac: %v)", f.multiplicative) - } - str := f.name - if f.invert { - str += "^-1" - } - if f.negate { - str = "-" + str - } - return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative) -} - -func mul2DVector(a, b [2]int) [2]int { - return [2]int{a[0] * b[0], a[1] * b[1]} -} - -func factorsSignature(leftFactors, rightFactors []factor) string { - hasher.Reset() - //using a commutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a) - //could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard - for _, facLeft := range leftFactors { - hasher.Write([]byte(facLeft.String())) - } - for _, Righ := range rightFactors { - hasher.Write([]byte(Righ.String())) - } - return string(hasher.Sum(nil))[:16] -} - -//multiplies factor elements and returns the result -//in case the factors do not hold any constants and all inputs are distinct, the output will be the concatenation of left+right -func mulFactors(leftFactors, rightFactors []factor) (result []factor) { - - for _, facLeft := range leftFactors { - - for i, facRight := range rightFactors { - if facLeft.typ == CONST && facRight.typ == IN { - rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)} - continue - } - if facRight.typ == CONST && facLeft.typ == IN { - rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)} - continue - } - - if facRight.typ&facLeft.typ == CONST { - rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)} - continue - - } - //tricky part here - //this one should only be reached, after a true mgate had its left and right braches computed. here we - //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here. - if facRight.typ&facLeft.typ == IN { - if facLeft.name == facRight.name { - if facRight.invert != facLeft.invert { - rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)} - continue - } - } - - //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)} - //continue - - } - panic("unexpected. If this errror is thrown, its probably brcause a true multiplication gate has been skipped and treated as on with constant multiplication or addition ") - - } - - } - - return rightFactors -} - -//returns the absolute value of a signed int and a flag telling if the input was positive or not -//this implementation is awesome and fast (see Henry S Warren, Hackers's Delight) -func abs(n int) (val int, positive bool) { - y := n >> 63 - return (n ^ y) - y, y == 0 -} - -//returns the reduced sum of two input factor arrays -//if no reduction was done (worst case), it returns the concatenation of the input arrays -func addFactors(leftFactors, rightFactors []factor) []factor { - var found bool - res := make([]factor, 0, len(leftFactors)+len(rightFactors)) - for _, facLeft := range leftFactors { - - found = false - for i, facRight := range rightFactors { - - if facLeft.typ&facRight.typ == CONST { - var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0] - if facLeft.negate { - a0 *= -1 - } - if facRight.negate { - b0 *= -1 - } - absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0) - - rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}} - - found = true - //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}) - break - } - if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name { - var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0] - if facLeft.negate { - a0 *= -1 - } - if facRight.negate { - b0 *= -1 - } - absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0) - - rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}} - - found = true - //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}) - break - } - } - if !found { - res = append(res, facLeft) - } - } - - for _, val := range rightFactors { - if val.multiplicative[0] != 0 { - res = append(res, val) - } - } - - return res -} - //copies a gate neglecting its references to other gates func cloneGate(in *gate) (out *gate) { constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1} - nRightins := make([]factor, len(in.rightIns)) - nLeftInst := make([]factor, len(in.leftIns)) - for k, v := range in.rightIns { - nRightins[k] = v - } - for k, v := range in.leftIns { - nLeftInst[k] = v - } + nRightins := in.rightIns.clone() + nLeftInst := in.leftIns.clone() return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index} } @@ -499,7 +354,7 @@ func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) { bConstraint := r1csqap.ArrayOfBigZeros(size) cConstraint := r1csqap.ArrayOfBigZeros(size) - insertValue := func(val factor, arr []*big.Int) { + insertValue := func(val *factor, arr []*big.Int) { if val.typ != CONST { if _, ex := indexMap[val.name]; !ex { panic(fmt.Sprintf("%v index not found!!!", val.name)) diff --git a/circuitcompiler/Programm_test.go b/circuitcompiler/Programm_test.go index 90bdf90..43f09f9 100644 --- a/circuitcompiler/Programm_test.go +++ b/circuitcompiler/Programm_test.go @@ -114,9 +114,26 @@ var correctnesTest = []TraceCorrectnessTest{ out = g * i `, }, + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(264)), + }}, + code: ` + def main(a,b,c,d): + e = a * 3 + f = b * 7 + g = c * 11 + h = d * 13 + i = e + f + j = g + h + k = i + j + out = k * 1 + `, + }, } -func TestNewProgramm(t *testing.T) { +func TestCorrectness(t *testing.T) { for _, test := range correctnesTest { parser := NewParser(strings.NewReader(test.code)) @@ -160,3 +177,139 @@ func TestNewProgramm(t *testing.T) { } } + +//test to check gate optimisation +//mess around the code s.t. results is unchanged. number of gates should remain the same in any case +func TestGateOptimisation(t *testing.T) { + + io := InOut{ + inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + + result: bigNumberResult1, + } + + equalCodes := []string{ + ` + def main( x , z ) : + out = do(z) + add(x,x) + + def do(x): + e = x * 5 + b = e * 6 + c = 7 * b + f = c * 1 + d = f * c + out = d * mul(d,e) + + def add(x ,k): + z = k * x + out = do(x) + mul(x,z) + + + def mul(a,b): + out = b * a + `, //switching order + ` + def main( x , z ) : + out = do(z) + add(x,x) + + def do(x): + e = x * 5 + b = e * 6 + c = b * 7 + f = c * 1 + d = c * f + out = d * mul(d,e) + + def add(x ,k): + z = k * x + out = do(x) + mul(x,z) + + + def mul(a,b): + out = a * b + `, //switching order + ` + def main( x , z ) : + out = do(z) + add(x,x) + + def do(x): + e = x * 5 + j = e * 3 + k = e * 3 + b = j+k + c = b * 7 + f = c * 1 + d = c * f + g = d * 1 + out = g * mul(d,e) + + def add(k ,x): + z = k * x + out = do(x) + mul(x,z) + + + def mul(b,a): + out = a * b + `, ` + def main( x , z ) : + out = add(x,x)+do(z) + + def do(x): + e = x * 5 + j = 3 * e + k = e * 3 + b = j+k + c = b * 7 + f = c * 1 + d = c * f + g = d * 1 + out = mul(d,e) * g + + def add(k ,x): + z = k * x + out = mul(x,z) + do(x) + + + def mul(b,a): + out = a * b + `, + } + + var r1css = make([]R1CS, len(equalCodes)) + + for i, c := range equalCodes { + parser := NewParser(strings.NewReader(c)) + program, err := parser.Parse() + + if err != nil { + panic(err) + } + program.BuildConstraintTrees() + + gates := program.ReduceCombinedTree() + for _, g := range gates { + fmt.Printf("\n %v", g) + } + fmt.Println("\n generating R1CS") + r1cs := program.GenerateReducedR1CS(gates) + r1css[i] = r1cs + fmt.Println(r1cs.A) + fmt.Println(r1cs.B) + fmt.Println(r1cs.C) + } + + for i := 0; i < len(equalCodes)-1; i++ { + assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A)) + } + + for i := 0; i < len(equalCodes); i++ { + //assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A)) + w := CalculateWitness(io.inputs, r1css[i]) + fmt.Println("witness") + fmt.Println(w) + assert.Equal(t, io.result, w[3]) + + } + +} diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index d6e77ae..d804e61 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -27,8 +27,8 @@ type gate struct { right *gate funcInputs []*gate value *Constraint //is a pointer a good thing here?? - leftIns []factor //leftIns and RightIns after addition gates have been reduced. only multiplication gates remain - rightIns []factor + leftIns factors //leftIns and RightIns after addition gates have been reduced. only multiplication gates remain + rightIns factors } func (g gate) String() string { diff --git a/circuitcompiler/factorHandling.go b/circuitcompiler/factorHandling.go new file mode 100644 index 0000000..6e12ed1 --- /dev/null +++ b/circuitcompiler/factorHandling.go @@ -0,0 +1,300 @@ +package circuitcompiler + +import ( + "fmt" + "math/big" + "sort" + "strings" +) + +type factors []*factor + +type factor struct { + typ Token + name string + invert, negate bool + multiplicative [2]int +} + +func (f factors) Len() int { + return len(f) +} + +func (f factors) Swap(i, j int) { + f[i], f[j] = f[j], f[i] +} + +func (f factors) Less(i, j int) bool { + if strings.Compare(f[i].String(), f[j].String()) < 0 { + return false + } + return true +} + +func (f factor) String() string { + if f.typ == CONST { + return fmt.Sprintf("(const fac: %v)", f.multiplicative) + } + str := f.name + if f.invert { + str += "^-1" + } + if f.negate { + str = "-" + str + } + return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative) +} + +func (f factors) clone() (res factors) { + res = make(factors, len(f)) + for k, v := range f { + res[k] = &factor{multiplicative: v.multiplicative, typ: v.typ, name: v.name, invert: v.invert, negate: v.negate} + } + return +} + +func (f factors) normalizeAll() { + for i, _ := range f { + f[i].multiplicative = normalizeFactor(f[i].multiplicative) + } +} + +// find Least Common Multiple (LCM) via GCD +func LCMsmall(a, b int) int { + result := a * b / GCD(a, b) + + return result +} + +func extractFactor(f factors) (factors, [2]int) { + + lcm := f[0].multiplicative[1] + + for i := 1; i < len(f); i++ { + lcm = LCMsmall(f[i].multiplicative[1], lcm) + } + + for i := 0; i < len(f); i++ { + f[i].multiplicative[0] = (lcm / f[i].multiplicative[1]) * f[i].multiplicative[0] + } + + gcd := f[0].multiplicative[0] + for i := 1; i < len(f); i++ { + gcd = GCD(f[i].multiplicative[0], gcd) + } + for i := 0; i < len(f); i++ { + f[i].multiplicative[0] = f[i].multiplicative[0] / gcd + f[i].multiplicative[1] = 1 + } + + return f, [2]int{gcd, lcm} + +} + +func factorsSignature(leftFactors, rightFactors factors) (sig MultiplicationGateSignature, extractedLeftFactors, extractedRightFactors factors) { + leftFactors = leftFactors.clone() + rightFactors = rightFactors.clone() + + leftFactors.normalizeAll() + var extractedFacLeft [2]int + leftFactors, extractedFacLeft = extractFactor(leftFactors) + + sort.Sort(leftFactors) + hasher.Reset() + for _, fac := range leftFactors { + hasher.Write([]byte(fac.String())) + } + leftNum := new(big.Int).SetBytes(hasher.Sum(nil)) + + rightFactors.normalizeAll() + + var extractedFacRight [2]int + rightFactors, extractedFacRight = extractFactor(rightFactors) + sort.Sort(rightFactors) + hasher.Reset() + + for _, fac := range rightFactors { + hasher.Write([]byte(fac.String())) + } + rightNum := new(big.Int).SetBytes(hasher.Sum(nil)) + + //we did all this, because multiplication is commutativ, and we want the signature of a + //mulitplication gate factorsSignature(a,b) == factorsSignature(b,a) + leftNum.Add(leftNum, rightNum) + + res := normalizeFactor(mul2DVector(extractedFacLeft, extractedFacRight)) + + return MultiplicationGateSignature{identifier: leftNum.String()[:16], commonExtracted: res}, leftFactors, rightFactors +} + +func lengthOfLongestSlice(a, b factors) int { + if len(a) > len(b) { + return len(a) + } + return len(b) +} + +//multiplies factor elements and returns the result +//in case the factors do not hold any constants and all inputs are distinct, the output will be the concatenation of left+right +func mulFactors(leftFactors, rightFactors factors) (result factors) { + + if len(leftFactors) < len(rightFactors) { + tmp := leftFactors + leftFactors = rightFactors + rightFactors = tmp + } + + for i, left := range leftFactors { + + for _, right := range rightFactors { + + if left.typ == CONST && right.typ == IN { + leftFactors[i] = &factor{typ: IN, name: right.name, negate: Xor(left.negate, right.negate), invert: right.invert, multiplicative: mul2DVector(right.multiplicative, left.multiplicative)} + continue + } + + if right.typ == CONST && left.typ == IN { + leftFactors[i] = &factor{typ: IN, name: left.name, negate: Xor(left.negate, right.negate), invert: left.invert, multiplicative: mul2DVector(right.multiplicative, left.multiplicative)} + continue + } + + if right.typ&left.typ == CONST { + leftFactors[i] = &factor{typ: CONST, negate: Xor(right.negate, left.negate), multiplicative: mul2DVector(right.multiplicative, left.multiplicative)} + continue + + } + //tricky part here + //this one should only be reached, after a true mgate had its left and right braches computed. here we + //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here. + if right.typ&left.typ == IN { + if left.name == right.name { + if right.invert != left.invert { + leftFactors[i] = &factor{typ: CONST, negate: Xor(right.negate, left.negate), multiplicative: mul2DVector(right.multiplicative, left.multiplicative)} + continue + } + } + + //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)} + //continue + + } + panic("unexpected. If this errror is thrown, its probably brcause a true multiplication gate has been skipped and treated as on with constant multiplication or addition ") + + } + + } + + return leftFactors +} + +//returns the absolute value of a signed int and a flag telling if the input was positive or not +//this implementation is awesome and fast (see Henry S Warren, Hackers's Delight) +func abs(n int) (val int, positive bool) { + y := n >> 63 + return (n ^ y) - y, y == 0 +} + +//adds two factors to one iff they are both are constants or of the same variable +func addFactor(facLeft, facRight factor) (couldAdd bool, sum factor) { + if facLeft.typ&facRight.typ == CONST { + var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0] + if facLeft.negate { + a0 *= -1 + } + if facRight.negate { + b0 *= -1 + } + absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0) + + return true, factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}} + + } + if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name { + var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0] + if facLeft.negate { + a0 *= -1 + } + if facRight.negate { + b0 *= -1 + } + absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0) + + return true, factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}} + + } + return false, factor{} + +} + +//returns the reduced sum of two input factor arrays +//if no reduction was done (worst case), it returns the concatenation of the input arrays +func addFactors(leftFactors, rightFactors factors) factors { + var found bool + res := make(factors, 0, len(leftFactors)+len(rightFactors)) + for _, facLeft := range leftFactors { + + found = false + for i, facRight := range rightFactors { + + var sum factor + found, sum = addFactor(*facLeft, *facRight) + + if found { + rightFactors[i] = &sum + break + } + + } + if !found { + res = append(res, facLeft) + } + } + + for _, val := range rightFactors { + if val.multiplicative[0] != 0 { + res = append(res, val) + } + } + + return res +} + +// -4/-5 -> 4/5 ; 5/-7 -> -5/7 ; 6 /2 -> 3 / 1 +func normalizeFactor(b [2]int) [2]int { + resa, signa := abs(b[0]) + resb, signb := abs(b[1]) + + gcd := GCD(resa, resb) + + if Xor(signa, signb) { + resa = -resa + } + return [2]int{resa / gcd, resb / gcd} +} + +//naive component multiplication +func mul2DVector(a, b [2]int) [2]int { + + return [2]int{a[0] * b[0], a[1] * b[1]} +} + +// find Least Common Multiple (LCM) via GCD +func LCM(a, b int, integers ...int) int { + result := a * b / GCD(a, b) + + for i := 0; i < len(integers); i++ { + result = LCM(result, integers[i]) + } + + return result +} + +//euclidean algo to determine greates common divisor +func GCD(a, b int) int { + for b != 0 { + t := b + b = a % b + a = t + } + return a +} diff --git a/circuitcompiler/factorHandling_test.go b/circuitcompiler/factorHandling_test.go new file mode 100644 index 0000000..ae3f8f1 --- /dev/null +++ b/circuitcompiler/factorHandling_test.go @@ -0,0 +1,205 @@ +package circuitcompiler + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "math/big" + "math/rand" + "strings" + "testing" +) + +//factors are essential to identify, if a specific gate has been computed already +//eg. if we can extract a factor from a gate that is independent of commutativity, multiplicativitz we will do much better, in finding and reusing old outputs do +//minimize the multiplication gate number +// for example the gate a*b == gate b*a hence, we only need to compute one of both. + +func TestFactorSignature(t *testing.T) { + facNeutral := factors{&factor{multiplicative: [2]int{1, 1}}} + + //dont let the random number be to big, cuz of overflow + r1, r2 := rand.Intn(1<<16), rand.Intn(1<<16) + fmt.Println(r1, r2) + equalityGroups := [][]factors{ + []factors{ //test sign and gcd + {&factor{multiplicative: [2]int{r1 * 2, -r2 * 2}}}, + {&factor{multiplicative: [2]int{-r1, r2}}}, + {&factor{multiplicative: [2]int{r1, -r2}}}, + {&factor{multiplicative: [2]int{r1 * 3, -r2 * 3}}}, + {&factor{multiplicative: [2]int{r1 * r1, -r2 * r1}}}, + {&factor{multiplicative: [2]int{r1 * r2, -r2 * r2}}}, + }, []factors{ //test kommutativity + {&factor{multiplicative: [2]int{r1, -r2}}, &factor{multiplicative: [2]int{13, 27}}}, + {&factor{multiplicative: [2]int{13, 27}}, &factor{multiplicative: [2]int{-r1, r2}}}, + }, + } + + for _, equalityGroup := range equalityGroups { + for i := 0; i < len(equalityGroup)-1; i++ { + sig1, _, _ := factorsSignature(facNeutral, equalityGroup[i]) + sig2, _, _ := factorsSignature(facNeutral, equalityGroup[i+1]) + assert.Equal(t, sig1, sig2) + sig1, _, _ = factorsSignature(equalityGroup[i], facNeutral) + sig2, _, _ = factorsSignature(facNeutral, equalityGroup[i+1]) + assert.Equal(t, sig1, sig2) + + sig1, _, _ = factorsSignature(facNeutral, equalityGroup[i]) + sig2, _, _ = factorsSignature(equalityGroup[i+1], facNeutral) + assert.Equal(t, sig1, sig2) + + sig1, _, _ = factorsSignature(equalityGroup[i], facNeutral) + sig2, _, _ = factorsSignature(equalityGroup[i+1], facNeutral) + assert.Equal(t, sig1, sig2) + } + } + +} + +func TestGate_ExtractValues(t *testing.T) { + facNeutral := factors{&factor{multiplicative: [2]int{8, 7}}, &factor{multiplicative: [2]int{9, 3}}} + facNeutral2 := factors{&factor{multiplicative: [2]int{9, 1}}, &factor{multiplicative: [2]int{13, 7}}} + fmt.Println(factorsSignature(facNeutral, facNeutral2)) + f, fc := extractFactor(facNeutral) + fmt.Println(f) + fmt.Println(fc) + + f2, _ := extractFactor(facNeutral2) + fmt.Println(f) + fmt.Println(fc) + fmt.Println(factorsSignature(facNeutral, facNeutral2)) + fmt.Println(factorsSignature(f, f2)) +} + +func TestGCD(t *testing.T) { + fmt.Println(LCM(10, 15)) + fmt.Println(LCM(10, 15, 20)) + fmt.Println(LCM(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) +} + +var correctnesTest2 = []TraceCorrectnessTest{ + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))}, + result: big.NewInt(int64(98441327276)), + }, { + inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + result: big.NewInt(int64(8675445947220)), + }}, + code: ` + def main(a,b): + c = a + b + e = c - a + f = e + b + g = f + 2 + out = g * a + `, + }, + {io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(7))}, + result: big.NewInt(int64(4)), + }}, + code: ` + def mul(a,b): + out = a * b + + def main(a): + b = a * a + c = 4 - b + d = 5 * c + out = mul(d,c) / mul(b,b) + `, + }, + {io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(22638)), + }, { + inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + result: bigNumberResult2, + }}, + code: ` + def main(a,b): + d = b + b + c = a * d + e = c - a + out = e * c + `, + }, + + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(444675)), + }}, + code: ` + def main(a,b,c,d): + e = a * b + f = c * d + g = e * f + h = g / e + i = h * 5 + out = g * i + `, + }, + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(264)), + }}, + code: ` + def main(a,b,c,d): + e = a * 3 + f = b * 7 + g = c * 11 + h = d * 13 + i = e + f + j = g + h + k = i + j + out = k * 1 + `, + }, +} + +func TestCorrectness2(t *testing.T) { + + for _, test := range correctnesTest2 { + parser := NewParser(strings.NewReader(test.code)) + program, err := parser.Parse() + + if err != nil { + panic(err) + } + fmt.Println("\n unreduced") + fmt.Println(test.code) + + program.BuildConstraintTrees() + for k, v := range program.functions { + fmt.Println(k) + PrintTree(v.root) + } + + fmt.Println("\nReduced gates") + //PrintTree(froots["mul"]) + gates := program.ReduceCombinedTree() + for _, g := range gates { + fmt.Printf("\n %v", g) + } + + fmt.Println("\n generating R1CS") + r1cs := program.GenerateReducedR1CS(gates) + fmt.Println(r1cs.A) + fmt.Println(r1cs.B) + fmt.Println(r1cs.C) + + for _, io := range test.io { + inputs := io.inputs + fmt.Println("input") + fmt.Println(inputs) + w := CalculateWitness(inputs, r1cs) + fmt.Println("witness") + fmt.Println(w) + assert.Equal(t, io.result, w[program.GlobalInputCount()]) + } + + } + +}