diff --git a/plonky2_verifier/arithmetic_gate.go b/plonky2_verifier/arithmetic_gate.go new file mode 100644 index 0000000..bc4d54e --- /dev/null +++ b/plonky2_verifier/arithmetic_gate.go @@ -0,0 +1,58 @@ +package plonky2_verifier + +import ( + "fmt" + . "gnark-plonky2-verifier/field" +) + +type ArithmeticGate struct { + numOps uint64 +} + +func NewArithmeticGate(numOps uint64) *ArithmeticGate { + return &ArithmeticGate{ + numOps: numOps, + } +} + +func (g *ArithmeticGate) Id() string { + return fmt.Sprintf("ArithmeticGate { num_ops: %d }", g.numOps) +} + +func (g *ArithmeticGate) WireIthMultiplicand0(i uint64) uint64 { + return 4 * i +} + +func (g *ArithmeticGate) WireIthMultiplicand1(i uint64) uint64 { + return 4*i + 1 +} + +func (g *ArithmeticGate) WireIthAddend(i uint64) uint64 { + return 4*i + 2 +} + +func (g *ArithmeticGate) WireIthOutput(i uint64) uint64 { + return 4*i + 3 +} + +func (g *ArithmeticGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { + const0 := vars.localConstants[0] + const1 := vars.localConstants[1] + + constraints := []QuadraticExtension{} + for i := uint64(0); i < g.numOps; i++ { + multiplicand0 := vars.localWires[g.WireIthMultiplicand0(i)] + multiplicand1 := vars.localWires[g.WireIthMultiplicand1(i)] + addend := vars.localWires[g.WireIthAddend(i)] + output := vars.localWires[g.WireIthOutput(i)] + + computedOutput := p.qeAPI.AddExtension( + p.qeAPI.MulExtension(p.qeAPI.MulExtension(multiplicand0, multiplicand1), const0), + p.qeAPI.MulExtension(addend, const1), + ) + + constraints = append(constraints, p.qeAPI.SubExtension(computedOutput, output)) + } + + return constraints +} diff --git a/plonky2_verifier/constant_gate.go b/plonky2_verifier/constant_gate.go new file mode 100644 index 0000000..2d12609 --- /dev/null +++ b/plonky2_verifier/constant_gate.go @@ -0,0 +1,44 @@ +package plonky2_verifier + +import ( + "fmt" + . "gnark-plonky2-verifier/field" +) + +type ConstantGate struct { + numConsts uint64 +} + +func NewConstantGate(numConsts uint64) *ConstantGate { + return &ConstantGate{ + numConsts: numConsts, + } +} + +func (g *ConstantGate) Id() string { + return fmt.Sprintf("ConstantGate { num_consts: %d }", g.numConsts) +} + +func (g *ConstantGate) ConstInput(i uint64) uint64 { + if i > g.numConsts { + panic("Invalid constant index") + } + return i +} + +func (g *ConstantGate) WireOutput(i uint64) uint64 { + if i > g.numConsts { + panic("Invalid wire index") + } + return i +} + +func (g *ConstantGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { + constraints := []QuadraticExtension{} + + for i := uint64(0); i < g.numConsts; i++ { + constraints = append(constraints, p.qeAPI.SubExtension(vars.localConstants[g.ConstInput(i)], vars.localWires[g.WireOutput(i)])) + } + + return constraints +} diff --git a/plonky2_verifier/deserialize.go b/plonky2_verifier/deserialize.go index 407e167..205c5e9 100644 --- a/plonky2_verifier/deserialize.go +++ b/plonky2_verifier/deserialize.go @@ -130,7 +130,8 @@ type CommonCircuitDataRaw struct { DegreeBits uint64 `json:"degree_bits"` ReductionArityBits []uint64 `json:"reduction_arity_bits"` } `json:"fri_params"` - DegreeBits uint64 `json:"degree_bits"` + Gates []string `json:"gates"` + DegreeBits uint64 `json:"degree_bits"` SelectorsInfo struct { SelectorIndices []uint64 `json:"selector_indices"` Groups []struct { @@ -346,6 +347,11 @@ func DeserializeCommonCircuitData(path string) CommonCircuitData { commonCircuitData.FriParams.Config.NumQueryRounds = raw.FriParams.Config.NumQueryRounds commonCircuitData.FriParams.ReductionArityBits = raw.FriParams.ReductionArityBits + commonCircuitData.Gates = []gate{} + for _, gate := range raw.Gates { + commonCircuitData.Gates = append(commonCircuitData.Gates, GateInstanceFromId(gate)) + } + commonCircuitData.DegreeBits = raw.DegreeBits commonCircuitData.QuotientDegreeFactor = raw.QuotientDegreeFactor commonCircuitData.NumGateConstraints = raw.NumGateConstraints diff --git a/plonky2_verifier/gate.go b/plonky2_verifier/gate.go index bbaeed1..5e22677 100644 --- a/plonky2_verifier/gate.go +++ b/plonky2_verifier/gate.go @@ -2,10 +2,51 @@ package plonky2_verifier import ( . "gnark-plonky2-verifier/field" + "strconv" + "strings" ) type gate interface { - EvalUnfiltered(vars EvaluationVars) []QuadraticExtension + Id() string + EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension +} + +func GateInstanceFromId(gateId string) gate { + if strings.HasPrefix(gateId, "ArithmeticGate") { + numOpsRaw := strings.Split(gateId, ":")[1] + numOpsRaw = strings.Split(numOpsRaw, "}")[0] + numOpsRaw = strings.TrimSpace(numOpsRaw) + numOps, err := strconv.Atoi(numOpsRaw) + if err != nil { + panic("Invalid gate ID for ArithmeticGate") + } + return NewArithmeticGate(uint64(numOps)) + } + + if gateId == "ConstantGate" { + numConstsRaw := strings.Split(gateId, ":")[1] + numConstsRaw = strings.Split(numConstsRaw, "}")[0] + numConstsRaw = strings.TrimSpace(numConstsRaw) + numConsts, err := strconv.Atoi(numConstsRaw) + if err != nil { + panic("Invalid gate ID") + } + return NewConstantGate(uint64(numConsts)) + } + + if gateId == "NoopGate" { + return NewNoopGate() + } + + if gateId == "PublicInputGate" { + return NewPublicInputGate() + } + + if strings.HasPrefix(gateId, "PoseidonGate") { + return NewPoseidonGate() + } + + panic("Unknown gate ID") } func (p *PlonkChip) computeFilter( @@ -42,7 +83,7 @@ func (p *PlonkChip) evalFiltered( vars.RemovePrefix(numSelectors) - unfiltered := g.EvalUnfiltered(vars) + unfiltered := g.EvalUnfiltered(p, vars) for i := range unfiltered { unfiltered[i] = p.qeAPI.MulExtension(unfiltered[i], filter) } diff --git a/plonky2_verifier/noop_gate.go b/plonky2_verifier/noop_gate.go index 3855197..fb48ae5 100644 --- a/plonky2_verifier/noop_gate.go +++ b/plonky2_verifier/noop_gate.go @@ -7,6 +7,14 @@ import ( type NoopGate struct { } -func (p *NoopGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []QuadraticExtension { +func NewNoopGate() *NoopGate { + return &NoopGate{} +} + +func (g *NoopGate) Id() string { + return "NoopGate" +} + +func (g *NoopGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { return []QuadraticExtension{} } diff --git a/plonky2_verifier/plonk.go b/plonky2_verifier/plonk.go index 05b2535..3e9546a 100644 --- a/plonky2_verifier/plonk.go +++ b/plonky2_verifier/plonk.go @@ -131,8 +131,10 @@ func (p *PlonkChip) evaluateGateConstraints(vars EvaluationVars) []QuadraticExte p.commonData.SelectorsInfo.NumSelectors(), ) - for _, constraint := range gateConstraints { - // assert j < commonData.NumGateConstraints + for j, constraint := range gateConstraints { + if uint64(j) >= p.commonData.NumGateConstraints { + panic("num_constraints() gave too low of a number") + } constraints[i] = p.qeAPI.AddExtension(constraints[i], constraint) } } diff --git a/plonky2_verifier/poseidon_gate.go b/plonky2_verifier/poseidon_gate.go index 4e68c0b..8f6c44d 100644 --- a/plonky2_verifier/poseidon_gate.go +++ b/plonky2_verifier/poseidon_gate.go @@ -5,6 +5,17 @@ import ( "gnark-plonky2-verifier/poseidon" ) +type PoseidonGate struct { +} + +func NewPoseidonGate() *PoseidonGate { + return &PoseidonGate{} +} + +func (g *PoseidonGate) Id() string { + return "PoseidonGate" +} + func (g *PoseidonGate) WireInput(i uint64) uint64 { return i } @@ -68,40 +79,37 @@ func (g *PoseidonGate) WiresEnd() uint64 { return START_FULL_1 + poseidon.HALF_N_FULL_ROUNDS*poseidon.SPONGE_WIDTH } -type PoseidonGate struct { -} - -func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []QuadraticExtension { +func (g *PoseidonGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { constraints := []QuadraticExtension{} - poseidonChip := poseidon.NewPoseidonChip(pc.api, NewFieldAPI(pc.api), pc.qeAPI) + poseidonChip := poseidon.NewPoseidonChip(p.api, NewFieldAPI(p.api), p.qeAPI) // Assert that `swap` is binary. - swap := vars.localWires[p.WireSwap()] - notSwap := pc.qeAPI.SubExtension(pc.qeAPI.FieldToQE(ONE_F), swap) - constraints = append(constraints, pc.qeAPI.MulExtension(swap, notSwap)) + swap := vars.localWires[g.WireSwap()] + notSwap := p.qeAPI.SubExtension(p.qeAPI.FieldToQE(ONE_F), swap) + constraints = append(constraints, p.qeAPI.MulExtension(swap, notSwap)) // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i := uint64(0); i < 4; i++ { - inputLhs := vars.localWires[p.WireInput(i)] - inputRhs := vars.localWires[p.WireInput(i+4)] - deltaI := vars.localWires[p.WireDelta(i)] - diff := pc.qeAPI.SubExtension(inputRhs, inputLhs) - expectedDeltaI := pc.qeAPI.MulExtension(swap, diff) - constraints = append(constraints, pc.qeAPI.SubExtension(expectedDeltaI, deltaI)) + inputLhs := vars.localWires[g.WireInput(i)] + inputRhs := vars.localWires[g.WireInput(i+4)] + deltaI := vars.localWires[g.WireDelta(i)] + diff := p.qeAPI.SubExtension(inputRhs, inputLhs) + expectedDeltaI := p.qeAPI.MulExtension(swap, diff) + constraints = append(constraints, p.qeAPI.SubExtension(expectedDeltaI, deltaI)) } // Compute the possibly-swapped input layer. var state [poseidon.SPONGE_WIDTH]QuadraticExtension for i := uint64(0); i < 4; i++ { - deltaI := vars.localWires[p.WireDelta(i)] - inputLhs := vars.localWires[p.WireInput(i)] - inputRhs := vars.localWires[p.WireInput(i+4)] - state[i] = pc.qeAPI.AddExtension(inputLhs, deltaI) - state[i+4] = pc.qeAPI.SubExtension(inputRhs, deltaI) + deltaI := vars.localWires[g.WireDelta(i)] + inputLhs := vars.localWires[g.WireInput(i)] + inputRhs := vars.localWires[g.WireInput(i+4)] + state[i] = p.qeAPI.AddExtension(inputLhs, deltaI) + state[i+4] = p.qeAPI.SubExtension(inputRhs, deltaI) } for i := uint64(8); i < poseidon.SPONGE_WIDTH; i++ { - state[i] = vars.localWires[p.WireInput(i)] + state[i] = vars.localWires[g.WireInput(i)] } round_ctr := 0 @@ -111,8 +119,8 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad state = poseidonChip.ConstantLayerExtension(state, &round_ctr) if r != 0 { for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ { - sBoxIn := vars.localWires[p.WireFullSBox0(r, i)] - constraints = append(constraints, pc.qeAPI.SubExtension(state[i], sBoxIn)) + sBoxIn := vars.localWires[g.WireFullSBox0(r, i)] + constraints = append(constraints, p.qeAPI.SubExtension(state[i], sBoxIn)) state[i] = sBoxIn } } @@ -125,14 +133,14 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad state = poseidonChip.PartialFirstConstantLayerExtension(state) state = poseidonChip.MdsPartialLayerInitExtension(state) for r := uint64(0); r < poseidon.N_PARTIAL_ROUNDS-1; r++ { - sBoxIn := vars.localWires[p.WirePartialSBox(r)] - constraints = append(constraints, pc.qeAPI.SubExtension(state[0], sBoxIn)) + sBoxIn := vars.localWires[g.WirePartialSBox(r)] + constraints = append(constraints, p.qeAPI.SubExtension(state[0], sBoxIn)) state[0] = poseidonChip.SBoxMonomialExtension(sBoxIn) - state[0] = pc.qeAPI.AddExtension(state[0], pc.qeAPI.FieldToQE(NewFieldElement(poseidon.FAST_PARTIAL_ROUND_CONSTANTS[r]))) + state[0] = p.qeAPI.AddExtension(state[0], p.qeAPI.FieldToQE(NewFieldElement(poseidon.FAST_PARTIAL_ROUND_CONSTANTS[r]))) state = poseidonChip.MdsPartialLayerFastExtension(state, int(r)) } - sBoxIn := vars.localWires[p.WirePartialSBox(poseidon.N_PARTIAL_ROUNDS-1)] - constraints = append(constraints, pc.qeAPI.SubExtension(state[0], sBoxIn)) + sBoxIn := vars.localWires[g.WirePartialSBox(poseidon.N_PARTIAL_ROUNDS-1)] + constraints = append(constraints, p.qeAPI.SubExtension(state[0], sBoxIn)) state[0] = poseidonChip.SBoxMonomialExtension(sBoxIn) state = poseidonChip.MdsPartialLayerFastExtension(state, poseidon.N_PARTIAL_ROUNDS-1) round_ctr += poseidon.N_PARTIAL_ROUNDS @@ -141,8 +149,8 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad for r := uint64(0); r < poseidon.HALF_N_FULL_ROUNDS; r++ { poseidonChip.ConstantLayerExtension(state, &round_ctr) for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ { - sBoxIn := vars.localWires[p.WireFullSBox1(r, i)] - constraints = append(constraints, pc.qeAPI.SubExtension(state[i], sBoxIn)) + sBoxIn := vars.localWires[g.WireFullSBox1(r, i)] + constraints = append(constraints, p.qeAPI.SubExtension(state[i], sBoxIn)) state[i] = sBoxIn } state = poseidonChip.MdsLayerExtension(state) @@ -151,7 +159,7 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad } for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ { - constraints = append(constraints, pc.qeAPI.SubExtension(state[i], vars.localWires[p.WireOutput(i)])) + constraints = append(constraints, p.qeAPI.SubExtension(state[i], vars.localWires[g.WireOutput(i)])) } return constraints diff --git a/plonky2_verifier/public_input_gate.go b/plonky2_verifier/public_input_gate.go index 1804560..30fa95c 100644 --- a/plonky2_verifier/public_input_gate.go +++ b/plonky2_verifier/public_input_gate.go @@ -7,20 +7,28 @@ import ( type PublicInputGate struct { } +func NewPublicInputGate() *PublicInputGate { + return &PublicInputGate{} +} + +func (g *PublicInputGate) Id() string { + return "PublicInputGate" +} + func (g *PublicInputGate) WiresPublicInputsHash() []uint64 { return []uint64{0, 1, 2, 3} } -func (p *PublicInputGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []QuadraticExtension { +func (g *PublicInputGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { constraints := []QuadraticExtension{} - wires := p.WiresPublicInputsHash() + wires := g.WiresPublicInputsHash() hash_parts := vars.publicInputsHash for i := 0; i < 4; i++ { wire := wires[i] hash_part := hash_parts[i] - diff := pc.qeAPI.SubExtension(vars.localWires[wire], pc.qeAPI.FieldToQE(hash_part)) + diff := p.qeAPI.SubExtension(vars.localWires[wire], p.qeAPI.FieldToQE(hash_part)) constraints = append(constraints, diff) }