diff --git a/field/quadratic_extension.go b/field/quadratic_extension.go index f571d3e..42c3b5c 100644 --- a/field/quadratic_extension.go +++ b/field/quadratic_extension.go @@ -128,11 +128,11 @@ func (c *QuadraticExtensionAPI) ReduceWithPowers(terms []QuadraticExtension, sca return sum } -func (c *QuadraticExtensionAPI) Select(b0 frontend.Variable, qe0, qe1 QuadraticExtension) QuadraticExtension { +func (c *QuadraticExtensionAPI) Select(b frontend.Variable, qe0, qe1 QuadraticExtension) QuadraticExtension { var retQE QuadraticExtension for i := 0; i < 2; i++ { - retQE[i] = c.fieldAPI.Select(b0, qe0[i], qe1[i]).(F) + retQE[i] = c.fieldAPI.Select(b, qe0[i], qe1[i]).(F) } return retQE diff --git a/plonky2_verifier/gate.go b/plonky2_verifier/gate.go index e547b99..6d062bc 100644 --- a/plonky2_verifier/gate.go +++ b/plonky2_verifier/gate.go @@ -76,6 +76,41 @@ func GateInstanceFromId(gateId string) gate { return NewBaseSumGate(uint64(numLimbs), uint64(base)) } + if strings.HasPrefix(gateId, "RandomAccessGate") { + // Has the format "RandomAccessGate { bits: 2, num_copies: 13, num_extra_constants: 2, _phantom: PhantomData }" + + regEx := "RandomAccessGate { bits: (?P[0-9]+), num_copies: (?P[0-9]+), num_extra_constants: (?P[0-9]+), _phantom: PhantomData }[0-9]+)>" + r, err := regexp.Compile(regEx) + if err != nil { + panic("Invalid RandomAccessGate regular expression") + } + + matches := getRegExMatches(r, gateId) + bitsStr, hasBits := matches["bits"] + numCopiesStr, hasNumCopies := matches["numCopies"] + numExtraConstantsStr, hasNumExtraConstants := matches["numExtraConstants"] + if !hasBits || !hasNumCopies || !hasNumExtraConstants { + panic("Invalid RandomAccessGate ID") + } + + bits, err := strconv.Atoi(bitsStr) + if err != nil { + panic("Invalid RandomAccessGate ID: " + err.Error()) + } + + numCopies, err := strconv.Atoi(numCopiesStr) + if err != nil { + panic("Invalid RandomAccessGate ID: " + err.Error()) + } + + numExtraConstants, err := strconv.Atoi(numExtraConstantsStr) + if err != nil { + panic("Invalid RandomAccessGate ID: " + err.Error()) + } + + return NewRandomAccessGate(uint64(bits), uint64(numCopies), uint64(numExtraConstants)) + } + return nil //panic(fmt.Sprintf("Unknown gate ID %s", gateId)) } diff --git a/plonky2_verifier/gate_testing_utils.go b/plonky2_verifier/gate_testing_utils.go index 8d836ee..4bc68a7 100644 --- a/plonky2_verifier/gate_testing_utils.go +++ b/plonky2_verifier/gate_testing_utils.go @@ -218,3 +218,32 @@ var arithmeticGateExpectedConstraints = []QuadraticExtension{ {NewFieldElement(16837665759306664052), NewFieldElement(13229282844806523763)}, {NewFieldElement(15646775329525033386), NewFieldElement(7893047165846868816)}, } + +var randomAccessGateExpectedConstraints = []QuadraticExtension{ + {NewFieldElement(14096597590517523067), NewFieldElement(2655169409008419702)}, + {NewFieldElement(5150369100957105913), NewFieldElement(9690142804550213688)}, + {NewFieldElement(12862636728240805500), NewFieldElement(15885653298577690721)}, + {NewFieldElement(5938169082606588253), NewFieldElement(13375731264713600699)}, + {NewFieldElement(16556211561864036325), NewFieldElement(1097770550456310263)}, + {NewFieldElement(7110929822775027665), NewFieldElement(12197631192598781905)}, + {NewFieldElement(3526950454725789222), NewFieldElement(16581256211788295110)}, + {NewFieldElement(6704870069993050342), NewFieldElement(639095910170462201)}, + {NewFieldElement(15722010723337496870), NewFieldElement(1609594866420744764)}, + {NewFieldElement(3770790493236783721), NewFieldElement(1601875894399014690)}, + {NewFieldElement(8940214713698553170), NewFieldElement(1435550334204491513)}, + {NewFieldElement(5765635925648817913), NewFieldElement(17921434922626677797)}, + {NewFieldElement(11284135106973148775), NewFieldElement(12235917185065439961)}, + {NewFieldElement(11684377481625307024), NewFieldElement(4068304938130253402)}, + {NewFieldElement(16430390956600401383), NewFieldElement(5027375440531063469)}, + {NewFieldElement(12346816733826550618), NewFieldElement(8983232740461478925)}, + {NewFieldElement(7315134556374205868), NewFieldElement(10733792242004794605)}, + {NewFieldElement(9676902521951667374), NewFieldElement(17472456522303293623)}, + {NewFieldElement(3391573289049150552), NewFieldElement(13044958098760740211)}, + {NewFieldElement(7161062079224730414), NewFieldElement(14473293246671391425)}, + {NewFieldElement(590698465067972002), NewFieldElement(4791051041728641335)}, + {NewFieldElement(11301242955861918730), NewFieldElement(7313100973676377913)}, + {NewFieldElement(6327059471985261770), NewFieldElement(11232679988877321564)}, + {NewFieldElement(15485954821265539981), NewFieldElement(1201918074719834630)}, + {NewFieldElement(11416240451899794915), NewFieldElement(3127372561985201979)}, + {NewFieldElement(1915429544941884288), NewFieldElement(16698510309904634494)}, +} diff --git a/plonky2_verifier/random_access_gate.go b/plonky2_verifier/random_access_gate.go new file mode 100644 index 0000000..3ea67e7 --- /dev/null +++ b/plonky2_verifier/random_access_gate.go @@ -0,0 +1,140 @@ +package plonky2_verifier + +import ( + "fmt" + . "gnark-plonky2-verifier/field" +) + +type RandomAccessGate struct { + bits uint64 + numCopies uint64 + numExtraConstants uint64 +} + +func NewRandomAccessGate(bits uint64, numCopies uint64, numExtraConstants uint64) *RandomAccessGate { + return &RandomAccessGate{ + bits: bits, + numCopies: numCopies, + numExtraConstants: numExtraConstants, + } +} + +func (g *RandomAccessGate) Id() string { + return fmt.Sprintf("RandomAccessGate { bits: %d, num_copies: %d, num_extra_constants: %d }", g.bits, g.numCopies, g.numExtraConstants) +} + +func (g *RandomAccessGate) vecSize() uint64 { + return 1 << g.bits +} + +func (g *RandomAccessGate) WireAccessIndex(copy uint64) uint64 { + if copy >= g.numCopies { + panic("RandomAccessGate.WireAccessIndex called with copy >= num_copies") + } + return (2 + g.vecSize()) * copy +} + +func (g *RandomAccessGate) WireClaimedElement(copy uint64) uint64 { + if copy >= g.numCopies { + panic("RandomAccessGate.WireClaimedElement called with copy >= num_copies") + } + + return (2+g.vecSize())*copy + 1 +} + +func (g *RandomAccessGate) WireListItem(i uint64, copy uint64) uint64 { + if i >= g.vecSize() { + panic("RandomAccessGate.WireListItem called with i >= vec_size") + } + if copy >= g.numCopies { + panic("RandomAccessGate.WireListItem called with copy >= num_copies") + } + + return (2+g.vecSize())*copy + 2 + i +} + +func (g *RandomAccessGate) startExtraConstants() uint64 { + return (2 + g.vecSize()) * g.numCopies +} + +func (g *RandomAccessGate) wireExtraConstant(i uint64) uint64 { + if i >= g.numExtraConstants { + panic("RandomAccessGate.wireExtraConstant called with i >= num_extra_constants") + } + + return g.startExtraConstants() + i +} + +func (g *RandomAccessGate) NumRoutedWires() uint64 { + return g.startExtraConstants() + g.numExtraConstants +} + +func (g *RandomAccessGate) WireBit(i uint64, copy uint64) uint64 { + if i >= g.bits { + panic("RandomAccessGate.WireBit called with i >= bits") + } + if copy >= g.numCopies { + panic("RandomAccessGate.WireBit called with copy >= num_copies") + } + + return g.NumRoutedWires() + copy*g.bits + i +} + +func (g *RandomAccessGate) EvalUnfiltered(p *PlonkChip, vars EvaluationVars) []QuadraticExtension { + two := QuadraticExtension{NewFieldElement(2), NewFieldElement(0)} + constraints := []QuadraticExtension{} + + for copy := uint64(0); copy < g.numCopies; copy++ { + accessIndex := vars.localWires[g.WireAccessIndex(copy)] + listItems := []QuadraticExtension{} + for i := uint64(0); i < g.vecSize(); i++ { + listItems = append(listItems, vars.localWires[g.WireListItem(i, copy)]) + } + claimedElement := vars.localWires[g.WireClaimedElement(copy)] + bits := []QuadraticExtension{} + for i := uint64(0); i < g.bits; i++ { + bits = append(bits, vars.localWires[g.WireBit(i, copy)]) + } + + // Assert that each bit wire value is indeed boolean. + for _, b := range bits { + bSquared := p.qeAPI.MulExtension(b, b) + constraints = append(constraints, p.qeAPI.SubExtension(bSquared, b)) + } + + // Assert that the binary decomposition was correct. + reconstructedIndex := p.qeAPI.ReduceWithPowers(bits, two) + constraints = append(constraints, p.qeAPI.SubExtension(reconstructedIndex, accessIndex)) + + for _, b := range bits { + listItemsTmp := []QuadraticExtension{} + for i := 0; i < len(listItems); i += 2 { + x := listItems[i] + y := listItems[i+1] + + // This is computing `if b { x } else { y }` + // i.e. `bx - (by-y)`. + mul1 := p.qeAPI.MulExtension(b, x) + sub1 := p.qeAPI.SubExtension(mul1, x) + + mul2 := p.qeAPI.MulExtension(b, y) + sub2 := p.qeAPI.SubExtension(mul2, sub1) + + listItemsTmp = append(listItemsTmp, sub2) + } + listItems = listItemsTmp + } + + if len(listItems) != 1 { + panic("listItems(len) != 1") + } + + constraints = append(constraints, p.qeAPI.SubExtension(listItems[0], claimedElement)) + } + + for i := uint64(0); i < g.numExtraConstants; i++ { + constraints = append(constraints, p.qeAPI.SubExtension(vars.localConstants[i], vars.localWires[g.wireExtraConstant(i)])) + } + + return constraints +} diff --git a/plonky2_verifier/random_access_gate_test.go b/plonky2_verifier/random_access_gate_test.go new file mode 100644 index 0000000..d2ef3b9 --- /dev/null +++ b/plonky2_verifier/random_access_gate_test.go @@ -0,0 +1,49 @@ +package plonky2_verifier + +import ( + "errors" + . "gnark-plonky2-verifier/field" + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type TestRandomAccessGateCircuit struct{} + +func (circuit *TestRandomAccessGateCircuit) Define(api frontend.API) error { + commonCircuitData := DeserializeCommonCircuitData("./data/step/common_circuit_data.json") + numSelectors := len(commonCircuitData.SelectorsInfo.groups) + + fieldAPI := NewFieldAPI(api) + qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits) + plonkChip := NewPlonkChip(api, qeAPI, commonCircuitData) + + randomAccessGate := RandomAccessGate{bits: 4, numCopies: 4, numExtraConstants: 2} + vars := EvaluationVars{localConstants: localConstants[numSelectors:], localWires: localWires, publicInputsHash: publicInputsHash} + + constraints := randomAccessGate.EvalUnfiltered(plonkChip, vars) + + if len(constraints) != len(randomAccessGateExpectedConstraints) { + return errors.New("constant gate constraints length mismatch") + } + + for i := 0; i < len(constraints); i++ { + qeAPI.AssertIsEqual(constraints[i], randomAccessGateExpectedConstraints[i]) + } + + return nil +} + +func TestRandomAccessGate(t *testing.T) { + assert := test.NewAssert(t) + + testCase := func() { + circuit := TestRandomAccessGateCircuit{} + witness := TestRandomAccessGateCircuit{} + err := test.IsSolved(&circuit, &witness, TEST_CURVE.ScalarField()) + assert.NoError(err) + } + + testCase() +}