Browse Source

Poseidon extension function versions, and finished PoseidonGate

main
Nicholas Ward 2 years ago
parent
commit
f1dc02d30f
9 changed files with 219 additions and 268 deletions
  1. +1
    -1
      benchmark.go
  2. +15
    -13
      plonky2_verifier/challenger_test.go
  3. +1
    -1
      plonky2_verifier/fri_test.go
  4. +4
    -4
      plonky2_verifier/plonk_test.go
  5. +16
    -14
      plonky2_verifier/poseidon_gate.go
  6. +0
    -164
      plonky2_verifier/quadratic_extension.go
  7. +8
    -8
      plonky2_verifier/quadratic_extension_test.go
  8. +169
    -58
      poseidon/poseidon.go
  9. +5
    -5
      poseidon/public_inputs_hash_test.go

+ 1
- 1
benchmark.go

@ -28,7 +28,7 @@ func (circuit *BenchmarkPlonky2VerifierCircuit) Define(api frontend.API) error {
fieldAPI := NewFieldAPI(api)
qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits)
hashAPI := NewHashAPI(fieldAPI)
poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI)
poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI, qeAPI)
friChip := NewFriChip(api, fieldAPI, qeAPI, hashAPI, poseidonChip, &commonCircuitData.FriParams)
plonkChip := NewPlonkChip(api, qeAPI, commonCircuitData)
circuit.verifierChip = NewVerifierChip(api, fieldAPI, qeAPI, poseidonChip, plonkChip, friChip)

+ 15
- 13
plonky2_verifier/challenger_test.go

