Browse Source

reducing multiplication gates with constants

pull/8/head
mottla 5 years ago
parent
commit
de0e081494
6 changed files with 130 additions and 70 deletions
  1. +1
    -1
      README.md
  2. +102
    -13
      circuitcompiler/Programm.go
  3. +11
    -10
      circuitcompiler/Programm_test.go
  4. +3
    -30
      circuitcompiler/circuit.go
  5. +4
    -0
      r1csqap/r1csqap.go
  6. +9
    -16
      snark_test.go

+ 1
- 1
README.md

@ -11,7 +11,7 @@ UNDER CONSTRUCTION!
Implementation of the zkSNARK [Pinocchio protocol](https://eprint.iacr.org/2013/279.pdf) from scratch in Go to understand the concepts. Do not use in production. Implementation of the zkSNARK [Pinocchio protocol](https://eprint.iacr.org/2013/279.pdf) from scratch in Go to understand the concepts. Do not use in production.
This forked aims to extend its functionalities s.t. one can prove set-membership in zero knowledge.
This fork aims to extend its functionalities s.t. one can prove set-membership in zero knowledge.
Current implementation status: Current implementation status:
- [x] Finite Fields (1, 2, 6, 12) operations - [x] Finite Fields (1, 2, 6, 12) operations

+ 102
- 13
circuitcompiler/Programm.go

@ -28,7 +28,7 @@ func (p *Program) BuildConstraintTrees() {
functionRootMap := make(map[string]*gate) functionRootMap := make(map[string]*gate)
for _, circuit := range p.functions { for _, circuit := range p.functions {
circuit.addConstraint(p.oneConstraint())
//circuit.addConstraint(p.oneConstraint())
fName := composeNewFunction(circuit.Name, circuit.Inputs) fName := composeNewFunction(circuit.Name, circuit.Inputs)
root := &gate{value: circuit.constraintMap[fName]} root := &gate{value: circuit.constraintMap[fName]}
functionRootMap[fName] = root functionRootMap[fName] = root
@ -123,19 +123,20 @@ func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool,
//if root == nil { //if root == nil {
// return // return
//} //}
//fmt.Printf("\n%p",mGatesUsed)
if root.OperationType() == FUNC { if root.OperationType() == FUNC {
//if a input has already been built, we let this subroutine know //if a input has already been built, we let this subroutine know
newMap := make(map[string]bool)
//newMap := make(map[string]bool)
for _, in := range root.funcInputs { for _, in := range root.funcInputs {
if _, ex := mGatesUsed[in.value.Out]; ex { if _, ex := mGatesUsed[in.value.Out]; ex {
newMap[in.value.Out] = true
//newMap[in.value.Out] = true
} else { } else {
traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse) traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
} }
} }
//mGatesUsed[root.value.Out] = true //mGatesUsed[root.value.Out] = true
traverseCombinedMultiplicationGates(functionRenamer(root.value), newMap, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
traverseCombinedMultiplicationGates(functionRenamer(root.value), mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
} else { } else {
if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 { if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse) traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
@ -148,12 +149,24 @@ func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool,
if root.OperationType() == MULTIPLY { if root.OperationType() == MULTIPLY {
_, n, _ := isFunction(root.value.Out)
if (root.left.OperationType()|root.right.OperationType())&CONST != 0 && n != "main" {
return
}
root.leftIns = make(map[string]int) root.leftIns = make(map[string]int)
collectAtomsInSubtree(root.left, root.leftIns, functionRootMap, negate, inverse)
collectAtomsInSubtree(root.left, mGatesUsed, 1, root.leftIns, functionRootMap, negate, inverse)
root.rightIns = make(map[string]int) root.rightIns = make(map[string]int)
collectAtomsInSubtree(root.right, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
//if root.left.value.Out== root.right.value.Out{
// //note this is not a full copy, but shouldnt be a problem
// root.rightIns= root.leftIns
//}else{
// collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
//}
collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
root.index = len(mGatesUsed) root.index = len(mGatesUsed)
mGatesUsed[root.value.Out] = true mGatesUsed[root.value.Out] = true
rootGate := cloneGate(root) rootGate := cloneGate(root)
*orderedmGates = append(*orderedmGates, *rootGate) *orderedmGates = append(*orderedmGates, *rootGate)
} }
@ -161,6 +174,63 @@ func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool,
//TODO optimize if output is not a multipication gate //TODO optimize if output is not a multipication gate
} }
func collectAtomsInSubtree(g *gate, mGatesUsed map[string]bool, multiplicative int, in map[string]int, functionRootMap map[string]*gate, negate bool, invert bool) {
if g == nil {
return
}
if _, ex := mGatesUsed[g.value.Out]; ex {
addToMap(g.value.Out, multiplicative, in, negate)
return
}
if g.OperationType()&(IN|CONST) != 0 {
addToMap(g.value.Out, multiplicative, in, negate)
return
}
if g.OperationType()&(MULTIPLY) != 0 {
b1, v1 := isValue(g.value.V1)
b2, v2 := isValue(g.value.V2)
if b1 && !b2 {
multiplicative *= v1
collectAtomsInSubtree(g.right, mGatesUsed, multiplicative, in, functionRootMap, Xor(negate, g.value.negate), invert)
return
} else if !b1 && b2 {
multiplicative *= v2
collectAtomsInSubtree(g.left, mGatesUsed, multiplicative, in, functionRootMap, negate, invert)
return
} else if b1 && b2 {
panic("multiply constants not supported yet")
} else {
panic("werird")
}
}
if g.OperationType() == FUNC {
if b, name, _ := isFunction(g.value.Out); b {
collectAtomsInSubtree(functionRootMap[name], mGatesUsed, multiplicative, in, functionRootMap, negate, invert)
} else {
panic("function expected")
}
}
collectAtomsInSubtree(g.left, mGatesUsed, multiplicative, in, functionRootMap, negate, invert)
collectAtomsInSubtree(g.right, mGatesUsed, multiplicative, in, functionRootMap, Xor(negate, g.value.negate), invert)
}
func addOneToMap(value string, in map[string]int, negate bool) {
addToMap(value, 1, in, negate)
}
func addToMap(value string, val int, in map[string]int, negate bool) {
if negate {
in[value] = (in[value] - 1) * val
} else {
in[value] = (in[value] + 1) * val
}
}
//copies a gate neglecting its references to other gates //copies a gate neglecting its references to other gates
func cloneGate(in *gate) (out *gate) { func cloneGate(in *gate) (out *gate) {
constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1} constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
@ -184,15 +254,16 @@ func (p *Program) addGlobalInput(c *Constraint) {
} }
func NewProgramm() *Program { func NewProgramm() *Program {
return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: CONST, Out: "one"}}}
//return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: PLUS, V1:"1",V2:"0", Out: "one"}}}
return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: IN, Out: "one"}}}
} }
func (p *Program) oneConstraint() *Constraint {
if p.globalInputs[0].Out != "one" {
panic("'one' should be first global input")
}
return p.globalInputs[0]
}
//func (p *Program) oneConstraint() *Constraint {
// if p.globalInputs[0].Out != "one" {
// panic("'one' should be first global input")
// }
// return p.globalInputs[0]
//}
func (p *Program) addSignal(name string) { func (p *Program) addSignal(name string) {
p.signals = append(p.signals, name) p.signals = append(p.signals, name)
@ -262,7 +333,25 @@ func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
bConstraint := r1csqap.ArrayOfBigZeros(size) bConstraint := r1csqap.ArrayOfBigZeros(size)
cConstraint := r1csqap.ArrayOfBigZeros(size) cConstraint := r1csqap.ArrayOfBigZeros(size)
//if len(gate.leftIns)>=len(gate.rightIns){
// for leftInput, _ := range gate.leftIns {
// if v, ex := gate.rightIns[leftInput]; ex {
// gate.leftIns[leftInput] *= v
// gate.rightIns[leftInput] = 1
//
// }
// }
//}else{
// for rightInput, _ := range gate.rightIns {
// if v, ex := gate.leftIns[rightInput]; ex {
// gate.rightIns[rightInput] *= v
// gate.leftIns[rightInput] = 1
// }
// }
//}
for leftInput, val := range gate.leftIns { for leftInput, val := range gate.leftIns {
insertVar3(aConstraint, val, leftInput, indexMap[leftInput]) insertVar3(aConstraint, val, leftInput, indexMap[leftInput])
} }
for rightInput, val := range gate.rightIns { for rightInput, val := range gate.rightIns {

+ 11
- 10
circuitcompiler/Programm_test.go

@ -17,29 +17,30 @@ func TestNewProgramm(t *testing.T) {
flat := ` flat := `
func do(x): func do(x):
e = x * x
b = e * e
c = b * b
d = c * c
e = x * 5
b = e * 6
c = b * 7
f = c * 1
d = c / f
out = d * 1 out = d * 1
func add(x ,k): func add(x ,k):
z = k * x z = k * x
out = do(x) + mul(x,z) out = do(x) + mul(x,z)
func main(a,b):
out = do(5) + 4
func main(x,z):
out = do(x) + 4
func mul(a,b): func mul(a,b):
out = a * b out = a * b
` `
//flat := ` //flat := `
//func do(x):
// b = x - 2
// out = x * b
//func main(a,b): //func main(a,b):
// out = do(a) + 4
// e = 4 * a
// c = a * e
// d = c * 70
// out = a * d
//` //`
parser := NewParser(strings.NewReader(flat)) parser := NewParser(strings.NewReader(flat))
program, err := parser.Parse() program, err := parser.Parse()

+ 3
- 30
circuitcompiler/circuit.go

@ -109,13 +109,15 @@ func (circ *Circuit) addConstraint(constraint *Constraint) {
//the main functions output must be a multiplication gate //the main functions output must be a multiplication gate
//if its not, then we simple create one where outNew = 1 * outOld //if its not, then we simple create one where outNew = 1 * outOld
if constraint.Op&(MINUS|PLUS) != 0 { if constraint.Op&(MINUS|PLUS) != 0 {
newOut := &Constraint{Out: constraint.Out, V1: "one", V2: "out2", Op: MULTIPLY}
newOut := &Constraint{Out: constraint.Out, V1: "1", V2: "out2", Op: MULTIPLY}
//TODO reachable?
delete(circ.constraintMap, constraint.Out) delete(circ.constraintMap, constraint.Out)
circ.addConstraint(newOut) circ.addConstraint(newOut)
constraint.Out = "out2" constraint.Out = "out2"
circ.addConstraint(constraint) circ.addConstraint(constraint)
} }
} }
} }
addConstantsAndFunctions := func(constraint string) { addConstantsAndFunctions := func(constraint string) {
@ -274,35 +276,6 @@ func printTree(g *gate, d int) {
} }
} }
func addToMap(value string, in map[string]int, negate bool) {
if negate {
in[value] = in[value] - 1
} else {
in[value] = in[value] + 1
}
}
func collectAtomsInSubtree(g *gate, in map[string]int, functionRootMap map[string]*gate, negate bool, invert bool) {
if g == nil {
return
}
if g.OperationType()&(MULTIPLY|IN|CONST) != 0 {
addToMap(g.value.Out, in, negate)
return
}
if g.OperationType() == FUNC {
if b, name, _ := isFunction(g.value.Out); b {
collectAtomsInSubtree(functionRootMap[name], in, functionRootMap, negate, invert)
} else {
panic("function expected")
}
}
collectAtomsInSubtree(g.left, in, functionRootMap, negate, invert)
collectAtomsInSubtree(g.right, in, functionRootMap, Xor(negate, g.value.negate), invert)
}
func Xor(a, b bool) bool { func Xor(a, b bool) bool {
return (a && !b) || (!a && b) return (a && !b) || (!a && b)
} }

+ 4
- 0
r1csqap/r1csqap.go

@ -2,6 +2,7 @@ package r1csqap
import ( import (
"bytes" "bytes"
"fmt"
"math/big" "math/big"
"github.com/mottla/go-snark/fields" "github.com/mottla/go-snark/fields"
@ -162,6 +163,9 @@ func (pf PolynomialField) R1CSToQAP(a, b, c [][]*big.Int) ([][]*big.Int, [][]*bi
aT := Transpose(a) aT := Transpose(a)
bT := Transpose(b) bT := Transpose(b)
cT := Transpose(c) cT := Transpose(c)
fmt.Println(aT)
fmt.Println(bT)
fmt.Println(cT)
var alphas [][]*big.Int var alphas [][]*big.Int
for i := 0; i < len(aT); i++ { for i := 0; i < len(aT); i++ {
alphas = append(alphas, pf.LagrangeInterpolation(aT[i])) alphas = append(alphas, pf.LagrangeInterpolation(aT[i]))

+ 9
- 16
snark_test.go

@ -3,7 +3,6 @@ package snark
import ( import (
"fmt" "fmt"
"github.com/mottla/go-snark/circuitcompiler" "github.com/mottla/go-snark/circuitcompiler"
"github.com/mottla/go-snark/r1csqap"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"math/big" "math/big"
"strings" "strings"
@ -13,16 +12,10 @@ import (
func TestNewProgramm(t *testing.T) { func TestNewProgramm(t *testing.T) {
flat := ` flat := `
func add(x ,k):
z = k * x
out = x + mul(x,z)
func main(a,b):
out = add(a,b) * a
func mul(a,b):
out = a * b
func main(a,b,c,d):
e = a + b
f = c * d
out = e * f
` `
parser := circuitcompiler.NewParser(strings.NewReader(flat)) parser := circuitcompiler.NewParser(strings.NewReader(flat))
@ -35,7 +28,7 @@ func TestNewProgramm(t *testing.T) {
fmt.Println(flat) fmt.Println(flat)
program.BuildConstraintTrees() program.BuildConstraintTrees()
program.PrintConstraintTrees()
program.PrintContraintTrees()
fmt.Println("\nReduced gates") fmt.Println("\nReduced gates")
//PrintTree(froots["mul"]) //PrintTree(froots["mul"])
gates := program.ReduceCombinedTree() gates := program.ReduceCombinedTree()
@ -50,7 +43,7 @@ func TestNewProgramm(t *testing.T) {
fmt.Println(c) fmt.Println(c)
a1 := big.NewInt(int64(6)) a1 := big.NewInt(int64(6))
a2 := big.NewInt(int64(5)) a2 := big.NewInt(int64(5))
inputs := []*big.Int{a1, a2}
inputs := []*big.Int{a1, a2, a1, a2}
w := program.CalculateWitness(inputs) w := program.CalculateWitness(inputs)
fmt.Println("witness") fmt.Println("witness")
fmt.Println(w) fmt.Println(w)
@ -82,9 +75,9 @@ func TestNewProgramm(t *testing.T) {
hzQAP := Utils.PF.Mul(hxQAP, zxQAP) hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
assert.Equal(t, abc, hzQAP) assert.Equal(t, abc, hzQAP)
div, rem := Utils.PF.Div(px, zxQAP)
assert.Equal(t, hxQAP, div)
assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
//div, rem := Utils.PF.Div(px, zxQAP)
//assert.Equal(t, hxQAP, div)
//assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
// calculate trusted setup // calculate trusted setup
//setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas) //setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)

Loading…
Cancel
Save