diff --git a/benchmark.go b/benchmark.go index d4e92bd..be6c137 100644 --- a/benchmark.go +++ b/benchmark.go @@ -28,7 +28,7 @@ func (circuit *BenchmarkPlonky2VerifierCircuit) Define(api frontend.API) error { fieldAPI := NewFieldAPI(api) qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits) hashAPI := NewHashAPI(fieldAPI) - poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI) + poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI, qeAPI) friChip := NewFriChip(api, fieldAPI, qeAPI, hashAPI, poseidonChip, &commonCircuitData.FriParams) plonkChip := NewPlonkChip(api, qeAPI, commonCircuitData) circuit.verifierChip = NewVerifierChip(api, fieldAPI, qeAPI, poseidonChip, plonkChip, friChip) diff --git a/plonky2_verifier/challenger_test.go b/plonky2_verifier/challenger_test.go index 20a86b0..f6914af 100644 --- a/plonky2_verifier/challenger_test.go +++ b/plonky2_verifier/challenger_test.go @@ -20,38 +20,40 @@ type TestChallengerCircuit struct { } func (circuit *TestChallengerCircuit) Define(api frontend.API) error { - field := field.NewFieldAPI(api) - poseidonChip := NewPoseidonChip(api, field) - challengerChip := NewChallengerChip(api, field, poseidonChip) + fieldAPI := field.NewFieldAPI(api) + degreeBits := 3 + qeAPI := NewQuadraticExtensionAPI(fieldAPI, uint64(degreeBits)) + poseidonChip := NewPoseidonChip(api, fieldAPI, qeAPI) + challengerChip := NewChallengerChip(api, fieldAPI, poseidonChip) var circuitDigest [4]F for i := 0; i < len(circuitDigest); i++ { - circuitDigest[i] = field.FromBinary(api.ToBinary(circuit.CircuitDigest[i], 64)).(F) + circuitDigest[i] = fieldAPI.FromBinary(api.ToBinary(circuit.CircuitDigest[i], 64)).(F) } var publicInputs [3]F for i := 0; i < len(publicInputs); i++ { - publicInputs[i] = field.FromBinary(api.ToBinary(circuit.PublicInputs[i], 64)).(F) + publicInputs[i] = fieldAPI.FromBinary(api.ToBinary(circuit.PublicInputs[i], 64)).(F) } var wiresCap [16][4]F for i := 0; i < len(wiresCap); i++ { for j := 0; j < len(wiresCap[0]); j++ { - wiresCap[i][j] = field.FromBinary(api.ToBinary(circuit.WiresCap[i][j], 64)).(F) + wiresCap[i][j] = fieldAPI.FromBinary(api.ToBinary(circuit.WiresCap[i][j], 64)).(F) } } var plonkZsPartialProductsCap [16][4]F for i := 0; i < len(plonkZsPartialProductsCap); i++ { for j := 0; j < len(plonkZsPartialProductsCap[0]); j++ { - plonkZsPartialProductsCap[i][j] = field.FromBinary(api.ToBinary(circuit.PlonkZsPartialProductsCap[i][j], 64)).(F) + plonkZsPartialProductsCap[i][j] = fieldAPI.FromBinary(api.ToBinary(circuit.PlonkZsPartialProductsCap[i][j], 64)).(F) } } var quotientPolysCap [16][4]F for i := 0; i < len(quotientPolysCap); i++ { for j := 0; j < len(quotientPolysCap[0]); j++ { - quotientPolysCap[i][j] = field.FromBinary(api.ToBinary(circuit.QuotientPolysCap[i][j], 64)).(F) + quotientPolysCap[i][j] = fieldAPI.FromBinary(api.ToBinary(circuit.QuotientPolysCap[i][j], 64)).(F) } } @@ -72,7 +74,7 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error { } for i := 0; i < 4; i++ { - field.AssertIsEqual(publicInputHash[i], expectedPublicInputHash[i]) + fieldAPI.AssertIsEqual(publicInputHash[i], expectedPublicInputHash[i]) } expectedPlonkBetas := [2]F{ @@ -86,8 +88,8 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error { } for i := 0; i < 2; i++ { - field.AssertIsEqual(plonkBetas[i], expectedPlonkBetas[i]) - field.AssertIsEqual(plonkGammas[i], expectedPlonkGammas[i]) + fieldAPI.AssertIsEqual(plonkBetas[i], expectedPlonkBetas[i]) + fieldAPI.AssertIsEqual(plonkGammas[i], expectedPlonkGammas[i]) } challengerChip.ObserveCap(plonkZsPartialProductsCap[:]) @@ -99,7 +101,7 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error { } for i := 0; i < 2; i++ { - field.AssertIsEqual(plonkAlphas[i], expectedPlonkAlphas[i]) + fieldAPI.AssertIsEqual(plonkAlphas[i], expectedPlonkAlphas[i]) } challengerChip.ObserveCap(quotientPolysCap[:]) @@ -111,7 +113,7 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error { } for i := 0; i < 2; i++ { - field.AssertIsEqual(plonkZeta[i], expectedPlonkZeta[i]) + fieldAPI.AssertIsEqual(plonkZeta[i], expectedPlonkZeta[i]) } return nil diff --git a/plonky2_verifier/fri_test.go b/plonky2_verifier/fri_test.go index ef15043..0cf501d 100644 --- a/plonky2_verifier/fri_test.go +++ b/plonky2_verifier/fri_test.go @@ -29,7 +29,7 @@ func (circuit *TestFriCircuit) Define(api frontend.API) error { fieldAPI := NewFieldAPI(api) qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits) hashAPI := NewHashAPI(fieldAPI) - poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI) + poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI, qeAPI) friChip := NewFriChip(api, fieldAPI, qeAPI, hashAPI, poseidonChip, &commonCircuitData.FriParams) friChallenges := FriChallenges{ diff --git a/plonky2_verifier/plonk_test.go b/plonky2_verifier/plonk_test.go index d710068..75348c1 100644 --- a/plonky2_verifier/plonk_test.go +++ b/plonky2_verifier/plonk_test.go @@ -23,8 +23,8 @@ func (circuit *TestPlonkCircuit) Define(api frontend.API) error { proofWithPis := DeserializeProofWithPublicInputs(circuit.proofWithPIsFilename) commonCircuitData := DeserializeCommonCircuitData(circuit.commonCircuitDataFilename) - field := NewFieldAPI(api) - qe := NewQuadraticExtensionAPI(field, commonCircuitData.DegreeBits) + fieldAPI := NewFieldAPI(api) + qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits) proofChallenges := ProofChallenges{ PlonkBetas: circuit.plonkBetas, @@ -33,9 +33,9 @@ func (circuit *TestPlonkCircuit) Define(api frontend.API) error { PlonkZeta: circuit.plonkZeta, } - plonkChip := NewPlonkChip(api, qe, commonCircuitData) + plonkChip := NewPlonkChip(api, qeAPI, commonCircuitData) - poseidonChip := poseidon.NewPoseidonChip(api, field) + poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI, qeAPI) publicInputsHash := poseidonChip.HashNoPad(proofWithPis.PublicInputs) plonkChip.Verify(proofChallenges, proofWithPis.Proof.Openings, publicInputsHash) diff --git a/plonky2_verifier/poseidon_gate.go b/plonky2_verifier/poseidon_gate.go index 3980fc9..4e68c0b 100644 --- a/plonky2_verifier/poseidon_gate.go +++ b/plonky2_verifier/poseidon_gate.go @@ -74,6 +74,8 @@ type PoseidonGate struct { func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []QuadraticExtension { constraints := []QuadraticExtension{} + poseidonChip := poseidon.NewPoseidonChip(pc.api, NewFieldAPI(pc.api), pc.qeAPI) + // Assert that `swap` is binary. swap := vars.localWires[p.WireSwap()] notSwap := pc.qeAPI.SubExtension(pc.qeAPI.FieldToQE(ONE_F), swap) @@ -90,7 +92,7 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad } // Compute the possibly-swapped input layer. - state := make([]QuadraticExtension, poseidon.SPONGE_WIDTH) + var state [poseidon.SPONGE_WIDTH]QuadraticExtension for i := uint64(0); i < 4; i++ { deltaI := vars.localWires[p.WireDelta(i)] inputLhs := vars.localWires[p.WireInput(i)] @@ -106,7 +108,7 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad // First set of full rounds. for r := uint64(0); r < poseidon.HALF_N_FULL_ROUNDS; r++ { - // TODO: constantLayerField(state, round_ctr) + state = poseidonChip.ConstantLayerExtension(state, &round_ctr) if r != 0 { for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ { sBoxIn := vars.localWires[p.WireFullSBox0(r, i)] @@ -114,37 +116,37 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad state[i] = sBoxIn } } - // TODO: sboxLayerField(state) - // TODO: state = mdsLayerField(state) + state = poseidonChip.SBoxLayerExtension(state) + state = poseidonChip.MdsLayerExtension(state) round_ctr++ } // Partial rounds. - // TODO: partialFirstConstantLayer(state) - // TODO: state = mdsParitalLayerInit(state) + state = poseidonChip.PartialFirstConstantLayerExtension(state) + state = poseidonChip.MdsPartialLayerInitExtension(state) for r := uint64(0); r < poseidon.N_PARTIAL_ROUNDS-1; r++ { sBoxIn := vars.localWires[p.WirePartialSBox(r)] constraints = append(constraints, pc.qeAPI.SubExtension(state[0], sBoxIn)) - // TODO: state[0] = sBoxMonomial(sBoxIn) - // TODO: state[0] += NewFieldElement(FAST_PARTIAL_ROUND_CONSTANTS[r]) - // TODO: state = mdsParitalLayerFastField(state, r) + state[0] = poseidonChip.SBoxMonomialExtension(sBoxIn) + state[0] = pc.qeAPI.AddExtension(state[0], pc.qeAPI.FieldToQE(NewFieldElement(poseidon.FAST_PARTIAL_ROUND_CONSTANTS[r]))) + state = poseidonChip.MdsPartialLayerFastExtension(state, int(r)) } sBoxIn := vars.localWires[p.WirePartialSBox(poseidon.N_PARTIAL_ROUNDS-1)] constraints = append(constraints, pc.qeAPI.SubExtension(state[0], sBoxIn)) - // TODO: state[0] = sBoxMonomial(sBoxIn) - // TODO: state = mdsPartialLayerLastField(state, poseidon.N_PARTIAL_ROUNDS - 1) + state[0] = poseidonChip.SBoxMonomialExtension(sBoxIn) + state = poseidonChip.MdsPartialLayerFastExtension(state, poseidon.N_PARTIAL_ROUNDS-1) round_ctr += poseidon.N_PARTIAL_ROUNDS // Second set of full rounds. for r := uint64(0); r < poseidon.HALF_N_FULL_ROUNDS; r++ { - // TODO: constantLayerField(state, round_ctr) + poseidonChip.ConstantLayerExtension(state, &round_ctr) for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ { sBoxIn := vars.localWires[p.WireFullSBox1(r, i)] constraints = append(constraints, pc.qeAPI.SubExtension(state[i], sBoxIn)) state[i] = sBoxIn } - // TODO: sboxLayerField(state) - // TODO: state = mdsLayerField(state) + state = poseidonChip.MdsLayerExtension(state) + state = poseidonChip.SBoxLayerExtension(state) round_ctr++ } diff --git a/plonky2_verifier/quadratic_extension.go b/plonky2_verifier/quadratic_extension.go deleted file mode 100644 index 80fdc1c..0000000 --- a/plonky2_verifier/quadratic_extension.go +++ /dev/null @@ -1,164 +0,0 @@ -package plonky2_verifier - -import ( - "fmt" - . "gnark-plonky2-verifier/field" - "math/bits" - - "github.com/consensys/gnark/frontend" -) - -type QuadraticExtensionAPI struct { - fieldAPI frontend.API - - W F - DTH_ROOT F - - ONE_QE QuadraticExtension - ZERO_QE QuadraticExtension -} - -func NewQuadraticExtensionAPI(fieldAPI frontend.API, degreeBits uint64) *QuadraticExtensionAPI { - // TODO: Should degreeBits be verified that it fits within the field and that degree is within uint64? - - return &QuadraticExtensionAPI{ - fieldAPI: fieldAPI, - - W: NewFieldElement(7), - DTH_ROOT: NewFieldElement(18446744069414584320), - - ONE_QE: QuadraticExtension{ONE_F, ZERO_F}, - ZERO_QE: QuadraticExtension{ZERO_F, ZERO_F}, - } -} - -func (c *QuadraticExtensionAPI) SquareExtension(a QuadraticExtension) QuadraticExtension { - return c.MulExtension(a, a) -} - -func (c *QuadraticExtensionAPI) MulExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension { - c_0 := c.fieldAPI.Add(c.fieldAPI.Mul(a[0], b[0]).(F), c.fieldAPI.Mul(c.W, a[1], b[1])).(F) - c_1 := c.fieldAPI.Add(c.fieldAPI.Mul(a[0], b[1]).(F), c.fieldAPI.Mul(a[1], b[0])).(F) - return QuadraticExtension{c_0, c_1} -} - -func (c *QuadraticExtensionAPI) AddExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension { - c_0 := c.fieldAPI.Add(a[0], b[0]).(F) - c_1 := c.fieldAPI.Add(a[1], b[1]).(F) - return QuadraticExtension{c_0, c_1} -} - -func (c *QuadraticExtensionAPI) SubExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension { - c_0 := c.fieldAPI.Sub(a[0], b[0]).(F) - c_1 := c.fieldAPI.Sub(a[1], b[1]).(F) - return QuadraticExtension{c_0, c_1} -} - -func (c *QuadraticExtensionAPI) DivExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension { - return c.MulExtension(a, c.InverseExtension(b)) -} - -func (c *QuadraticExtensionAPI) IsZero(a QuadraticExtension) frontend.Variable { - return c.fieldAPI.Mul(c.fieldAPI.IsZero(a[0]), c.fieldAPI.IsZero(a[1])) -} - -// TODO: Instead of calculating the inverse within the circuit, can witness the -// inverse and assert that a_inverse * a = 1. Should reduce # of constraints. -func (c *QuadraticExtensionAPI) InverseExtension(a QuadraticExtension) QuadraticExtension { - // First assert that a doesn't have 0 value coefficients - a0_is_zero := c.fieldAPI.IsZero(a[0]) - a1_is_zero := c.fieldAPI.IsZero(a[1]) - - // assert that a0_is_zero OR a1_is_zero == false - c.fieldAPI.AssertIsEqual(c.fieldAPI.Mul(a0_is_zero, a1_is_zero).(F), ZERO_F) - - a_pow_r_minus_1 := QuadraticExtension{a[0], c.fieldAPI.Mul(a[1], c.DTH_ROOT).(F)} - a_pow_r := c.MulExtension(a_pow_r_minus_1, a) - return c.ScalarMulExtension(a_pow_r_minus_1, c.fieldAPI.Inverse(a_pow_r[0]).(F)) -} - -func (c *QuadraticExtensionAPI) ScalarMulExtension(a QuadraticExtension, scalar F) QuadraticExtension { - return QuadraticExtension{c.fieldAPI.Mul(a[0], scalar).(F), c.fieldAPI.Mul(a[1], scalar).(F)} -} - -func (c *QuadraticExtensionAPI) FieldToQE(a F) QuadraticExtension { - return QuadraticExtension{a, ZERO_F} -} - -// / Exponentiate `base` to the power of a known `exponent`. -func (c *QuadraticExtensionAPI) ExpU64Extension(a QuadraticExtension, exponent uint64) QuadraticExtension { - switch exponent { - case 0: - return c.ONE_QE - case 1: - return a - case 2: - return c.SquareExtension(a) - default: - } - - current := a - product := c.ONE_QE - - for i := 0; i < bits.Len64(exponent); i++ { - if i != 0 { - current = c.SquareExtension(current) - } - - if (exponent >> i & 1) != 0 { - product = c.MulExtension(product, current) - } - } - - return product -} - -func (c *QuadraticExtensionAPI) ReduceWithPowers(terms []QuadraticExtension, scalar QuadraticExtension) QuadraticExtension { - sum := c.ZERO_QE - - for i := len(terms) - 1; i >= 0; i-- { - sum = c.AddExtension( - c.MulExtension( - sum, - scalar, - ), - terms[i], - ) - } - - return sum -} - -func (c *QuadraticExtensionAPI) Select(b0 frontend.Variable, qe0, qe1 QuadraticExtension) QuadraticExtension { - var retQE QuadraticExtension - - for i := 0; i < 2; i++ { - retQE[i] = c.fieldAPI.Select(b0, qe0[i], qe1[i]).(F) - } - - return retQE -} - -func (c *QuadraticExtensionAPI) Lookup2(b0 frontend.Variable, b1 frontend.Variable, qe0, qe1, qe2, qe3 QuadraticExtension) QuadraticExtension { - var retQE QuadraticExtension - - for i := 0; i < 2; i++ { - retQE[i] = c.fieldAPI.Lookup2(b0, b1, qe0[i], qe1[i], qe2[i], qe3[i]).(F) - } - - return retQE -} - -func (c *QuadraticExtensionAPI) AssertIsEqual(a, b QuadraticExtension) { - for i := 0; i < 2; i++ { - c.fieldAPI.AssertIsEqual(a[0], b[0]) - } -} - -func (c *QuadraticExtensionAPI) Println(a QuadraticExtension) { - fmt.Print("Degree 0 coefficient") - c.fieldAPI.Println(a[0]) - - fmt.Print("Degree 1 coefficient") - c.fieldAPI.Println(a[1]) -} diff --git a/plonky2_verifier/quadratic_extension_test.go b/plonky2_verifier/quadratic_extension_test.go index ca047e6..5b61913 100644 --- a/plonky2_verifier/quadratic_extension_test.go +++ b/plonky2_verifier/quadratic_extension_test.go @@ -21,14 +21,14 @@ type TestQuadraticExtensionMulCircuit struct { } func (c *TestQuadraticExtensionMulCircuit) Define(api frontend.API) error { - field := field.NewFieldAPI(api) + fieldAPI := field.NewFieldAPI(api) degreeBits := 3 - c.qeAPI = NewQuadraticExtensionAPI(field, uint64(degreeBits)) + c.qeAPI = NewQuadraticExtensionAPI(fieldAPI, uint64(degreeBits)) actualRes := c.qeAPI.MulExtension(c.operand1, c.operand2) - field.AssertIsEqual(actualRes[0], c.expectedResult[0]) - field.AssertIsEqual(actualRes[1], c.expectedResult[1]) + fieldAPI.AssertIsEqual(actualRes[0], c.expectedResult[0]) + fieldAPI.AssertIsEqual(actualRes[1], c.expectedResult[1]) return nil } @@ -55,14 +55,14 @@ type TestQuadraticExtensionDivCircuit struct { } func (c *TestQuadraticExtensionDivCircuit) Define(api frontend.API) error { - field := field.NewFieldAPI(api) + fieldAPI := field.NewFieldAPI(api) degreeBits := 3 - c.qeAPI = NewQuadraticExtensionAPI(field, uint64(degreeBits)) + c.qeAPI = NewQuadraticExtensionAPI(fieldAPI, uint64(degreeBits)) actualRes := c.qeAPI.DivExtension(c.operand1, c.operand2) - field.AssertIsEqual(actualRes[0], c.expectedResult[0]) - field.AssertIsEqual(actualRes[1], c.expectedResult[1]) + fieldAPI.AssertIsEqual(actualRes[0], c.expectedResult[0]) + fieldAPI.AssertIsEqual(actualRes[1], c.expectedResult[1]) return nil } diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index 76a83cd..d28e625 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -11,33 +11,35 @@ const N_FULL_ROUNDS_TOTAL = 2 * HALF_N_FULL_ROUNDS const N_PARTIAL_ROUNDS = 22 const N_ROUNDS = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS const MAX_WIDTH = 12 -const WIDTH = 12 const SPONGE_WIDTH = 12 const SPONGE_RATE = 8 -type PoseidonState = [WIDTH]F +type PoseidonState = [SPONGE_WIDTH]F +type PoseidonStateExtension = [SPONGE_WIDTH]QuadraticExtension + type PoseidonChip struct { - api frontend.API `gnark:"-"` - field frontend.API `gnark:"-"` + api frontend.API `gnark:"-"` + fieldAPI frontend.API `gnark:"-"` + qeAPI *QuadraticExtensionAPI `gnark:"-"` } -func NewPoseidonChip(api frontend.API, field frontend.API) *PoseidonChip { - return &PoseidonChip{api: api, field: field} +func NewPoseidonChip(api frontend.API, field frontend.API, qeAPI *QuadraticExtensionAPI) *PoseidonChip { + return &PoseidonChip{api: api, fieldAPI: field} } func (c *PoseidonChip) Poseidon(input PoseidonState) PoseidonState { state := input roundCounter := 0 - state = c.fullRounds(state, &roundCounter) - state = c.partialRounds(state, &roundCounter) - state = c.fullRounds(state, &roundCounter) + state = c.FullRounds(state, &roundCounter) + state = c.PartialRounds(state, &roundCounter) + state = c.FullRounds(state, &roundCounter) return state } func (c *PoseidonChip) HashNToMNoPad(input []F, nbOutputs int) []F { var state PoseidonState - for i := 0; i < WIDTH; i++ { + for i := 0; i < SPONGE_WIDTH; i++ { state[i] = ZERO_F } @@ -69,24 +71,24 @@ func (c *PoseidonChip) HashNoPad(input []F) Hash { return hash } -func (c *PoseidonChip) fullRounds(state PoseidonState, roundCounter *int) PoseidonState { +func (c *PoseidonChip) FullRounds(state PoseidonState, roundCounter *int) PoseidonState { for i := 0; i < HALF_N_FULL_ROUNDS; i++ { - state = c.constantLayer(state, roundCounter) - state = c.sBoxLayer(state) - state = c.mdsLayer(state) + state = c.ConstantLayer(state, roundCounter) + state = c.SBoxLayer(state) + state = c.MdsLayer(state) *roundCounter += 1 } return state } -func (c *PoseidonChip) partialRounds(state PoseidonState, roundCounter *int) PoseidonState { - state = c.partialFirstConstantLayer(state) - state = c.mdsPartialLayerInit(state) +func (c *PoseidonChip) PartialRounds(state PoseidonState, roundCounter *int) PoseidonState { + state = c.PartialFirstConstantLayer(state) + state = c.MdsPartialLayerInit(state) for i := 0; i < N_PARTIAL_ROUNDS; i++ { - state[0] = c.sBoxMonomial(state[0]) - state[0] = c.field.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i]).(F) - state = c.mdsPartialLayerFast(state, i) + state[0] = c.SBoxMonomial(state[0]) + state[0] = c.fieldAPI.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i]).(F) + state = c.MdsPartialLayerFast(state, i) } *roundCounter += N_PARTIAL_ROUNDS @@ -94,38 +96,64 @@ func (c *PoseidonChip) partialRounds(state PoseidonState, roundCounter *int) Pos return state } -func (c *PoseidonChip) constantLayer(state PoseidonState, roundCounter *int) PoseidonState { +func (c *PoseidonChip) ConstantLayer(state PoseidonState, roundCounter *int) PoseidonState { + for i := 0; i < 12; i++ { + if i < SPONGE_WIDTH { + roundConstant := NewFieldElement(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)]) + state[i] = c.fieldAPI.Add(state[i], roundConstant).(F) + } + } + return state +} + +func (c *PoseidonChip) ConstantLayerExtension(state PoseidonStateExtension, roundCounter *int) PoseidonStateExtension { for i := 0; i < 12; i++ { - if i < WIDTH { - roundConstant := NewFieldElement(ALL_ROUND_CONSTANTS[i+WIDTH*(*roundCounter)]) - state[i] = c.field.Add(state[i], roundConstant).(F) + if i < SPONGE_WIDTH { + roundConstant := c.qeAPI.FieldToQE(NewFieldElement(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)])) + state[i] = c.qeAPI.AddExtension(state[i], roundConstant) } } return state } -func (c *PoseidonChip) sBoxLayer(state PoseidonState) PoseidonState { +func (c *PoseidonChip) SBoxMonomial(x F) F { + x2 := c.fieldAPI.Mul(x, x) + x4 := c.fieldAPI.Mul(x2, x2) + x3 := c.fieldAPI.Mul(x2, x) + return c.fieldAPI.Mul(x3, x4).(F) +} + +func (c *PoseidonChip) SBoxMonomialExtension(x QuadraticExtension) QuadraticExtension { + x2 := c.qeAPI.MulExtension(x, x) + x4 := c.qeAPI.MulExtension(x2, x2) + x3 := c.qeAPI.MulExtension(x2, x) + return c.qeAPI.MulExtension(x3, x4) +} + +func (c *PoseidonChip) SBoxLayer(state PoseidonState) PoseidonState { for i := 0; i < 12; i++ { - if i < WIDTH { - state[i] = c.sBoxMonomial(state[i]) + if i < SPONGE_WIDTH { + state[i] = c.SBoxMonomial(state[i]) } } return state } -func (c *PoseidonChip) sBoxMonomial(x F) F { - x2 := c.field.Mul(x, x) - x4 := c.field.Mul(x2, x2) - x3 := c.field.Mul(x2, x) - return c.field.Mul(x3, x4).(F) +func (c *PoseidonChip) SBoxLayerExtension(state PoseidonStateExtension) PoseidonStateExtension { + for i := 0; i < 12; i++ { + if i < SPONGE_WIDTH { + state[i] = c.SBoxMonomialExtension(state[i]) + } + } + return state } -func (c *PoseidonChip) mdsRowShf(r int, v [WIDTH]frontend.Variable) frontend.Variable { +func (c *PoseidonChip) MdsRowShf(r int, v [SPONGE_WIDTH]frontend.Variable) frontend.Variable { res := frontend.Variable(0) for i := 0; i < 12; i++ { - if i < WIDTH { - res1 := c.api.Mul(v[(i+r)%WIDTH], frontend.Variable(MDS_MATRIX_CIRC[i])) + if i < SPONGE_WIDTH { + res1 := c.api.Mul(v[(i+r)%SPONGE_WIDTH], frontend.Variable(MDS_MATRIX_CIRC[i])) res = c.api.Add(res, res1) } } @@ -134,38 +162,76 @@ func (c *PoseidonChip) mdsRowShf(r int, v [WIDTH]frontend.Variable) frontend.Var return res } -func (c *PoseidonChip) mdsLayer(state_ PoseidonState) PoseidonState { +func (c *PoseidonChip) MdsRowShfExtension(r int, v [SPONGE_WIDTH]QuadraticExtension) QuadraticExtension { + res := c.qeAPI.FieldToQE(NewFieldElement(0)) + + for i := 0; i < 12; i++ { + if i < SPONGE_WIDTH { + matrixVal := c.qeAPI.FieldToQE(NewFieldElement(MDS_MATRIX_CIRC[i])) + res1 := c.qeAPI.MulExtension(v[(i+r)%SPONGE_WIDTH], matrixVal) + res = c.qeAPI.AddExtension(res, res1) + } + } + + matrixVal := c.qeAPI.FieldToQE(NewFieldElement(MDS_MATRIX_DIAG[r])) + res = c.qeAPI.AddExtension(res, c.qeAPI.MulExtension(v[r], matrixVal)) + return res +} + +func (c *PoseidonChip) MdsLayer(state_ PoseidonState) PoseidonState { var result PoseidonState - for i := 0; i < WIDTH; i++ { + for i := 0; i < SPONGE_WIDTH; i++ { result[i] = NewFieldElement(0) } - var state [WIDTH]frontend.Variable - for i := 0; i < WIDTH; i++ { - state[i] = c.api.FromBinary(c.field.ToBinary(state_[i])...) + var state [SPONGE_WIDTH]frontend.Variable + for i := 0; i < SPONGE_WIDTH; i++ { + state[i] = c.api.FromBinary(c.fieldAPI.ToBinary(state_[i])...) } for r := 0; r < 12; r++ { - if r < WIDTH { - sum := c.mdsRowShf(r, state) + if r < SPONGE_WIDTH { + sum := c.MdsRowShf(r, state) bits := c.api.ToBinary(sum) - result[r] = c.field.FromBinary(bits).(F) + result[r] = c.fieldAPI.FromBinary(bits).(F) + } + } + + return result +} + +func (c *PoseidonChip) MdsLayerExtension(state_ PoseidonStateExtension) PoseidonStateExtension { + var result PoseidonStateExtension + + for r := 0; r < 12; r++ { + if r < SPONGE_WIDTH { + sum := c.MdsRowShfExtension(r, state_) + result[r] = sum } } return result } -func (c *PoseidonChip) partialFirstConstantLayer(state PoseidonState) PoseidonState { +func (c *PoseidonChip) PartialFirstConstantLayer(state PoseidonState) PoseidonState { for i := 0; i < 12; i++ { - if i < WIDTH { - state[i] = c.field.Add(state[i], NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])).(F) + if i < SPONGE_WIDTH { + state[i] = c.fieldAPI.Add(state[i], NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])).(F) } } return state } -func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState { +func (c *PoseidonChip) PartialFirstConstantLayerExtension(state PoseidonStateExtension) PoseidonStateExtension { + for i := 0; i < 12; i++ { + if i < SPONGE_WIDTH { + state[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]))) + } + } + return state +} + +func (c *PoseidonChip) MdsPartialLayerInit(state PoseidonState) PoseidonState { var result PoseidonState for i := 0; i < 12; i++ { result[i] = NewFieldElement(0) @@ -174,11 +240,11 @@ func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState { result[0] = state[0] for r := 1; r < 12; r++ { - if r < WIDTH { + if r < SPONGE_WIDTH { for d := 1; d < 12; d++ { - if d < WIDTH { + if d < SPONGE_WIDTH { t := NewFieldElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1]) - result[d] = c.field.Add(result[d], c.field.Mul(state[r], t)).(F) + result[d] = c.fieldAPI.Add(result[d], c.fieldAPI.Mul(state[r], t)).(F) } } } @@ -187,32 +253,77 @@ func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState { return result } -func (c *PoseidonChip) mdsPartialLayerFast(state PoseidonState, r int) PoseidonState { +func (c *PoseidonChip) MdsPartialLayerInitExtension(state PoseidonStateExtension) PoseidonStateExtension { + var result PoseidonStateExtension + for i := 0; i < 12; i++ { + result[i] = c.qeAPI.FieldToQE(NewFieldElement(0)) + } + + result[0] = state[0] + + for r := 1; r < 12; r++ { + if r < SPONGE_WIDTH { + for d := 1; d < 12; d++ { + if d < SPONGE_WIDTH { + t := c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])) + result[d] = c.qeAPI.AddExtension(result[d], c.qeAPI.MulExtension(state[r], t)) + } + } + } + } + + return result +} + +func (c *PoseidonChip) MdsPartialLayerFast(state PoseidonState, r int) PoseidonState { dSum := frontend.Variable(0) for i := 1; i < 12; i++ { - if i < WIDTH { + if i < SPONGE_WIDTH { t := frontend.Variable(FAST_PARTIAL_ROUND_W_HATS[r][i-1]) - si := c.api.FromBinary(c.field.ToBinary(state[i])...) + si := c.api.FromBinary(c.fieldAPI.ToBinary(state[i])...) dSum = c.api.Add(dSum, c.api.Mul(si, t)) } } - s0 := c.api.FromBinary(c.field.ToBinary(state[0])...) + s0 := c.api.FromBinary(c.fieldAPI.ToBinary(state[0])...) mds0to0 := frontend.Variable(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0]) dSum = c.api.Add(dSum, c.api.Mul(s0, mds0to0)) - d := c.field.FromBinary(c.api.ToBinary(dSum)) + d := c.fieldAPI.FromBinary(c.api.ToBinary(dSum)) var result PoseidonState - for i := 0; i < WIDTH; i++ { + for i := 0; i < SPONGE_WIDTH; i++ { result[i] = NewFieldElement(0) } result[0] = d.(F) for i := 1; i < 12; i++ { - if i < WIDTH { + if i < SPONGE_WIDTH { t := NewFieldElement(FAST_PARTIAL_ROUND_VS[r][i-1]) - result[i] = c.field.Add(state[i], c.field.Mul(state[0], t)).(F) + result[i] = c.fieldAPI.Add(state[i], c.fieldAPI.Mul(state[0], t)).(F) + } + } + + return result +} + +func (c *PoseidonChip) MdsPartialLayerFastExtension(state PoseidonStateExtension, r int) PoseidonStateExtension { + s0 := state[0] + mds0to0 := c.qeAPI.FieldToQE(NewFieldElement(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0])) + d := c.qeAPI.AddExtension(s0, mds0to0) + for i := 1; i < 12; i++ { + if i < SPONGE_WIDTH { + t := c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_ROUND_W_HATS[r][i-1])) + d = c.qeAPI.AddExtension(d, c.qeAPI.MulExtension(state[i], t)) + } + } + + var result PoseidonStateExtension + result[0] = d + for i := 1; i < 12; i++ { + if i < SPONGE_WIDTH { + t := c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_ROUND_VS[r][i-1])) + result[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.MulExtension(state[0], t)) } } diff --git a/poseidon/public_inputs_hash_test.go b/poseidon/public_inputs_hash_test.go index 7148464..d681ede 100644 --- a/poseidon/public_inputs_hash_test.go +++ b/poseidon/public_inputs_hash_test.go @@ -18,22 +18,22 @@ type TestPublicInputsHashCircuit struct { } func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error { - field := NewFieldAPI(api) + fieldAPI := NewFieldAPI(api) // BN254 -> Binary(64) -> F var input [3]F for i := 0; i < 3; i++ { - input[i] = field.FromBinary(api.ToBinary(circuit.In[i], 64)).(F) + input[i] = fieldAPI.FromBinary(api.ToBinary(circuit.In[i], 64)).(F) } - poseidonChip := &PoseidonChip{api: api, field: field} + poseidonChip := &PoseidonChip{api: api, fieldAPI: fieldAPI} output := poseidonChip.HashNoPad(input[:]) // Check that output is correct for i := 0; i < 4; i++ { - field.AssertIsEqual( + fieldAPI.AssertIsEqual( output[i], - field.FromBinary(api.ToBinary(circuit.Out[i])).(F), + fieldAPI.FromBinary(api.ToBinary(circuit.Out[i])).(F), ) }