@ -20,38 +20,40 @@ type TestChallengerCircuit struct {
}
func (circuit *TestChallengerCircuit) Define(api frontend.API) error {
field := field.NewFieldAPI(api)
poseidonChip := NewPoseidonChip(api, field)
challengerChip := NewChallengerChip(api, field, poseidonChip)
fieldAPI := field.NewFieldAPI(api)
degreeBits := 3
qeAPI := NewQuadraticExtensionAPI(fieldAPI, uint64(degreeBits))
poseidonChip := NewPoseidonChip(api, fieldAPI, qeAPI)
challengerChip := NewChallengerChip(api, fieldAPI, poseidonChip)
var circuitDigest [4]F
for i := 0; i < len(circuitDigest); i++ {
circuitDigest[i] = field.FromBinary(api.ToBinary(circuit.CircuitDigest[i], 64)).(F)
circuitDigest[i] = fieldAPI.FromBinary(api.ToBinary(circuit.CircuitDigest[i], 64)).(F)
}
var publicInputs [3]F
for i := 0; i < len(publicInputs); i++ {
publicInputs[i] = field.FromBinary(api.ToBinary(circuit.PublicInputs[i], 64)).(F)
publicInputs[i] = fieldAPI.FromBinary(api.ToBinary(circuit.PublicInputs[i], 64)).(F)
}
var wiresCap [16][4]F
for i := 0; i < len(wiresCap); i++ {
for j := 0; j < len(wiresCap[0]); j++ {
wiresCap[i][j] = field.FromBinary(api.ToBinary(circuit.WiresCap[i][j], 64)).(F)
wiresCap[i][j] = fieldAPI.FromBinary(api.ToBinary(circuit.WiresCap[i][j], 64)).(F)
}
}
var plonkZsPartialProductsCap [16][4]F
for i := 0; i < len(plonkZsPartialProductsCap); i++ {
for j := 0; j < len(plonkZsPartialProductsCap[0]); j++ {
plonkZsPartialProductsCap[i][j] = field.FromBinary(api.ToBinary(circuit.PlonkZsPartialProductsCap[i][j], 64)).(F)
plonkZsPartialProductsCap[i][j] = fieldAPI.FromBinary(api.ToBinary(circuit.PlonkZsPartialProductsCap[i][j], 64)).(F)
}
}
var quotientPolysCap [16][4]F
for i := 0; i < len(quotientPolysCap); i++ {
for j := 0; j < len(quotientPolysCap[0]); j++ {
quotientPolysCap[i][j] = field.FromBinary(api.ToBinary(circuit.QuotientPolysCap[i][j], 64)).(F)
quotientPolysCap[i][j] = fieldAPI.FromBinary(api.ToBinary(circuit.QuotientPolysCap[i][j], 64)).(F)
}
}
@ -72,7 +74,7 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error {
}
for i := 0; i < 4; i++ {
field.AssertIsEqual(publicInputHash[i], expectedPublicInputHash[i])
fieldAPI.AssertIsEqual(publicInputHash[i], expectedPublicInputHash[i])
}
expectedPlonkBetas := [2]F{
@ -86,8 +88,8 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error {
}
for i := 0; i < 2; i++ {
field.AssertIsEqual(plonkBetas[i], expectedPlonkBetas[i])
field.AssertIsEqual(plonkGammas[i], expectedPlonkGammas[i])
fieldAPI.AssertIsEqual(plonkBetas[i], expectedPlonkBetas[i])
fieldAPI.AssertIsEqual(plonkGammas[i], expectedPlonkGammas[i])
}
challengerChip.ObserveCap(plonkZsPartialProductsCap[:])
@ -99,7 +101,7 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error {
}
for i := 0; i < 2; i++ {
field.AssertIsEqual(plonkAlphas[i], expectedPlonkAlphas[i])
fieldAPI.AssertIsEqual(plonkAlphas[i], expectedPlonkAlphas[i])
}
challengerChip.ObserveCap(quotientPolysCap[:])
@ -111,7 +113,7 @@ func (circuit *TestChallengerCircuit) Define(api frontend.API) error {
}
for i := 0; i < 2; i++ {
field.AssertIsEqual(plonkZeta[i], expectedPlonkZeta[i])
fieldAPI.AssertIsEqual(plonkZeta[i], expectedPlonkZeta[i])
}
return nil

+ 1
- 1
plonky2_verifier/fri_test.go

@ -29,7 +29,7 @@ func (circuit *TestFriCircuit) Define(api frontend.API) error {
fieldAPI := NewFieldAPI(api)
qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits)
hashAPI := NewHashAPI(fieldAPI)
poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI)
poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI, qeAPI)
friChip := NewFriChip(api, fieldAPI, qeAPI, hashAPI, poseidonChip, &commonCircuitData.FriParams)
friChallenges := FriChallenges{

+ 4
- 4
plonky2_verifier/plonk_test.go

@ -23,8 +23,8 @@ func (circuit *TestPlonkCircuit) Define(api frontend.API) error {
proofWithPis := DeserializeProofWithPublicInputs(circuit.proofWithPIsFilename)
commonCircuitData := DeserializeCommonCircuitData(circuit.commonCircuitDataFilename)
field := NewFieldAPI(api)
qe := NewQuadraticExtensionAPI(field, commonCircuitData.DegreeBits)
fieldAPI := NewFieldAPI(api)
qeAPI := NewQuadraticExtensionAPI(fieldAPI, commonCircuitData.DegreeBits)
proofChallenges := ProofChallenges{
PlonkBetas: circuit.plonkBetas,
@ -33,9 +33,9 @@ func (circuit *TestPlonkCircuit) Define(api frontend.API) error {
PlonkZeta: circuit.plonkZeta,
}
plonkChip := NewPlonkChip(api, qe, commonCircuitData)
plonkChip := NewPlonkChip(api, qeAPI, commonCircuitData)
poseidonChip := poseidon.NewPoseidonChip(api, field)
poseidonChip := poseidon.NewPoseidonChip(api, fieldAPI, qeAPI)
publicInputsHash := poseidonChip.HashNoPad(proofWithPis.PublicInputs)
plonkChip.Verify(proofChallenges, proofWithPis.Proof.Openings, publicInputsHash)

+ 16
- 14
plonky2_verifier/poseidon_gate.go

@ -74,6 +74,8 @@ type PoseidonGate struct {
func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []QuadraticExtension {
constraints := []QuadraticExtension{}
poseidonChip := poseidon.NewPoseidonChip(pc.api, NewFieldAPI(pc.api), pc.qeAPI)
// Assert that `swap` is binary.
swap := vars.localWires[p.WireSwap()]
notSwap := pc.qeAPI.SubExtension(pc.qeAPI.FieldToQE(ONE_F), swap)
@ -90,7 +92,7 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad
}
// Compute the possibly-swapped input layer.
state := make([]QuadraticExtension, poseidon.SPONGE_WIDTH)
var state [poseidon.SPONGE_WIDTH]QuadraticExtension
for i := uint64(0); i < 4; i++ {
deltaI := vars.localWires[p.WireDelta(i)]
inputLhs := vars.localWires[p.WireInput(i)]
@ -106,7 +108,7 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad
// First set of full rounds.
for r := uint64(0); r < poseidon.HALF_N_FULL_ROUNDS; r++ {
// TODO: constantLayerField(state, round_ctr)
state = poseidonChip.ConstantLayerExtension(state, &round_ctr)
if r != 0 {
for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ {
sBoxIn := vars.localWires[p.WireFullSBox0(r, i)]
@ -114,37 +116,37 @@ func (p *PoseidonGate) EvalUnfiltered(pc *PlonkChip, vars EvaluationVars) []Quad
state[i] = sBoxIn
}
}
// TODO: sboxLayerField(state)
// TODO: state = mdsLayerField(state)
state = poseidonChip.SBoxLayerExtension(state)
state = poseidonChip.MdsLayerExtension(state)
round_ctr++
}
// Partial rounds.
// TODO: partialFirstConstantLayer(state)
// TODO: state = mdsParitalLayerInit(state)
state = poseidonChip.PartialFirstConstantLayerExtension(state)
state = poseidonChip.MdsPartialLayerInitExtension(state)
for r := uint64(0); r < poseidon.N_PARTIAL_ROUNDS-1; r++ {
sBoxIn := vars.localWires[p.WirePartialSBox(r)]
constraints = append(constraints, pc.qeAPI.SubExtension(state[0], sBoxIn))
// TODO: state[0] = sBoxMonomial(sBoxIn)
// TODO: state[0] += NewFieldElement(FAST_PARTIAL_ROUND_CONSTANTS[r])
// TODO: state = mdsParitalLayerFastField(state, r)
state[0] = poseidonChip.SBoxMonomialExtension(sBoxIn)
state[0] = pc.qeAPI.AddExtension(state[0], pc.qeAPI.FieldToQE(NewFieldElement(poseidon.FAST_PARTIAL_ROUND_CONSTANTS[r])))
state = poseidonChip.MdsPartialLayerFastExtension(state, int(r))
}
sBoxIn := vars.localWires[p.WirePartialSBox(poseidon.N_PARTIAL_ROUNDS-1)]
constraints = append(constraints, pc.qeAPI.SubExtension(state[0], sBoxIn))
// TODO: state[0] = sBoxMonomial(sBoxIn)
// TODO: state = mdsPartialLayerLastField(state, poseidon.N_PARTIAL_ROUNDS - 1)
state[0] = poseidonChip.SBoxMonomialExtension(sBoxIn)
state = poseidonChip.MdsPartialLayerFastExtension(state, poseidon.N_PARTIAL_ROUNDS-1)
round_ctr += poseidon.N_PARTIAL_ROUNDS
// Second set of full rounds.
for r := uint64(0); r < poseidon.HALF_N_FULL_ROUNDS; r++ {
// TODO: constantLayerField(state, round_ctr)
poseidonChip.ConstantLayerExtension(state, &round_ctr)
for i := uint64(0); i < poseidon.SPONGE_WIDTH; i++ {
sBoxIn := vars.localWires[p.WireFullSBox1(r, i)]
constraints = append(constraints, pc.qeAPI.SubExtension(state[i], sBoxIn))
state[i] = sBoxIn
}
// TODO: sboxLayerField(state)
// TODO: state = mdsLayerField(state)
state = poseidonChip.MdsLayerExtension(state)
state = poseidonChip.SBoxLayerExtension(state)
round_ctr++
}

+ 0
- 164
plonky2_verifier/quadratic_extension.go

@ -1,164 +0,0 @@
package plonky2_verifier
import (
"fmt"
. "gnark-plonky2-verifier/field"
"math/bits"
"github.com/consensys/gnark/frontend"
)
type QuadraticExtensionAPI struct {
fieldAPI frontend.API
W F
DTH_ROOT F
ONE_QE QuadraticExtension
ZERO_QE QuadraticExtension
}
func NewQuadraticExtensionAPI(fieldAPI frontend.API, degreeBits uint64) *QuadraticExtensionAPI {
// TODO: Should degreeBits be verified that it fits within the field and that degree is within uint64?
return &QuadraticExtensionAPI{
fieldAPI: fieldAPI,
W: NewFieldElement(7),
DTH_ROOT: NewFieldElement(18446744069414584320),
ONE_QE: QuadraticExtension{ONE_F, ZERO_F},
ZERO_QE: QuadraticExtension{ZERO_F, ZERO_F},
}
}
func (c *QuadraticExtensionAPI) SquareExtension(a QuadraticExtension) QuadraticExtension {
return c.MulExtension(a, a)
}
func (c *QuadraticExtensionAPI) MulExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension {
c_0 := c.fieldAPI.Add(c.fieldAPI.Mul(a[0], b[0]).(F), c.fieldAPI.Mul(c.W, a[1], b[1])).(F)
c_1 := c.fieldAPI.Add(c.fieldAPI.Mul(a[0], b[1]).(F), c.fieldAPI.Mul(a[1], b[0])).(F)
return QuadraticExtension{c_0, c_1}
}
func (c *QuadraticExtensionAPI) AddExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension {
c_0 := c.fieldAPI.Add(a[0], b[0]).(F)
c_1 := c.fieldAPI.Add(a[1], b[1]).(F)
return QuadraticExtension{c_0, c_1}
}
func (c *QuadraticExtensionAPI) SubExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension {
c_0 := c.fieldAPI.Sub(a[0], b[0]).(F)
c_1 := c.fieldAPI.Sub(a[1], b[1]).(F)
return QuadraticExtension{c_0, c_1}
}
func (c *QuadraticExtensionAPI) DivExtension(a QuadraticExtension, b QuadraticExtension) QuadraticExtension {
return c.MulExtension(a, c.InverseExtension(b))
}
func (c *QuadraticExtensionAPI) IsZero(a QuadraticExtension) frontend.Variable {
return c.fieldAPI.Mul(c.fieldAPI.IsZero(a[0]), c.fieldAPI.IsZero(a[1]))
}
// TODO: Instead of calculating the inverse within the circuit, can witness the
// inverse and assert that a_inverse * a = 1. Should reduce # of constraints.
func (c *QuadraticExtensionAPI) InverseExtension(a QuadraticExtension) QuadraticExtension {
// First assert that a doesn't have 0 value coefficients
a0_is_zero := c.fieldAPI.IsZero(a[0])
a1_is_zero := c.fieldAPI.IsZero(a[1])
// assert that a0_is_zero OR a1_is_zero == false
c.fieldAPI.AssertIsEqual(c.fieldAPI.Mul(a0_is_zero, a1_is_zero).(F), ZERO_F)
a_pow_r_minus_1 := QuadraticExtension{a[0], c.fieldAPI.Mul(a[1], c.DTH_ROOT).(F)}
a_pow_r := c.MulExtension(a_pow_r_minus_1, a)
return c.ScalarMulExtension(a_pow_r_minus_1, c.fieldAPI.Inverse(a_pow_r[0]).(F))
}
func (c *QuadraticExtensionAPI) ScalarMulExtension(a QuadraticExtension, scalar F) QuadraticExtension {
return QuadraticExtension{c.fieldAPI.Mul(a[0], scalar).(F), c.fieldAPI.Mul(a[1], scalar).(F)}
}
func (c *QuadraticExtensionAPI) FieldToQE(a F) QuadraticExtension {
return QuadraticExtension{a, ZERO_F}
}
// / Exponentiate `base` to the power of a known `exponent`.
func (c *QuadraticExtensionAPI) ExpU64Extension(a QuadraticExtension, exponent uint64) QuadraticExtension {
switch exponent {
case 0:
return c.ONE_QE
case 1:
return a
case 2:
return c.SquareExtension(a)
default:
}
current := a
product := c.ONE_QE
for i := 0; i < bits.Len64(exponent); i++ {
if i != 0 {
current = c.SquareExtension(current)
}
if (exponent >> i & 1) != 0 {
product = c.MulExtension(product, current)
}
}
return product
}
func (c *QuadraticExtensionAPI) ReduceWithPowers(terms []QuadraticExtension, scalar QuadraticExtension) QuadraticExtension {
sum := c.ZERO_QE
for i := len(terms) - 1; i >= 0; i-- {
sum = c.AddExtension(
c.MulExtension(
sum,
scalar,
),
terms[i],
)
}
return sum
}
func (c *QuadraticExtensionAPI) Select(b0 frontend.Variable, qe0, qe1 QuadraticExtension) QuadraticExtension {
var retQE QuadraticExtension
for i := 0; i < 2; i++ {
retQE[i] = c.fieldAPI.Select(b0, qe0[i], qe1[i]).(F)
}
return retQE
}
func (c *QuadraticExtensionAPI) Lookup2(b0 frontend.Variable, b1 frontend.Variable, qe0, qe1, qe2, qe3 QuadraticExtension) QuadraticExtension {
var retQE QuadraticExtension
for i := 0; i < 2; i++ {
retQE[i] = c.fieldAPI.Lookup2(b0, b1, qe0[i], qe1[i], qe2[i], qe3[i]).(F)
}
return retQE
}
func (c *QuadraticExtensionAPI) AssertIsEqual(a, b QuadraticExtension) {
for i := 0; i < 2; i++ {
c.fieldAPI.AssertIsEqual(a[0], b[0])
}
}
func (c *QuadraticExtensionAPI) Println(a QuadraticExtension) {
fmt.Print("Degree 0 coefficient")
c.fieldAPI.Println(a[0])
fmt.Print("Degree 1 coefficient")
c.fieldAPI.Println(a[1])
}

+ 8
- 8
plonky2_verifier/quadratic_extension_test.go

@ -21,14 +21,14 @@ type TestQuadraticExtensionMulCircuit struct {
}
func (c *TestQuadraticExtensionMulCircuit) Define(api frontend.API) error {
field := field.NewFieldAPI(api)
fieldAPI := field.NewFieldAPI(api)
degreeBits := 3
c.qeAPI = NewQuadraticExtensionAPI(field, uint64(degreeBits))
c.qeAPI = NewQuadraticExtensionAPI(fieldAPI, uint64(degreeBits))
actualRes := c.qeAPI.MulExtension(c.operand1, c.operand2)
field.AssertIsEqual(actualRes[0], c.expectedResult[0])
field.AssertIsEqual(actualRes[1], c.expectedResult[1])
fieldAPI.AssertIsEqual(actualRes[0], c.expectedResult[0])
fieldAPI.AssertIsEqual(actualRes[1], c.expectedResult[1])
return nil
}
@ -55,14 +55,14 @@ type TestQuadraticExtensionDivCircuit struct {
}
func (c *TestQuadraticExtensionDivCircuit) Define(api frontend.API) error {
field := field.NewFieldAPI(api)
fieldAPI := field.NewFieldAPI(api)
degreeBits := 3
c.qeAPI = NewQuadraticExtensionAPI(field, uint64(degreeBits))
c.qeAPI = NewQuadraticExtensionAPI(fieldAPI, uint64(degreeBits))
actualRes := c.qeAPI.DivExtension(c.operand1, c.operand2)
field.AssertIsEqual(actualRes[0], c.expectedResult[0])
field.AssertIsEqual(actualRes[1], c.expectedResult[1])
fieldAPI.AssertIsEqual(actualRes[0], c.expectedResult[0])
fieldAPI.AssertIsEqual(actualRes[1], c.expectedResult[1])
return nil
}

+ 169
- 58
poseidon/poseidon.go

@ -11,33 +11,35 @@ const N_FULL_ROUNDS_TOTAL = 2 * HALF_N_FULL_ROUNDS
const N_PARTIAL_ROUNDS = 22
const N_ROUNDS = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS
const MAX_WIDTH = 12
const WIDTH = 12
const SPONGE_WIDTH = 12
const SPONGE_RATE = 8
type PoseidonState = [WIDTH]F
type PoseidonState = [SPONGE_WIDTH]F
type PoseidonStateExtension = [SPONGE_WIDTH]QuadraticExtension
type PoseidonChip struct {
api frontend.API `gnark:"-"`
field frontend.API `gnark:"-"`
api frontend.API `gnark:"-"`
fieldAPI frontend.API `gnark:"-"`
qeAPI *QuadraticExtensionAPI `gnark:"-"`
}
func NewPoseidonChip(api frontend.API, field frontend.API) *PoseidonChip {
return &PoseidonChip{api: api, field: field}
func NewPoseidonChip(api frontend.API, field frontend.API, qeAPI *QuadraticExtensionAPI) *PoseidonChip {
return &PoseidonChip{api: api, fieldAPI: field}
}
func (c *PoseidonChip) Poseidon(input PoseidonState) PoseidonState {
state := input
roundCounter := 0
state = c.fullRounds(state, &roundCounter)
state = c.partialRounds(state, &roundCounter)
state = c.fullRounds(state, &roundCounter)
state = c.FullRounds(state, &roundCounter)
state = c.PartialRounds(state, &roundCounter)
state = c.FullRounds(state, &roundCounter)
return state
}
func (c *PoseidonChip) HashNToMNoPad(input []F, nbOutputs int) []F {
var state PoseidonState
for i := 0; i < WIDTH; i++ {
for i := 0; i < SPONGE_WIDTH; i++ {
state[i] = ZERO_F
}
@ -69,24 +71,24 @@ func (c *PoseidonChip) HashNoPad(input []F) Hash {
return hash
}
func (c *PoseidonChip) fullRounds(state PoseidonState, roundCounter *int) PoseidonState {
func (c *PoseidonChip) FullRounds(state PoseidonState, roundCounter *int) PoseidonState {
for i := 0; i < HALF_N_FULL_ROUNDS; i++ {
state = c.constantLayer(state, roundCounter)
state = c.sBoxLayer(state)
state = c.mdsLayer(state)
state = c.ConstantLayer(state, roundCounter)
state = c.SBoxLayer(state)
state = c.MdsLayer(state)
*roundCounter += 1
}
return state
}
func (c *PoseidonChip) partialRounds(state PoseidonState, roundCounter *int) PoseidonState {
state = c.partialFirstConstantLayer(state)
state = c.mdsPartialLayerInit(state)
func (c *PoseidonChip) PartialRounds(state PoseidonState, roundCounter *int) PoseidonState {
state = c.PartialFirstConstantLayer(state)
state = c.MdsPartialLayerInit(state)
for i := 0; i < N_PARTIAL_ROUNDS; i++ {
state[0] = c.sBoxMonomial(state[0])
state[0] = c.field.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i]).(F)
state = c.mdsPartialLayerFast(state, i)
state[0] = c.SBoxMonomial(state[0])
state[0] = c.fieldAPI.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i]).(F)
state = c.MdsPartialLayerFast(state, i)
}
*roundCounter += N_PARTIAL_ROUNDS
@ -94,38 +96,64 @@ func (c *PoseidonChip) partialRounds(state PoseidonState, roundCounter *int) Pos
return state
}
func (c *PoseidonChip) constantLayer(state PoseidonState, roundCounter *int) PoseidonState {
func (c *PoseidonChip) ConstantLayer(state PoseidonState, roundCounter *int) PoseidonState {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
roundConstant := NewFieldElement(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)])
state[i] = c.fieldAPI.Add(state[i], roundConstant).(F)
}
}
return state
}
func (c *PoseidonChip) ConstantLayerExtension(state PoseidonStateExtension, roundCounter *int) PoseidonStateExtension {
for i := 0; i < 12; i++ {
if i < WIDTH {
roundConstant := NewFieldElement(ALL_ROUND_CONSTANTS[i+WIDTH*(*roundCounter)])
state[i] = c.field.Add(state[i], roundConstant).(F)
if i < SPONGE_WIDTH {
roundConstant := c.qeAPI.FieldToQE(NewFieldElement(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)]))
state[i] = c.qeAPI.AddExtension(state[i], roundConstant)
}
}
return state
}
func (c *PoseidonChip) sBoxLayer(state PoseidonState) PoseidonState {
func (c *PoseidonChip) SBoxMonomial(x F) F {
x2 := c.fieldAPI.Mul(x, x)
x4 := c.fieldAPI.Mul(x2, x2)
x3 := c.fieldAPI.Mul(x2, x)
return c.fieldAPI.Mul(x3, x4).(F)
}
func (c *PoseidonChip) SBoxMonomialExtension(x QuadraticExtension) QuadraticExtension {
x2 := c.qeAPI.MulExtension(x, x)
x4 := c.qeAPI.MulExtension(x2, x2)
x3 := c.qeAPI.MulExtension(x2, x)
return c.qeAPI.MulExtension(x3, x4)
}
func (c *PoseidonChip) SBoxLayer(state PoseidonState) PoseidonState {
for i := 0; i < 12; i++ {
if i < WIDTH {
state[i] = c.sBoxMonomial(state[i])
if i < SPONGE_WIDTH {
state[i] = c.SBoxMonomial(state[i])
}
}
return state
}
func (c *PoseidonChip) sBoxMonomial(x F) F {
x2 := c.field.Mul(x, x)
x4 := c.field.Mul(x2, x2)
x3 := c.field.Mul(x2, x)
return c.field.Mul(x3, x4).(F)
func (c *PoseidonChip) SBoxLayerExtension(state PoseidonStateExtension) PoseidonStateExtension {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
state[i] = c.SBoxMonomialExtension(state[i])
}
}
return state
}
func (c *PoseidonChip) mdsRowShf(r int, v [WIDTH]frontend.Variable) frontend.Variable {
func (c *PoseidonChip) MdsRowShf(r int, v [SPONGE_WIDTH]frontend.Variable) frontend.Variable {
res := frontend.Variable(0)
for i := 0; i < 12; i++ {
if i < WIDTH {
res1 := c.api.Mul(v[(i+r)%WIDTH], frontend.Variable(MDS_MATRIX_CIRC[i]))
if i < SPONGE_WIDTH {
res1 := c.api.Mul(v[(i+r)%SPONGE_WIDTH], frontend.Variable(MDS_MATRIX_CIRC[i]))
res = c.api.Add(res, res1)
}
}
@ -134,38 +162,76 @@ func (c *PoseidonChip) mdsRowShf(r int, v [WIDTH]frontend.Variable) frontend.Var
return res
}
func (c *PoseidonChip) mdsLayer(state_ PoseidonState) PoseidonState {
func (c *PoseidonChip) MdsRowShfExtension(r int, v [SPONGE_WIDTH]QuadraticExtension) QuadraticExtension {
res := c.qeAPI.FieldToQE(NewFieldElement(0))
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
matrixVal := c.qeAPI.FieldToQE(NewFieldElement(MDS_MATRIX_CIRC[i]))
res1 := c.qeAPI.MulExtension(v[(i+r)%SPONGE_WIDTH], matrixVal)
res = c.qeAPI.AddExtension(res, res1)
}
}
matrixVal := c.qeAPI.FieldToQE(NewFieldElement(MDS_MATRIX_DIAG[r]))
res = c.qeAPI.AddExtension(res, c.qeAPI.MulExtension(v[r], matrixVal))
return res
}
func (c *PoseidonChip) MdsLayer(state_ PoseidonState) PoseidonState {
var result PoseidonState
for i := 0; i < WIDTH; i++ {
for i := 0; i < SPONGE_WIDTH; i++ {
result[i] = NewFieldElement(0)
}
var state [WIDTH]frontend.Variable
for i := 0; i < WIDTH; i++ {
state[i] = c.api.FromBinary(c.field.ToBinary(state_[i])...)
var state [SPONGE_WIDTH]frontend.Variable
for i := 0; i < SPONGE_WIDTH; i++ {
state[i] = c.api.FromBinary(c.fieldAPI.ToBinary(state_[i])...)
}
for r := 0; r < 12; r++ {
if r < WIDTH {
sum := c.mdsRowShf(r, state)
if r < SPONGE_WIDTH {
sum := c.MdsRowShf(r, state)
bits := c.api.ToBinary(sum)
result[r] = c.field.FromBinary(bits).(F)
result[r] = c.fieldAPI.FromBinary(bits).(F)
}
}
return result
}
func (c *PoseidonChip) MdsLayerExtension(state_ PoseidonStateExtension) PoseidonStateExtension {
var result PoseidonStateExtension
for r := 0; r < 12; r++ {
if r < SPONGE_WIDTH {
sum := c.MdsRowShfExtension(r, state_)
result[r] = sum
}
}
return result
}
func (c *PoseidonChip) partialFirstConstantLayer(state PoseidonState) PoseidonState {
func (c *PoseidonChip) PartialFirstConstantLayer(state PoseidonState) PoseidonState {
for i := 0; i < 12; i++ {
if i < WIDTH {
state[i] = c.field.Add(state[i], NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])).(F)
if i < SPONGE_WIDTH {
state[i] = c.fieldAPI.Add(state[i], NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])).(F)
}
}
return state
}
func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState {
func (c *PoseidonChip) PartialFirstConstantLayerExtension(state PoseidonStateExtension) PoseidonStateExtension {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
state[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])))
}
}
return state
}
func (c *PoseidonChip) MdsPartialLayerInit(state PoseidonState) PoseidonState {
var result PoseidonState
for i := 0; i < 12; i++ {
result[i] = NewFieldElement(0)
@ -174,11 +240,11 @@ func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState {
result[0] = state[0]
for r := 1; r < 12; r++ {
if r < WIDTH {
if r < SPONGE_WIDTH {
for d := 1; d < 12; d++ {
if d < WIDTH {
if d < SPONGE_WIDTH {
t := NewFieldElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])
result[d] = c.field.Add(result[d], c.field.Mul(state[r], t)).(F)
result[d] = c.fieldAPI.Add(result[d], c.fieldAPI.Mul(state[r], t)).(F)
}
}
}
@ -187,32 +253,77 @@ func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState {
return result
}
func (c *PoseidonChip) mdsPartialLayerFast(state PoseidonState, r int) PoseidonState {
func (c *PoseidonChip) MdsPartialLayerInitExtension(state PoseidonStateExtension) PoseidonStateExtension {
var result PoseidonStateExtension
for i := 0; i < 12; i++ {
result[i] = c.qeAPI.FieldToQE(NewFieldElement(0))
}
result[0] = state[0]
for r := 1; r < 12; r++ {
if r < SPONGE_WIDTH {
for d := 1; d < 12; d++ {
if d < SPONGE_WIDTH {
t := c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1]))
result[d] = c.qeAPI.AddExtension(result[d], c.qeAPI.MulExtension(state[r], t))
}
}
}
}
return result
}
func (c *PoseidonChip) MdsPartialLayerFast(state PoseidonState, r int) PoseidonState {
dSum := frontend.Variable(0)
for i := 1; i < 12; i++ {
if i < WIDTH {
if i < SPONGE_WIDTH {
t := frontend.Variable(FAST_PARTIAL_ROUND_W_HATS[r][i-1])
si := c.api.FromBinary(c.field.ToBinary(state[i])...)
si := c.api.FromBinary(c.fieldAPI.ToBinary(state[i])...)
dSum = c.api.Add(dSum, c.api.Mul(si, t))
}
}
s0 := c.api.FromBinary(c.field.ToBinary(state[0])...)
s0 := c.api.FromBinary(c.fieldAPI.ToBinary(state[0])...)
mds0to0 := frontend.Variable(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0])
dSum = c.api.Add(dSum, c.api.Mul(s0, mds0to0))
d := c.field.FromBinary(c.api.ToBinary(dSum))
d := c.fieldAPI.FromBinary(c.api.ToBinary(dSum))
var result PoseidonState
for i := 0; i < WIDTH; i++ {
for i := 0; i < SPONGE_WIDTH; i++ {
result[i] = NewFieldElement(0)
}
result[0] = d.(F)
for i := 1; i < 12; i++ {
if i < WIDTH {
if i < SPONGE_WIDTH {
t := NewFieldElement(FAST_PARTIAL_ROUND_VS[r][i-1])
result[i] = c.field.Add(state[i], c.field.Mul(state[0], t)).(F)
result[i] = c.fieldAPI.Add(state[i], c.fieldAPI.Mul(state[0], t)).(F)
}
}
return result
}
func (c *PoseidonChip) MdsPartialLayerFastExtension(state PoseidonStateExtension, r int) PoseidonStateExtension {
s0 := state[0]
mds0to0 := c.qeAPI.FieldToQE(NewFieldElement(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0]))
d := c.qeAPI.AddExtension(s0, mds0to0)
for i := 1; i < 12; i++ {
if i < SPONGE_WIDTH {
t := c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_ROUND_W_HATS[r][i-1]))
d = c.qeAPI.AddExtension(d, c.qeAPI.MulExtension(state[i], t))
}
}
var result PoseidonStateExtension
result[0] = d
for i := 1; i < 12; i++ {
if i < SPONGE_WIDTH {
t := c.qeAPI.FieldToQE(NewFieldElement(FAST_PARTIAL_ROUND_VS[r][i-1]))
result[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.MulExtension(state[0], t))
}
}

+ 5
- 5
poseidon/public_inputs_hash_test.go

@ -18,22 +18,22 @@ type TestPublicInputsHashCircuit struct {
}
func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error {
field := NewFieldAPI(api)
fieldAPI := NewFieldAPI(api)
// BN254 -> Binary(64) -> F
var input [3]F
for i := 0; i < 3; i++ {
input[i] = field.FromBinary(api.ToBinary(circuit.In[i], 64)).(F)
input[i] = fieldAPI.FromBinary(api.ToBinary(circuit.In[i], 64)).(F)
}
poseidonChip := &PoseidonChip{api: api, field: field}
poseidonChip := &PoseidonChip{api: api, fieldAPI: fieldAPI}
output := poseidonChip.HashNoPad(input[:])
// Check that output is correct
for i := 0; i < 4; i++ {
field.AssertIsEqual(
fieldAPI.AssertIsEqual(
output[i],
field.FromBinary(api.ToBinary(circuit.Out[i])).(F),
fieldAPI.FromBinary(api.ToBinary(circuit.Out[i])).(F),
)
}

Loading…
Cancel
Save