diff --git a/circuitcompiler/Programm.go b/circuitcompiler/Programm.go index 7e02960..a8eab4e 100644 --- a/circuitcompiler/Programm.go +++ b/circuitcompiler/Programm.go @@ -42,10 +42,16 @@ type Program struct { computedFactors map[string]string } +//returns the cardinality of all main inputs + 1 for the "one" signal func (p *Program) GlobalInputCount() int { return len(p.globalInputs) } +//returns the cardinaltiy of the output signals. Current only 1 output possible +func (p *Program) GlobalOutputCount() int { + return len(p.globalOutput) +} + func (p *Program) PrintContraintTrees() { for k, v := range p.functions { fmt.Println(k) @@ -68,10 +74,7 @@ func (p *Program) BuildConstraintTrees() { for _, in := range p.getMainCircuit().Inputs { p.globalInputs = append(p.globalInputs, in) } - for key, _ := range p.globalOutput { - p.globalInputs = append(p.globalInputs, key) - } - //TODO do the same with the outputs + var wg = sync.WaitGroup{} //we build the parse trees concurrently! because we can! go rocks @@ -200,6 +203,7 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10]) //note we only check for existence, but not for truth. + //global outputs do not require a hash identifier, since they are unique if _, ex := p.globalOutput[rootGate.value.Out]; !ex { rootGate.value.Out = rootGate.value.Out + string(out[:10]) } @@ -470,14 +474,16 @@ func NewProgram() (p *Program) { func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) { // from flat code to R1CS - offset := len(p.globalInputs) - len(p.globalOutput) + offset := len(p.globalInputs) // one + in1 +in2+... + gate1 + gate2 .. + out size := offset + len(mGates) indexMap := make(map[string]int) for i, v := range p.globalInputs { indexMap[v] = i - + } + for k, _ := range p.globalOutput { + indexMap[k] = len(indexMap) } for _, v := range mGates { if _, ex := indexMap[v.value.Out]; !ex { diff --git a/circuitcompiler/Programm_test.go b/circuitcompiler/Programm_test.go index 412e23f..90bdf90 100644 --- a/circuitcompiler/Programm_test.go +++ b/circuitcompiler/Programm_test.go @@ -154,7 +154,7 @@ func TestNewProgramm(t *testing.T) { w := CalculateWitness(inputs, r1cs) fmt.Println("witness") fmt.Println(w) - assert.Equal(t, io.result, w[len(program.globalInputs)-1]) + assert.Equal(t, io.result, w[program.GlobalInputCount()]) } } diff --git a/snark.go b/snark.go index 135dc70..14ad7d2 100644 --- a/snark.go +++ b/snark.go @@ -90,7 +90,7 @@ func prepareUtils() utils { } // GenerateTrustedSetup generates the Trusted Setup from a compiled Circuit. The Setup.Toxic sub data structure must be destroyed -func GenerateTrustedSetup(inputs int, alphas, betas, gammas [][]*big.Int) (Setup, error) { +func GenerateTrustedSetup(numberInOutSignals int, alphas, betas, gammas [][]*big.Int) (Setup, error) { var setup Setup var err error @@ -179,7 +179,7 @@ func GenerateTrustedSetup(inputs int, alphas, betas, gammas [][]*big.Int) (Setup rhoAat := Utils.FqR.Mul(setup.Toxic.RhoA, at) a := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, rhoAat) setup.Pk.A = append(setup.Pk.A, a) - if i < inputs { + if i < numberInOutSignals { setup.Vk.IC = append(setup.Vk.IC, a) } diff --git a/snark_test.go b/snark_test.go index 6653899..643a201 100644 --- a/snark_test.go +++ b/snark_test.go @@ -163,7 +163,7 @@ func TestGenerateAndVerifyProof(t *testing.T) { before := time.Now() //calculate trusted setup - setup, err := GenerateTrustedSetup(program.GlobalInputCount(), alphas, betas, gammas) + setup, err := GenerateTrustedSetup(program.GlobalInputCount()+program.GlobalOutputCount(), alphas, betas, gammas) fmt.Println("Generate CRS time elapsed:", time.Since(before)) assert.Nil(t, err) fmt.Println("\nt:", setup.Toxic.T) @@ -176,7 +176,7 @@ func TestGenerateAndVerifyProof(t *testing.T) { w := circuitcompiler.CalculateWitness(inputs, r1cs) fmt.Println("\nwitness", w) - assert.Equal(t, io.result, w[program.GlobalInputCount()-1]) + assert.Equal(t, io.result, w[program.GlobalInputCount()]) ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) fmt.Println("ax length", len(ax)) @@ -211,12 +211,13 @@ func TestGenerateAndVerifyProof(t *testing.T) { before := time.Now() // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) - proof, err := GenerateProofs(setup, program.GlobalInputCount(), w, px) + proof, err := GenerateProofs(setup, program.GlobalInputCount()+program.GlobalOutputCount(), w, px) fmt.Println("proof generation time elapsed:", time.Since(before)) assert.Nil(t, err) - + fmt.Println(program.GlobalInputCount() + program.GlobalOutputCount()) before = time.Now() - assert.True(t, VerifyProof(setup, proof, w[:program.GlobalInputCount()], true)) + Signals := w[:program.GlobalInputCount()+program.GlobalOutputCount()] + assert.True(t, VerifyProof(setup, proof, Signals, true)) fmt.Println("verify proof time elapsed:", time.Since(before)) }