mirror of
https://github.com/arnaucube/gnark-plonky2-verifier.git
synced 2026-01-12 00:51:33 +01:00
goldilocks and poseidon
This commit is contained in:
1131
poseidon/constants.go
Normal file
1131
poseidon/constants.go
Normal file
File diff suppressed because it is too large
Load Diff
191
poseidon/poseidon.go
Normal file
191
poseidon/poseidon.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package poseidon
|
||||
|
||||
import (
|
||||
. "gnark-ed25519/goldilocks"
|
||||
|
||||
"github.com/consensys/gnark/frontend"
|
||||
)
|
||||
|
||||
/* Note: This package assumes usage of the BN254 curve in various places. */
|
||||
|
||||
const HALF_N_FULL_ROUNDS = 4
|
||||
const N_FULL_ROUNDS_TOTAL = 2 * HALF_N_FULL_ROUNDS
|
||||
const N_PARTIAL_ROUNDS = 22
|
||||
const N_ROUNDS = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS
|
||||
const MAX_WIDTH = 12
|
||||
const WIDTH = 12
|
||||
const SPONGE_WIDTH = 12
|
||||
const SPONGE_RATE = 8
|
||||
|
||||
type PoseidonState = [WIDTH]GoldilocksElement
|
||||
type PoseidonChip struct {
|
||||
api frontend.API
|
||||
field frontend.API
|
||||
}
|
||||
|
||||
func Poseidon(api frontend.API, field frontend.API, input PoseidonState) PoseidonState {
|
||||
chip := &PoseidonChip{api: api, field: field}
|
||||
return chip.Poseidon(input)
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) Poseidon(input PoseidonState) PoseidonState {
|
||||
state := input
|
||||
roundCounter := 0
|
||||
state = c.fullRounds(state, &roundCounter)
|
||||
state = c.partialRounds(state, &roundCounter)
|
||||
state = c.fullRounds(state, &roundCounter)
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) fullRounds(state PoseidonState, roundCounter *int) PoseidonState {
|
||||
for i := 0; i < HALF_N_FULL_ROUNDS; i++ {
|
||||
state = c.constantLayer(state, roundCounter)
|
||||
state = c.sBoxLayer(state)
|
||||
state = c.mdsLayer(state)
|
||||
if *roundCounter >= 26 && i == 3 {
|
||||
break
|
||||
}
|
||||
*roundCounter += 1
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) partialRounds(state PoseidonState, roundCounter *int) PoseidonState {
|
||||
state = c.partialFirstConstantLayer(state)
|
||||
state = c.mdsPartialLayerInit(state)
|
||||
|
||||
for i := 0; i < N_PARTIAL_ROUNDS; i++ {
|
||||
state[0] = c.sBoxMonomial(state[0])
|
||||
state[0] = c.field.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i]).(GoldilocksElement)
|
||||
state = c.mdsPartialLayerFast(state, i)
|
||||
}
|
||||
|
||||
*roundCounter += N_PARTIAL_ROUNDS
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) constantLayer(state PoseidonState, roundCounter *int) PoseidonState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < WIDTH {
|
||||
roundConstant := NewGoldilocksElement(ALL_ROUND_CONSTANTS[i+WIDTH*(*roundCounter)])
|
||||
state[i] = c.field.Add(state[i], roundConstant).(GoldilocksElement)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) sBoxLayer(state PoseidonState) PoseidonState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < WIDTH {
|
||||
state[i] = c.sBoxMonomial(state[i])
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) sBoxMonomial(x GoldilocksElement) GoldilocksElement {
|
||||
x2 := c.field.Mul(x, x)
|
||||
x4 := c.field.Mul(x2, x2)
|
||||
x3 := c.field.Mul(x2, x)
|
||||
return c.field.Mul(x3, x4).(GoldilocksElement)
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsRowShf(r int, v [WIDTH]frontend.Variable) frontend.Variable {
|
||||
res := frontend.Variable(0)
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < WIDTH {
|
||||
res1 := c.api.Mul(v[(i+r)%WIDTH], frontend.Variable(MDS_MATRIX_CIRC[i]))
|
||||
res = c.api.Add(res, res1)
|
||||
}
|
||||
}
|
||||
|
||||
res = c.api.Add(res, c.api.Mul(v[r], MDS_MATRIX_DIAG[r]))
|
||||
return res
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsLayer(state_ PoseidonState) PoseidonState {
|
||||
var result PoseidonState
|
||||
for i := 0; i < WIDTH; i++ {
|
||||
result[i] = NewGoldilocksElement(0)
|
||||
}
|
||||
|
||||
var state [WIDTH]frontend.Variable
|
||||
for i := 0; i < WIDTH; i++ {
|
||||
state[i] = c.api.FromBinary(c.field.ToBinary(state_[i])...)
|
||||
}
|
||||
|
||||
for r := 0; r < 12; r++ {
|
||||
if r < WIDTH {
|
||||
sum := c.mdsRowShf(r, state)
|
||||
bits := c.api.ToBinary(sum)
|
||||
result[r] = c.field.FromBinary(bits).(GoldilocksElement)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) partialFirstConstantLayer(state PoseidonState) PoseidonState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < WIDTH {
|
||||
state[i] = c.field.Add(state[i], NewGoldilocksElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])).(GoldilocksElement)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState {
|
||||
var result PoseidonState
|
||||
for i := 0; i < 12; i++ {
|
||||
result[i] = NewGoldilocksElement(0)
|
||||
}
|
||||
|
||||
result[0] = state[0]
|
||||
|
||||
for r := 1; r < 12; r++ {
|
||||
if r < WIDTH {
|
||||
for d := 1; d < 12; d++ {
|
||||
if d < WIDTH {
|
||||
t := NewGoldilocksElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])
|
||||
result[d] = c.field.Add(result[d], c.field.Mul(state[r], t)).(GoldilocksElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsPartialLayerFast(state PoseidonState, r int) PoseidonState {
|
||||
dSum := frontend.Variable(0)
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < WIDTH {
|
||||
t := frontend.Variable(FAST_PARTIAL_ROUND_W_HATS[r][i-1])
|
||||
si := c.api.FromBinary(c.field.ToBinary(state[i])...)
|
||||
dSum = c.api.Add(dSum, c.api.Mul(si, t))
|
||||
}
|
||||
}
|
||||
|
||||
s0 := c.api.FromBinary(c.field.ToBinary(state[0])...)
|
||||
mds0to0 := frontend.Variable(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0])
|
||||
dSum = c.api.Add(dSum, c.api.Mul(s0, mds0to0))
|
||||
d := c.field.FromBinary(c.api.ToBinary(dSum))
|
||||
|
||||
var result PoseidonState
|
||||
for i := 0; i < WIDTH; i++ {
|
||||
result[i] = NewGoldilocksElement(0)
|
||||
}
|
||||
|
||||
result[0] = d.(GoldilocksElement)
|
||||
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < WIDTH {
|
||||
t := NewGoldilocksElement(FAST_PARTIAL_ROUND_VS[r][i-1])
|
||||
result[i] = c.field.Add(state[i], c.field.Mul(state[0], t)).(GoldilocksElement)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
145
poseidon/poseidon_test.go
Normal file
145
poseidon/poseidon_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package poseidon
|
||||
|
||||
import (
|
||||
. "gnark-ed25519/goldilocks"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/gnark/backend/groth16"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/frontend/cs/r1cs"
|
||||
"github.com/consensys/gnark/test"
|
||||
)
|
||||
|
||||
var testCurve = ecc.BN254
|
||||
|
||||
type TestPoseidonCircuit struct {
|
||||
In [12]frontend.Variable
|
||||
Out [12]frontend.Variable
|
||||
}
|
||||
|
||||
func (circuit *TestPoseidonCircuit) Define(api frontend.API) error {
|
||||
goldilocksApi := NewGoldilocksAPI(api)
|
||||
|
||||
// BN254 -> Binary(64) -> GoldilocksElement
|
||||
var input PoseidonState
|
||||
for i := 0; i < 12; i++ {
|
||||
input[i] = goldilocksApi.FromBinary(api.ToBinary(circuit.In[i], 64)).(GoldilocksElement)
|
||||
}
|
||||
|
||||
output := Poseidon(api, goldilocksApi, input)
|
||||
|
||||
// Check that output is correct
|
||||
for i := 0; i < 12; i++ {
|
||||
goldilocksApi.AssertIsEqual(
|
||||
output[i],
|
||||
goldilocksApi.FromBinary(api.ToBinary(circuit.Out[i])).(GoldilocksElement),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPoseidonWitness(t *testing.T) {
|
||||
assert := test.NewAssert(t)
|
||||
|
||||
testCase := func(inBigInt [12]big.Int, outBigInt [12]big.Int) {
|
||||
var in [12]frontend.Variable
|
||||
var out [12]frontend.Variable
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
in[i] = inBigInt[i]
|
||||
out[i] = outBigInt[i]
|
||||
}
|
||||
|
||||
circuit := TestPoseidonCircuit{In: in, Out: out}
|
||||
witness := TestPoseidonCircuit{In: in, Out: out}
|
||||
err := test.IsSolved(&circuit, &witness, testCurve.ScalarField())
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
inStr := [12]string{"0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0"}
|
||||
outStr := [12]string{
|
||||
"4330397376401421145", "14124799381142128323", "8742572140681234676",
|
||||
"14345658006221440202", "15524073338516903644", "5091405722150716653",
|
||||
"15002163819607624508", "2047012902665707362", "16106391063450633726",
|
||||
"4680844749859802542", "15019775476387350140", "1698615465718385111",
|
||||
}
|
||||
|
||||
var inBigInt [12]big.Int
|
||||
var outBigInt [12]big.Int
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
inTmp := new(big.Int)
|
||||
inTmp, _ = inTmp.SetString(inStr[i], 10)
|
||||
inBigInt[i] = *inTmp
|
||||
|
||||
outTmp := new(big.Int)
|
||||
outTmp, _ = outTmp.SetString(outStr[i], 10)
|
||||
outBigInt[i] = *outTmp
|
||||
}
|
||||
|
||||
testCase(inBigInt, outBigInt)
|
||||
}
|
||||
|
||||
func TestPoseidonProof(t *testing.T) {
|
||||
inStr := [12]string{"0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0"}
|
||||
outStr := [12]string{
|
||||
"4330397376401421145", "14124799381142128323", "8742572140681234676",
|
||||
"14345658006221440202", "15524073338516903644", "5091405722150716653",
|
||||
"15002163819607624508", "2047012902665707362", "16106391063450633726",
|
||||
"4680844749859802542", "15019775476387350140", "1698615465718385111",
|
||||
}
|
||||
|
||||
var in [12]frontend.Variable
|
||||
var out [12]frontend.Variable
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
inTmp := new(big.Int)
|
||||
inTmp, _ = inTmp.SetString(inStr[i], 10)
|
||||
in[i] = *inTmp
|
||||
|
||||
outTmp := new(big.Int)
|
||||
outTmp, _ = outTmp.SetString(outStr[i], 10)
|
||||
out[i] = *outTmp
|
||||
}
|
||||
|
||||
circuit := TestPoseidonCircuit{In: in, Out: out}
|
||||
assignment := TestPoseidonCircuit{In: in, Out: out}
|
||||
|
||||
r1cs, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
witness, err := frontend.NewWitness(&assignment, testCurve.ScalarField())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pk, vk, err := groth16.Setup(r1cs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = test.IsSolved(&circuit, &assignment, testCurve.ScalarField())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
proof, err := groth16.Prove(r1cs, pk, witness)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
publicWitness, err := witness.Public()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = groth16.Verify(proof, vk, publicWitness)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user