Browse Source

multiplication gate reduction via reusing old gates with equivalent behaviour

multiplication gate reduction via reusing old gates with equivalent behaviour
pull/8/head
mottla 5 years ago
parent
commit
2fc30d452a
5 changed files with 202 additions and 124 deletions
  1. +79
    -67
      circuitcompiler/Programm.go
  2. +17
    -17
      circuitcompiler/Programm_test.go
  3. +68
    -0
      circuitcompiler/test.py
  4. +4
    -5
      snark.go
  5. +34
    -35
      snark_test.go

+ 79
- 67
circuitcompiler/Programm.go

@ -17,17 +17,28 @@ type utils struct {
PF r1csqap.PolynomialField
}
type R1CS struct {
A [][]*big.Int
B [][]*big.Int
C [][]*big.Int
}
type Program struct {
functions map[string]*Circuit
globalInputs []string
arithmeticEnvironment utils //find a better name
sha256Hasher hash.Hash
computedInContext map[string]map[string]string
R1CS struct {
A [][]*big.Int
B [][]*big.Int
C [][]*big.Int
}
//key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function.
//value 1 : map
// with key variable name
// with value variable name + hash Chain
//this datastructure is nice but maybe ill replace it later with something less confusing
//it serves the elementary purpose of not computing a variable a second time
computedInContext map[string]map[string]string
//to reduce the number of multiplication gates, we store each factor signature, and the variable name,
//so each time a variable is computed, that happens to have the very same factors, we reuse the former gate
computedFactors map[string]string
}
func (p *Program) PrintContraintTrees() {
@ -48,9 +59,6 @@ func (p *Program) BuildConstraintTrees() {
p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
}
//for _, in := range p.getMainCircuit().Inputs {
// p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs))
//}
for _, in := range p.getMainCircuit().Inputs {
p.globalInputs = append(p.globalInputs, in)
}
@ -58,10 +66,11 @@ func (p *Program) BuildConstraintTrees() {
for _, circuit := range p.functions {
wg.Add(1)
func() {
circuit.buildTree(circuit.root)
//interesting: if circuit is not passed as argument, the program fails. duno why..
go func(c *Circuit) {
c.buildTree(c.root)
wg.Done()
}()
}(circuit)
}
wg.Wait()
@ -109,12 +118,17 @@ func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
//mGatesUsed := make(map[string]bool)
orderedmGates = []gate{}
p.computedInContext = make(map[string]map[string]string)
p.computedFactors = make(map[string]string)
rootHash := []byte{}
p.computedInContext[string(rootHash)] = make(map[string]string)
p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
return orderedmGates
}
//recursively walks through the parse tree to create a list of all
//multiplication gates needed for the QAP construction
//Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
//TODO if the same variable that has been computed in context A, is needed again but from a different context b, will be recomputed and not reused
func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
if node.OperationType() == CONST {
@ -163,6 +177,12 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
//if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
}
sig := factorsSignature(leftFactors, rightFactors)
if out, ex := p.computedFactors[sig]; ex {
return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
}
rootGate := cloneGate(node)
rootGate.index = len(*orderedmGates)
rootGate.leftIns = leftFactors
@ -172,6 +192,7 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
rootGate.value.Out = rootGate.value.Out + string(out[:10])
p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
p.computedFactors[sig] = rootGate.value.Out
*orderedmGates = append(*orderedmGates, *rootGate)
hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
@ -186,7 +207,6 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
panic("unexpected gate")
}
//TODO optimize if output is not a multipication gate
}
type factor struct {
@ -214,6 +234,19 @@ func mul2DVector(a, b [2]int) [2]int {
return [2]int{a[0] * b[0], a[1] * b[1]}
}
func factorsSignature(leftFactors, rightFactors []factor) string {
hasher.Reset()
//not using a kommutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a)
//could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard
for _, facLeft := range leftFactors {
hasher.Write([]byte(facLeft.String()))
}
for _, Righ := range rightFactors {
hasher.Write([]byte(Righ.String()))
}
return string(hasher.Sum(nil))[:16]
}
func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
for _, facLeft := range leftFactors {
@ -248,7 +281,6 @@ func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
//continue
}
fmt.Println("dsf")
panic("unexpected")
}
@ -340,11 +372,6 @@ func (p *Program) getMainCircuit() *Circuit {
return p.functions["main"]
}
//func (p *Program) addGlobalInput(c Constraint) {
// c.Out = "main@" + c.Out
// p.globalInputs = append(p.globalInputs, c)
//}
func prepareUtils() utils {
bn, err := bn128.NewBn128()
if err != nil {
@ -424,7 +451,7 @@ func NewProgram() (p *Program) {
}
// GenerateR1CS generates the R1CS polynomials from the Circuit
func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
// from flat code to R1CS
offset := len(p.globalInputs)
@ -440,51 +467,52 @@ func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
indexMap[v.value.Out] = i + offset
}
for _, gate := range mGates {
for _, g := range mGates {
if gate.OperationType() == MULTIPLY {
if g.OperationType() == MULTIPLY {
aConstraint := r1csqap.ArrayOfBigZeros(size)
bConstraint := r1csqap.ArrayOfBigZeros(size)
cConstraint := r1csqap.ArrayOfBigZeros(size)
for _, val := range gate.leftIns {
insertValue := func(val factor, arr []*big.Int) {
if val.typ != CONST {
if _, ex := indexMap[val.name]; !ex {
panic(fmt.Sprintf("%v index not found!!!", val.name))
}
}
convertAndInsertFactorAt(aConstraint, val, indexMap[val.name])
value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
if val.negate {
value.Neg(value)
}
//not that index is 0 if its a constant, since 0 is the map default if no entry was found
arr[indexMap[val.name]] = value
}
for _, val := range gate.rightIns {
if val.typ != CONST {
if _, ex := indexMap[val.name]; !ex {
panic(fmt.Sprintf("%v index not found!!!", val.name))
}
}
for _, val := range g.leftIns {
insertValue(val, aConstraint)
}
convertAndInsertFactorAt(bConstraint, val, indexMap[val.name])
for _, val := range g.rightIns {
insertValue(val, bConstraint)
}
cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1))
if gate.value.invert {
if g.value.invert {
tmp := aConstraint
aConstraint = cConstraint
cConstraint = tmp
}
a = append(a, aConstraint)
b = append(b, bConstraint)
c = append(c, cConstraint)
r1CS.A = append(r1CS.A, aConstraint)
r1CS.B = append(r1CS.B, bConstraint)
r1CS.C = append(r1CS.C, cConstraint)
} else {
panic("not a m gate")
}
}
p.R1CS.A = a
p.R1CS.B = b
p.R1CS.C = c
return a, b, c
return
}
var Utils = prepareUtils()
@ -494,25 +522,15 @@ func fractionToField(in [2]int) *big.Int {
}
func convertAndInsertFactorAt(arr []*big.Int, val factor, index int) {
value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
if val.negate {
value.Neg(value)
}
//not that index is 0 if its a constant, since 0 is the map default if no entry was found
arr[index] = value
//Calculates the witness (program trace) given some input
//asserts that R1CS has been computed and is stored in the program p memory calling this function
func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) {
}
func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
if len(p.globalInputs)-1 != len(input) {
panic("input do not match the required inputs")
}
//if len(p.globalInputs)-1 != len(input) {
// panic("input do not match the required inputs")
//}
witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0]))
set := make([]bool, len(witness))
witness[0] = big.NewInt(int64(1))
set[0] = true
@ -524,10 +542,10 @@ func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
zero := big.NewInt(int64(0))
for i := 0; i < len(p.R1CS.A); i++ {
gatesLeftInputs := p.R1CS.A[i]
gatesRightInputs := p.R1CS.B[i]
gatesOutputs := p.R1CS.C[i]
for i := 0; i < len(r1cs.A); i++ {
gatesLeftInputs := r1cs.A[i]
gatesRightInputs := r1cs.B[i]
gatesOutputs := r1cs.C[i]
sumLeft := big.NewInt(int64(0))
sumRight := big.NewInt(int64(0))
@ -583,12 +601,6 @@ func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
var hasher = sha256.New()
func hashFactorWithContext(f factor, currentCircuit *Circuit) []byte {
hasher.Reset()
hasher.Write([]byte(f.name))
hasher.Write([]byte(currentCircuit.currentOutputName()))
return hasher.Sum(nil)
}
func hashTogether(a, b []byte) []byte {
hasher.Reset()
hasher.Write(a)

+ 17
- 17
circuitcompiler/Programm_test.go

@ -101,18 +101,18 @@ var correctnesTest = []TraceCorrectnessTest{
}
func TestNewProgramm(t *testing.T) {
flat := `
def doSomething(x ,k):
z = k * x
out = 6 + mul(x,z)
def main(x,z):
out = mul(x,z) - doSomething(x,x)
def mul(a,b):
out = a * b
`
//flat := `
//
//def doSomething(x ,k):
// z = k * x
// out = 6 + mul(x,z)
//
//def main(x,z):
// out = mul(x,z) - doSomething(x,x)
//
//def mul(a,b):
// out = a * b
//`
for _, test := range correctnesTest {
parser := NewParser(strings.NewReader(test.code))
@ -138,16 +138,16 @@ func TestNewProgramm(t *testing.T) {
}
fmt.Println("\n generating R1CS")
a, b, c := program.GenerateReducedR1CS(gates)
fmt.Println(a)
fmt.Println(b)
fmt.Println(c)
r1cs := program.GenerateReducedR1CS(gates)
fmt.Println(r1cs.A)
fmt.Println(r1cs.B)
fmt.Println(r1cs.C)
for _, io := range test.io {
inputs := io.inputs
fmt.Println("input")
fmt.Println(inputs)
w := program.CalculateWitness(inputs)
w := CalculateWitness(inputs, r1cs)
fmt.Println("witness")
fmt.Println(w)
assert.Equal(t, io.result, w[len(w)-1])

+ 68
- 0
circuitcompiler/test.py

@ -0,0 +1,68 @@
# def do(x):
# e = x * 5
# b = e * 6
# c = b * 7
# f = c * 1
# d = c * f
# return d * mul(d,e)
#
# def add(x ,k):
# z = k * x
# return do(x) + mul(x,z)
#
#
# def mul(a,b):
# return a * b
#
# def main():
# x=365235
# z=11876525
# print(do(z) + add(x,x))
################################
# def add(x ,k):
# z = k * x
# return 6 + mul(x,z)
# def asdf(a,b):
# d = b + b
# c = a * d
# e = c - a
# return e * c
#
# def asdf(a,b):
# c = a + b
# e = c - a
# f = e + b
# g = f + 2
# return g * a
def doSomething(x ,k):
z = k * x
return 6 + mul(x,z)
# def main(x,z):
# out =
def mul(a,b):
return a * b
def main():
x=64341
z=76548465
print(mul(x,z) - doSomething(x,x))
#
# def mul(a,b):
# return a * b
#
# def asdf(a):
# b = a * a
# c = 4 - b
# d = 5 * c
# return mul(d,c) / mul(b,b)
if __name__ == '__main__':
#pascal(8)
main()

+ 4
- 5
snark.go

@ -6,7 +6,6 @@ import (
"os"
"github.com/mottla/go-snark/bn128"
"github.com/mottla/go-snark/circuitcompiler"
"github.com/mottla/go-snark/fields"
"github.com/mottla/go-snark/r1csqap"
)
@ -246,7 +245,7 @@ func GenerateTrustedSetup(witnessLength int, alphas, betas, gammas [][]*big.Int)
}
// GenerateProofs generates all the parameters to proof the zkSNARK from the Circuit, Setup and the Witness
func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int, px []*big.Int) (Proof, error) {
func GenerateProofs(setup Setup, nInputs int, w []*big.Int, px []*big.Int) (Proof, error) {
var proof Proof
proof.PiA = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()}
proof.PiAp = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()}
@ -257,12 +256,12 @@ func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int,
proof.PiH = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()}
proof.PiKp = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()}
for i := circuit.NPublic + 1; i < circuit.NVars; i++ {
for i := nInputs; i < len(w)-1; i++ {
proof.PiA = Utils.Bn.G1.Add(proof.PiA, Utils.Bn.G1.MulScalar(setup.Pk.A[i], w[i]))
proof.PiAp = Utils.Bn.G1.Add(proof.PiAp, Utils.Bn.G1.MulScalar(setup.Pk.Ap[i], w[i]))
}
for i := 0; i < circuit.NVars; i++ {
for i := 0; i < len(w); i++ {
proof.PiB = Utils.Bn.G2.Add(proof.PiB, Utils.Bn.G2.MulScalar(setup.Pk.B[i], w[i]))
proof.PiBp = Utils.Bn.G1.Add(proof.PiBp, Utils.Bn.G1.MulScalar(setup.Pk.Bp[i], w[i]))
@ -284,7 +283,7 @@ func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int,
}
// VerifyProof verifies over the BN128 the Pairings of the Proof
func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publicSignals []*big.Int, debug bool) bool {
func VerifyProof(setup Setup, proof Proof, publicSignals []*big.Int, debug bool) bool {
// e(piA, Va) == e(piA', g2)
pairingPiaVa := Utils.Bn.Pairing(proof.PiA, setup.Vk.Vka)
pairingPiapG2 := Utils.Bn.Pairing(proof.PiAp, Utils.Bn.G2.G)

+ 34
- 35
snark_test.go

@ -57,7 +57,7 @@ func TestGenerateProofs(t *testing.T) {
func TestNewProgramm(t *testing.T) {
flat := `
func main(a,b,c,d):
def main(a,b,c,d):
e = a * b
f = c * d
g = e * f
@ -85,14 +85,15 @@ func TestNewProgramm(t *testing.T) {
}
fmt.Println("generating R1CS")
a, b, c := program.GenerateReducedR1CS(gates)
r1cs := program.GenerateReducedR1CS(gates)
a, b, c := r1cs.A, r1cs.B, r1cs.C
fmt.Println(a)
fmt.Println(b)
fmt.Println(c)
a1 := big.NewInt(int64(6))
a2 := big.NewInt(int64(5))
inputs := []*big.Int{a1, a2, a1, a2}
w := program.CalculateWitness(inputs)
w := circuitcompiler.CalculateWitness(inputs, r1cs)
fmt.Println("witness")
fmt.Println(w)
@ -120,47 +121,45 @@ func TestNewProgramm(t *testing.T) {
// p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
assert.Equal(t, abc, px)
hzQAP := Utils.PF.Mul(hxQAP, domain)
assert.Equal(t, abc, hzQAP)
div, rem := Utils.PF.Div(px, domain)
assert.Equal(t, hxQAP, div) //not necessary
assert.Equal(t, hxQAP, div) //not necessary, since DivisorPolynomial is Div, just discarding 'rem'
assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(len(px)-len(domain)))
//calculate trusted setup
//setup, err := GenerateTrustedSetup(len(w),alphas, betas, gammas)
//assert.Nil(t, err)
//fmt.Println("\nt:", setup.Toxic.T)
////
////// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
//assert.Equal(t, domain, setup.Pk.Z)
//
//fmt.Println("hx pk.z", hxQAP)
//hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
//fmt.Println("hx pk.z", hx)
//// assert.Equal(t, hxQAP, hx)
//assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
//assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
//
//assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
//assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
//// fmt.Println("pk.Z", len(setup.Pk.Z))
//// fmt.Println("zxQAP", len(zxQAP))
setup, err := GenerateTrustedSetup(len(w), alphas, betas, gammas)
assert.Nil(t, err)
fmt.Println("\nt:", setup.Toxic.T)
//
//// piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
//proof, err := GenerateProofs(*circuit, setup, w, px)
//assert.Nil(t, err)
//
//// fmt.Println("\n proofs:")
//// fmt.Println(proof)
//
//// fmt.Println("public signals:", proof.PublicSignals)
//fmt.Println("\nwitness", w)
//// b1 := big.NewInt(int64(1))
//// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
//assert.Equal(t, domain, setup.Pk.Z)
fmt.Println("hx pk.z", hxQAP)
hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
fmt.Println("hx pk.z", hx)
// assert.Equal(t, hxQAP, hx)
assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain))
assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
assert.Equal(t, len(hxQAP), len(px)-len(domain)+1)
// fmt.Println("pk.Z", len(setup.Pk.Z))
// fmt.Println("zxQAP", len(zxQAP))
// piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
proof, err := GenerateProofs(setup, 5, w, px)
assert.Nil(t, err)
// fmt.Println("\n proofs:")
// fmt.Println(proof)
// fmt.Println("public signals:", proof.PublicSignals)
fmt.Println("\nwitness", w)
// b1 := big.NewInt(int64(1))
//b35 := big.NewInt(int64(35))
//// publicSignals := []*big.Int{b1, b35}
//publicSignals := []*big.Int{b35}
//before := time.Now()
//assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
assert.True(t, VerifyProof(setup, proof, w[:5], true))
//fmt.Println("verify proof time elapsed:", time.Since(before))
}

Loading…
Cancel
Save