diff --git a/field/field.go b/field/field.go index 5db5859..b30571d 100644 --- a/field/field.go +++ b/field/field.go @@ -10,6 +10,7 @@ import ( type EmulatedField = emulated.Goldilocks type F = emulated.Element[EmulatedField] type QuadraticExtension = [2]F +type QEAlgebra = [2]QuadraticExtension type Hash = [4]F var TEST_CURVE = ecc.BN254 diff --git a/field/quadratic_extension.go b/field/quadratic_extension.go index 42c3b5c..8c27212 100644 --- a/field/quadratic_extension.go +++ b/field/quadratic_extension.go @@ -154,6 +154,20 @@ func (c *QuadraticExtensionAPI) AssertIsEqual(a, b QuadraticExtension) { } } +func (c *QuadraticExtensionAPI) InnerProductExtension(constant F, startingAcc QuadraticExtension, pairs [][2]QuadraticExtension) QuadraticExtension { + acc := startingAcc + + for i := 0; i < len(pairs); i++ { + a := pairs[i][0] + b := pairs[i][1] + mul := c.ScalarMulExtension(a, constant) + mul = c.MulExtension(mul, b) + acc = c.AddExtension(acc, mul) + } + + return acc +} + func (c *QuadraticExtensionAPI) Println(a QuadraticExtension) { fmt.Print("Degree 0 coefficient") c.fieldAPI.Println(a[0]) @@ -161,3 +175,53 @@ func (c *QuadraticExtensionAPI) Println(a QuadraticExtension) { fmt.Print("Degree 1 coefficient") c.fieldAPI.Println(a[1]) } + +func (c *QuadraticExtensionAPI) MulExtensionAlgebra(a, b QEAlgebra) QEAlgebra { + var inner [2][][2]QuadraticExtension + var inner_w [2][][2]QuadraticExtension + for i := 0; i < 2; i++ { + for j := 0; j < 2-i; j++ { + idx := (i + j) % 2 + inner[idx] = append(inner[idx], [2]QuadraticExtension{a[i], b[j]}) + } + for j := 2 - i; j < 2; j++ { + idx := (i + j) % 2 + inner_w[idx] = append(inner_w[idx], [2]QuadraticExtension{a[i], b[j]}) + } + } + + var product QEAlgebra + for i := 0; i < 2; i++ { + acc := c.InnerProductExtension(NewFieldElement(7), c.ZERO_QE, inner_w[i]) + product[i] = c.InnerProductExtension(ONE_F, acc, inner[i]) + } + + return product +} + +func (c *QuadraticExtensionAPI) ScalarMulExtensionAlgebra(a QuadraticExtension, b QEAlgebra) QEAlgebra { + var product QEAlgebra + for i := 0; i < 2; i++ { + product[i] = c.MulExtension(a, b[i]) + } + + return product +} + +func (c *QuadraticExtensionAPI) AddExtensionAlgebra(a, b QEAlgebra) QEAlgebra { + var sum QEAlgebra + for i := 0; i < 2; i++ { + sum[i] = c.AddExtension(a[i], b[i]) + } + + return sum +} + +func (c *QuadraticExtensionAPI) SubExtensionAlgebra(a, b QEAlgebra) QEAlgebra { + var diff QEAlgebra + for i := 0; i < 2; i++ { + diff[i] = c.SubExtension(a[i], b[i]) + } + + return diff +} diff --git a/plonky2_verifier/arithmetic_extension_gate.go b/plonky2_verifier/arithmetic_extension_gate.go new file mode 100644 index 0000000..6cb4c2f --- /dev/null +++ b/plonky2_verifier/arithmetic_extension_gate.go @@ -0,0 +1,62 @@ +package plonky2_verifier + +import ( + "fmt" + . "gnark-plonky2-verifier/field" +) + +// Ideally, this should be serialized in the plonky2 repo +const d = 2 + +type ArithmeticExtensionGate struct { + numOps uint64 +} + +func NewArithmeticExtensionGate(numOps uint64) *ArithmeticExtensionGate { + return &ArithmeticExtensionGate{ + numOps: numOps, + } +} + +func (g *ArithmeticExtensionGate) Id() string { + return fmt.Sprintf("ArithmeticExtensionGate { num_ops: %d }", g.numOps) +} + +func (g *ArithmeticExtensionGate) wiresIthMultiplicand0(i uint64) Range { + return Range{4 * d * i, 4*d*i + d} +} + +func (g *ArithmeticExtensionGate) wiresIthMultiplicand1(i uint64) Range { + return Range{4*d*i + d, 4*d*i + 2*d} +} + +func (g *ArithmeticExtensionGate) wiresIthAddend(i uint64) Range { + return Range{4*d*i + 2*d, 4*d*i + 3*d} +} + +func (g *ArithmeticExtensionGate) wiresIthOutput(i uint64) Range { + return Range{4*d*i + 3*d, 4*d*i + 4*d} +} + +func (g *ArithmeticExtensionGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { + const0 := vars.localConstants[0] + const1 := vars.localConstants[1] + + constraints := []QuadraticExtension{} + for i := uint64(0); i < g.numOps; i++ { + multiplicand0 := vars.GetLocalExtAlgebra(g.wiresIthMultiplicand0(i)) + multiplicand1 := vars.GetLocalExtAlgebra(g.wiresIthMultiplicand1(i)) + addend := vars.GetLocalExtAlgebra(g.wiresIthAddend(i)) + output := vars.GetLocalExtAlgebra(g.wiresIthOutput(i)) + + mul := p.qeAPI.MulExtensionAlgebra(multiplicand0, multiplicand1) + scaled_mul := p.qeAPI.ScalarMulExtensionAlgebra(const0, mul) + computed_output := p.qeAPI.ScalarMulExtensionAlgebra(const1, addend) + computed_output = p.qeAPI.AddExtensionAlgebra(computed_output, scaled_mul) + + diff := p.qeAPI.SubExtensionAlgebra(output, computed_output) + constraints = append(constraints, diff[0], diff[1]) + } + + return constraints +} diff --git a/plonky2_verifier/gate.go b/plonky2_verifier/gate.go index 6d062bc..5332634 100644 --- a/plonky2_verifier/gate.go +++ b/plonky2_verifier/gate.go @@ -57,22 +57,12 @@ func GateInstanceFromId(gateId string) gate { } matches := getRegExMatches(r, gateId) - numLimbsStr, hasNumLimbs := matches["numLimbs"] - baseStr, hasBase := matches["base"] + numLimbs, hasNumLimbs := matches["numLimbs"] + base, hasBase := matches["base"] if !hasNumLimbs || !hasBase { panic("Invalid BaseSumGate ID") } - numLimbs, err := strconv.Atoi(numLimbsStr) - if err != nil { - panic("Invalid BaseSumGate ID: " + err.Error()) - } - - base, err := strconv.Atoi(baseStr) - if err != nil { - panic("Invalid BaseSumGate ID: " + err.Error()) - } - return NewBaseSumGate(uint64(numLimbs), uint64(base)) } @@ -86,41 +76,48 @@ func GateInstanceFromId(gateId string) gate { } matches := getRegExMatches(r, gateId) - bitsStr, hasBits := matches["bits"] - numCopiesStr, hasNumCopies := matches["numCopies"] - numExtraConstantsStr, hasNumExtraConstants := matches["numExtraConstants"] + bits, hasBits := matches["bits"] + numCopies, hasNumCopies := matches["numCopies"] + numExtraConstants, hasNumExtraConstants := matches["numExtraConstants"] if !hasBits || !hasNumCopies || !hasNumExtraConstants { panic("Invalid RandomAccessGate ID") } - bits, err := strconv.Atoi(bitsStr) - if err != nil { - panic("Invalid RandomAccessGate ID: " + err.Error()) - } + return NewRandomAccessGate(uint64(bits), uint64(numCopies), uint64(numExtraConstants)) + } - numCopies, err := strconv.Atoi(numCopiesStr) + if strings.HasPrefix(gateId, "ArithmeticExtension") { + // Has the format "ArithmeticExtensionGate { num_ops: 10 }" + + regEx := "ArithmeticExtensionGate { num_ops: (?P[0-9]+) }" + r, err := regexp.Compile(regEx) if err != nil { - panic("Invalid RandomAccessGate ID: " + err.Error()) + panic("Invalid ArithmeticExtensionGate regular expression") } - numExtraConstants, err := strconv.Atoi(numExtraConstantsStr) - if err != nil { - panic("Invalid RandomAccessGate ID: " + err.Error()) + matches := getRegExMatches(r, gateId) + numOps, hasNumOps := matches["numOps"] + if !hasNumOps { + panic("Invalid ArithmeticExtensionGate ID") } - return NewRandomAccessGate(uint64(bits), uint64(numCopies), uint64(numExtraConstants)) + return NewArithmeticExtensionGate(uint64(numOps)) } return nil //panic(fmt.Sprintf("Unknown gate ID %s", gateId)) } -func getRegExMatches(r *regexp.Regexp, gateId string) map[string]string { +func getRegExMatches(r *regexp.Regexp, gateId string) map[string]int { matches := r.FindStringSubmatch(gateId) - result := make(map[string]string) + result := make(map[string]int) for i, name := range r.SubexpNames() { if i != 0 && name != "" { - result[name] = matches[i] + value, err := strconv.Atoi(matches[i]) + if err != nil { + panic("Invalid field value for \"name\": " + err.Error()) + } + result[name] = value } } diff --git a/plonky2_verifier/gate_test.go b/plonky2_verifier/gate_test.go index c48e50c..1c31b10 100644 --- a/plonky2_verifier/gate_test.go +++ b/plonky2_verifier/gate_test.go @@ -767,6 +767,7 @@ func TestGates(t *testing.T) { {&BaseSumGate{numLimbs: 63, base: 2}, baseSumGateExpectedConstraints}, {&RandomAccessGate{bits: 4, numCopies: 4, numExtraConstants: 2}, randomAccessGateExpectedConstraints}, {&PoseidonGate{}, poseidonGateExpectedConstraints}, + {&ArithmeticExtensionGate{numOps: 10}, arithmeticExtensionGateExpectedConstraints}, } for _, test := range gateTests { diff --git a/plonky2_verifier/vars.go b/plonky2_verifier/vars.go index 1c78511..d2e6aa7 100644 --- a/plonky2_verifier/vars.go +++ b/plonky2_verifier/vars.go @@ -13,3 +13,12 @@ type EvaluationVars struct { func (e *EvaluationVars) RemovePrefix(numSelectors uint64) { e.localConstants = e.localConstants[numSelectors:] } + +func (e *EvaluationVars) GetLocalExtAlgebra(wireRange Range) QEAlgebra { + // For now, only support degree 2 + if wireRange.end-wireRange.start != 2 { + panic("Only degree 2 supported") + } + + return QEAlgebra{e.localWires[wireRange.start], e.localWires[wireRange.end-1]} +}