mirror of
https://github.com/arnaucube/gnark-plonky2-verifier.git
synced 2026-01-12 00:51:33 +01:00
Use optimized goldilocks in codebase (#26)
* gl * stage 1 optimizations * working optimized poseidon * Fix posedion tests * in progress gate type refactor * working gates * working e2e * hm' * hm2 * debug saga continues * more debugging cry * more debug * it finally works * optimizations * more optimizations * new changes * more optimizations * more cleanup * some refactoring * new files * flattening of packages * working commit * more refactor * more flattening * more flattening * more more refactor * more optimizations * more optimizations * more optimizations * plonk benchmark * plonk * fix r1cs * resolve kevin's comments * Update goldilocks/base.go Co-authored-by: Kevin Jue <kjue235@gmail.com> * Update goldilocks/base.go Co-authored-by: Kevin Jue <kjue235@gmail.com> * Update goldilocks/base.go Co-authored-by: Kevin Jue <kjue235@gmail.com> * Update goldilocks/quadratic_extension.go Co-authored-by: Kevin Jue <kjue235@gmail.com> * fix: resolve kevin's confusion --------- Co-authored-by: Kevin Jue <kjue235@gmail.com>
This commit is contained in:
204
poseidon/bn254.go
Normal file
204
poseidon/bn254.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package poseidon
|
||||
|
||||
// This is a customized implementation of the Poseidon hash function inside the BN254 field.
|
||||
// This implementation is based on the following implementation:
|
||||
//
|
||||
// https://github.com/iden3/go-iden3-crypto/blob/master/poseidon/poseidon.go
|
||||
//
|
||||
// The input and output are modified to ingest Goldilocks field elements.
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"github.com/consensys/gnark/frontend"
|
||||
gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks"
|
||||
)
|
||||
|
||||
const BN254_FULL_ROUNDS int = 8
|
||||
const BN254_PARTIAL_ROUNDS int = 56
|
||||
const BN254_SPONGE_WIDTH int = 4
|
||||
const BN254_SPONGE_RATE int = 3
|
||||
|
||||
type BN254Chip struct {
|
||||
api frontend.API `gnark:"-"`
|
||||
gl gl.Chip `gnark:"-"`
|
||||
}
|
||||
|
||||
type BN254State = [BN254_SPONGE_WIDTH]frontend.Variable
|
||||
type BN254HashOut = frontend.Variable
|
||||
|
||||
func NewBN254Chip(api frontend.API) *BN254Chip {
|
||||
return &BN254Chip{api: api, gl: *gl.NewChip(api)}
|
||||
}
|
||||
|
||||
func (c *BN254Chip) Poseidon(state BN254State) BN254State {
|
||||
state = c.ark(state, 0)
|
||||
state = c.fullRounds(state, true)
|
||||
state = c.partialRounds(state)
|
||||
state = c.fullRounds(state, false)
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *BN254Chip) HashNoPad(input []gl.Variable) BN254HashOut {
|
||||
state := BN254State{
|
||||
frontend.Variable(0),
|
||||
frontend.Variable(0),
|
||||
frontend.Variable(0),
|
||||
frontend.Variable(0),
|
||||
}
|
||||
|
||||
for i := 0; i < len(input); i += BN254_SPONGE_RATE * 3 {
|
||||
endI := c.min(len(input), i+BN254_SPONGE_RATE*3)
|
||||
rateChunk := input[i:endI]
|
||||
for j, stateIdx := 0, 0; j < len(rateChunk); j, stateIdx = j+3, stateIdx+1 {
|
||||
endJ := c.min(len(rateChunk), j+3)
|
||||
bn254Chunk := rateChunk[j:endJ]
|
||||
|
||||
bits := []frontend.Variable{}
|
||||
for k := 0; k < len(bn254Chunk); k++ {
|
||||
bn254Chunk[k] = c.gl.Reduce(bn254Chunk[k])
|
||||
bits = append(bits, c.api.ToBinary(bn254Chunk[k].Limb, 64)...)
|
||||
}
|
||||
|
||||
state[stateIdx+1] = c.api.FromBinary(bits...)
|
||||
}
|
||||
|
||||
state = c.Poseidon(state)
|
||||
}
|
||||
|
||||
return BN254HashOut(state[0])
|
||||
}
|
||||
|
||||
func (c *BN254Chip) HashOrNoop(input []gl.Variable) BN254HashOut {
|
||||
if len(input) <= 3 {
|
||||
returnVal := frontend.Variable(0)
|
||||
|
||||
alpha := new(big.Int).SetInt64(1 << 32)
|
||||
for i, inputElement := range input {
|
||||
returnVal = c.api.Add(returnVal, c.api.Mul(inputElement, alpha.Exp(alpha, big.NewInt(int64(i)), nil)))
|
||||
}
|
||||
|
||||
return BN254HashOut(returnVal)
|
||||
} else {
|
||||
return c.HashNoPad(input)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BN254Chip) TwoToOne(left BN254HashOut, right BN254HashOut) BN254HashOut {
|
||||
var inputs BN254State
|
||||
inputs[0] = frontend.Variable(0)
|
||||
inputs[1] = frontend.Variable(0)
|
||||
inputs[2] = left
|
||||
inputs[3] = right
|
||||
state := c.Poseidon(inputs)
|
||||
return state[0]
|
||||
}
|
||||
|
||||
func (c *BN254Chip) ToVec(hash BN254HashOut) []gl.Variable {
|
||||
bits := c.api.ToBinary(hash)
|
||||
|
||||
returnElements := []gl.Variable{}
|
||||
|
||||
// Split into 7 byte chunks, since 8 byte chunks can result in collisions
|
||||
chunkSize := 56
|
||||
for i := 0; i < len(bits); i += chunkSize {
|
||||
maxIdx := c.min(len(bits), i+chunkSize)
|
||||
bitChunk := bits[i:maxIdx]
|
||||
returnElements = append(returnElements, gl.NewVariable(c.api.FromBinary(bitChunk...)))
|
||||
}
|
||||
|
||||
return returnElements
|
||||
}
|
||||
|
||||
func (c *BN254Chip) min(x, y int) int {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
|
||||
return y
|
||||
}
|
||||
|
||||
func (c *BN254Chip) fullRounds(state BN254State, isFirst bool) BN254State {
|
||||
for i := 0; i < BN254_FULL_ROUNDS/2-1; i++ {
|
||||
state = c.exp5state(state)
|
||||
if isFirst {
|
||||
state = c.ark(state, (i+1)*BN254_SPONGE_WIDTH)
|
||||
} else {
|
||||
state = c.ark(state, (BN254_FULL_ROUNDS/2+1)*BN254_SPONGE_WIDTH+BN254_PARTIAL_ROUNDS+i*BN254_SPONGE_WIDTH)
|
||||
}
|
||||
state = c.mix(state, mMatrix)
|
||||
}
|
||||
|
||||
state = c.exp5state(state)
|
||||
if isFirst {
|
||||
state = c.ark(state, (BN254_FULL_ROUNDS/2)*BN254_SPONGE_WIDTH)
|
||||
state = c.mix(state, pMatrix)
|
||||
} else {
|
||||
state = c.mix(state, mMatrix)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *BN254Chip) partialRounds(state BN254State) BN254State {
|
||||
for i := 0; i < BN254_PARTIAL_ROUNDS; i++ {
|
||||
state[0] = c.exp5(state[0])
|
||||
state[0] = c.api.Add(state[0], cConstants[(BN254_FULL_ROUNDS/2+1)*BN254_SPONGE_WIDTH+i])
|
||||
|
||||
var mul frontend.Variable
|
||||
newState0 := frontend.Variable(0)
|
||||
for j := 0; j < BN254_SPONGE_WIDTH; j++ {
|
||||
mul = c.api.Mul(sConstants[(BN254_SPONGE_WIDTH*2-1)*i+j], state[j])
|
||||
newState0 = c.api.Add(newState0, mul)
|
||||
}
|
||||
|
||||
for k := 1; k < BN254_SPONGE_WIDTH; k++ {
|
||||
mul = c.api.Mul(state[0], sConstants[(BN254_SPONGE_WIDTH*2-1)*i+BN254_SPONGE_WIDTH+k-1])
|
||||
state[k] = c.api.Add(state[k], mul)
|
||||
}
|
||||
state[0] = newState0
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *BN254Chip) ark(state BN254State, it int) BN254State {
|
||||
var result BN254State
|
||||
|
||||
for i := 0; i < len(state); i++ {
|
||||
result[i] = c.api.Add(state[i], cConstants[it+i])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *BN254Chip) exp5(x frontend.Variable) frontend.Variable {
|
||||
x2 := c.api.Mul(x, x)
|
||||
x4 := c.api.Mul(x2, x2)
|
||||
return c.api.Mul(x4, x)
|
||||
}
|
||||
|
||||
func (c *BN254Chip) exp5state(state BN254State) BN254State {
|
||||
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
|
||||
state[i] = c.exp5(state[i])
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *BN254Chip) mix(state_ BN254State, constantMatrix [][]*big.Int) BN254State {
|
||||
var mul frontend.Variable
|
||||
var result BN254State
|
||||
|
||||
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
|
||||
result[i] = frontend.Variable(0)
|
||||
}
|
||||
|
||||
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
|
||||
for j := 0; j < BN254_SPONGE_WIDTH; j++ {
|
||||
mul = c.api.Mul(constantMatrix[j][i], state_[j])
|
||||
result[i] = c.api.Add(result[i], mul)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -3,23 +3,22 @@ package poseidon
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/test"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/field"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/utils"
|
||||
gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks"
|
||||
)
|
||||
|
||||
type TestPoseidonBN128Circuit struct {
|
||||
In [spongeWidth]frontend.Variable
|
||||
Out [spongeWidth]frontend.Variable
|
||||
type TestPoseidonBN254Circuit struct {
|
||||
In [BN254_SPONGE_WIDTH]frontend.Variable
|
||||
Out [BN254_SPONGE_WIDTH]frontend.Variable
|
||||
}
|
||||
|
||||
func (circuit *TestPoseidonBN128Circuit) Define(api frontend.API) error {
|
||||
fieldAPI := field.NewFieldAPI(api)
|
||||
poseidonChip := NewPoseidonBN128Chip(api, fieldAPI)
|
||||
func (circuit *TestPoseidonBN254Circuit) Define(api frontend.API) error {
|
||||
poseidonChip := NewBN254Chip(api)
|
||||
output := poseidonChip.Poseidon(circuit.In)
|
||||
|
||||
for i := 0; i < spongeWidth; i++ {
|
||||
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
|
||||
api.AssertIsEqual(
|
||||
output[i],
|
||||
circuit.Out[i],
|
||||
@@ -29,13 +28,13 @@ func (circuit *TestPoseidonBN128Circuit) Define(api frontend.API) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPoseidonBN128(t *testing.T) {
|
||||
func TestPoseidonBN254(t *testing.T) {
|
||||
assert := test.NewAssert(t)
|
||||
|
||||
testCaseFn := func(in [spongeWidth]frontend.Variable, out [spongeWidth]frontend.Variable) {
|
||||
circuit := TestPoseidonBN128Circuit{In: in, Out: out}
|
||||
witness := TestPoseidonBN128Circuit{In: in, Out: out}
|
||||
err := test.IsSolved(&circuit, &witness, field.TEST_CURVE.ScalarField())
|
||||
testCaseFn := func(in [BN254_SPONGE_WIDTH]frontend.Variable, out [BN254_SPONGE_WIDTH]frontend.Variable) {
|
||||
circuit := TestPoseidonBN254Circuit{In: in, Out: out}
|
||||
witness := TestPoseidonBN254Circuit{In: in, Out: out}
|
||||
err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
@@ -89,10 +88,10 @@ func TestPoseidonBN128(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
var in [spongeWidth]frontend.Variable
|
||||
var out [spongeWidth]frontend.Variable
|
||||
copy(in[:], utils.StrArrayToFrontendVariableArray(testCase[0]))
|
||||
copy(out[:], utils.StrArrayToFrontendVariableArray(testCase[1]))
|
||||
var in [BN254_SPONGE_WIDTH]frontend.Variable
|
||||
var out [BN254_SPONGE_WIDTH]frontend.Variable
|
||||
copy(in[:], gl.StrArrayToFrontendVariableArray(testCase[0]))
|
||||
copy(out[:], gl.StrArrayToFrontendVariableArray(testCase[1]))
|
||||
testCaseFn(in, out)
|
||||
}
|
||||
}
|
||||
@@ -109,7 +109,7 @@ func init() {
|
||||
cConstants[85], _ = new(big.Int).SetString("19007581091212404202795325684108744075320879284650517772195719617120941682734", 10)
|
||||
cConstants[86], _ = new(big.Int).SetString("8172766643075822491744127151779052248074930479661223662192995838879026989201", 10)
|
||||
cConstants[87], _ = new(big.Int).SetString("1885998770792872998306340529689960371653339961062025442813774917754800650781", 10)
|
||||
|
||||
|
||||
sConstants[0], _ = new(big.Int).SetString("16023668707004248971294664614290028914393192768609916554276071736843535714477", 10)
|
||||
sConstants[1], _ = new(big.Int).SetString("20198106103550706280267600199190750325504745188750640438654177959939538483777", 10)
|
||||
sConstants[2], _ = new(big.Int).SetString("20760367756622597472566835313508896628444391801225538453375145392828630013190", 10)
|
||||
357
poseidon/goldilocks.go
Normal file
357
poseidon/goldilocks.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package poseidon
|
||||
|
||||
import (
|
||||
"github.com/consensys/gnark/frontend"
|
||||
gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks"
|
||||
)
|
||||
|
||||
const HALF_N_FULL_ROUNDS = 4
|
||||
const N_PARTIAL_ROUNDS = 22
|
||||
const MAX_WIDTH = 12
|
||||
const SPONGE_WIDTH = 12
|
||||
const SPONGE_RATE = 8
|
||||
|
||||
type GoldilocksState = [SPONGE_WIDTH]gl.Variable
|
||||
type GoldilocksStateExtension = [SPONGE_WIDTH]gl.QuadraticExtensionVariable
|
||||
type GoldilocksHashOut = [4]gl.Variable
|
||||
|
||||
type GoldilocksChip struct {
|
||||
api frontend.API `gnark:"-"`
|
||||
gl gl.Chip `gnark:"-"`
|
||||
}
|
||||
|
||||
func NewGoldilocksChip(api frontend.API) *GoldilocksChip {
|
||||
return &GoldilocksChip{api: api, gl: *gl.NewChip(api)}
|
||||
}
|
||||
|
||||
// The permutation function.
|
||||
// The input state MUST have all it's elements be within Goldilocks field (e.g. this function will not reduce the input elements).
|
||||
// The returned state's elements will all be within Goldilocks field.
|
||||
func (c *GoldilocksChip) Poseidon(input GoldilocksState) GoldilocksState {
|
||||
state := input
|
||||
roundCounter := 0
|
||||
state = c.fullRounds(state, &roundCounter)
|
||||
state = c.partialRounds(state, &roundCounter)
|
||||
state = c.fullRounds(state, &roundCounter)
|
||||
return state
|
||||
}
|
||||
|
||||
// The input elements MUST have all it's elements be within Goldilocks field.
|
||||
// The returned slice's elements will all be within Goldilocks field.
|
||||
func (c *GoldilocksChip) HashNToMNoPad(input []gl.Variable, nbOutputs int) []gl.Variable {
|
||||
var state GoldilocksState
|
||||
|
||||
for i := 0; i < SPONGE_WIDTH; i++ {
|
||||
state[i] = gl.NewVariable(0)
|
||||
}
|
||||
|
||||
for i := 0; i < len(input); i += SPONGE_RATE {
|
||||
for j := 0; j < SPONGE_RATE; j++ {
|
||||
if i+j < len(input) {
|
||||
state[j] = input[i+j]
|
||||
}
|
||||
}
|
||||
state = c.Poseidon(state)
|
||||
}
|
||||
|
||||
var outputs []gl.Variable
|
||||
|
||||
for {
|
||||
for i := 0; i < SPONGE_RATE; i++ {
|
||||
outputs = append(outputs, state[i])
|
||||
if len(outputs) == nbOutputs {
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
state = c.Poseidon(state)
|
||||
}
|
||||
}
|
||||
|
||||
// The input elements can be outside of the Goldilocks field.
|
||||
// The returned slice's elements will all be within Goldilocks field.
|
||||
func (c *GoldilocksChip) HashNoPad(input []gl.Variable) GoldilocksHashOut {
|
||||
var hash GoldilocksHashOut
|
||||
inputVars := []gl.Variable{}
|
||||
|
||||
for i := 0; i < len(input); i++ {
|
||||
inputVars = append(inputVars, c.gl.Reduce(input[i]))
|
||||
}
|
||||
|
||||
outputVars := c.HashNToMNoPad(inputVars, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
hash[i] = outputVars[i]
|
||||
}
|
||||
|
||||
return hash
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) ToVec(hash GoldilocksHashOut) []gl.Variable {
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) fullRounds(state GoldilocksState, roundCounter *int) GoldilocksState {
|
||||
for i := 0; i < HALF_N_FULL_ROUNDS; i++ {
|
||||
state = c.constantLayer(state, roundCounter)
|
||||
state = c.sBoxLayer(state)
|
||||
state = c.mdsLayer(state)
|
||||
*roundCounter += 1
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) partialRounds(state GoldilocksState, roundCounter *int) GoldilocksState {
|
||||
state = c.partialFirstConstantLayer(state)
|
||||
state = c.mdsPartialLayerInit(state)
|
||||
|
||||
for i := 0; i < N_PARTIAL_ROUNDS; i++ {
|
||||
state[0] = c.sBoxMonomial(state[0])
|
||||
state[0] = c.gl.Add(state[0], gl.NewVariable(FAST_PARTIAL_ROUND_CONSTANTS[i]))
|
||||
state = c.mdsPartialLayerFast(state, i)
|
||||
}
|
||||
|
||||
*roundCounter += N_PARTIAL_ROUNDS
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) constantLayer(state GoldilocksState, roundCounter *int) GoldilocksState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
roundConstant := ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)]
|
||||
state[i] = c.gl.MulAdd(state[i], gl.NewVariable(1), gl.NewVariable(roundConstant))
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) ConstantLayerExtension(state GoldilocksStateExtension, roundCounter *int) GoldilocksStateExtension {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
roundConstant := gl.NewVariable(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)])
|
||||
roundConstantQE := gl.NewQuadraticExtensionVariable(roundConstant, gl.Zero())
|
||||
state[i] = c.gl.AddExtension(state[i], roundConstantQE)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) sBoxMonomial(x gl.Variable) gl.Variable {
|
||||
x2 := c.gl.MulNoReduce(x, x)
|
||||
x3 := c.gl.MulNoReduce(x, x2)
|
||||
x3 = c.gl.ReduceWithMaxBits(x3, 192)
|
||||
x6 := c.gl.MulNoReduce(x3, x3)
|
||||
x7 := c.gl.MulNoReduce(x, x6)
|
||||
return c.gl.ReduceWithMaxBits(x7, 192)
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) SBoxMonomialExtension(x gl.QuadraticExtensionVariable) gl.QuadraticExtensionVariable {
|
||||
x2 := c.gl.MulExtension(x, x)
|
||||
x4 := c.gl.MulExtension(x2, x2)
|
||||
x3 := c.gl.MulExtension(x, x2)
|
||||
return c.gl.MulExtension(x4, x3)
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) sBoxLayer(state GoldilocksState) GoldilocksState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = c.sBoxMonomial(state[i])
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) SBoxLayerExtension(state GoldilocksStateExtension) GoldilocksStateExtension {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = c.SBoxMonomialExtension(state[i])
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) mdsRowShf(r int, v [SPONGE_WIDTH]gl.Variable) gl.Variable {
|
||||
res := gl.Zero()
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
res = c.gl.MulAddNoReduce(v[(i+r)%SPONGE_WIDTH], gl.NewVariable(MDS_MATRIX_CIRC_VARS[i]), res)
|
||||
}
|
||||
}
|
||||
|
||||
res = c.gl.MulAddNoReduce(v[r], gl.NewVariable(MDS_MATRIX_DIAG_VARS[r]), res)
|
||||
return c.gl.Reduce(res)
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) MdsRowShfExtension(r int, v [SPONGE_WIDTH]gl.QuadraticExtensionVariable) gl.QuadraticExtensionVariable {
|
||||
res := gl.ZeroExtension()
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
matrixVal := gl.NewVariable(MDS_MATRIX_CIRC[i])
|
||||
matrixValQE := gl.NewQuadraticExtensionVariable(matrixVal, gl.Zero())
|
||||
res1 := c.gl.MulExtension(v[(i+r)%SPONGE_WIDTH], matrixValQE)
|
||||
res = c.gl.AddExtension(res, res1)
|
||||
}
|
||||
}
|
||||
|
||||
matrixVal := gl.NewVariable(MDS_MATRIX_DIAG[r])
|
||||
matrixValQE := gl.NewQuadraticExtensionVariable(matrixVal, gl.Zero())
|
||||
res = c.gl.AddExtension(res, c.gl.MulExtension(v[r], matrixValQE))
|
||||
return res
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) mdsLayer(state_ GoldilocksState) GoldilocksState {
|
||||
var result GoldilocksState
|
||||
for i := 0; i < SPONGE_WIDTH; i++ {
|
||||
result[i] = gl.NewVariable(0)
|
||||
}
|
||||
|
||||
for r := 0; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
result[r] = c.mdsRowShf(r, state_)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) MdsLayerExtension(state_ GoldilocksStateExtension) GoldilocksStateExtension {
|
||||
var result GoldilocksStateExtension
|
||||
|
||||
for r := 0; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
sum := c.MdsRowShfExtension(r, state_)
|
||||
result[r] = sum
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) partialFirstConstantLayer(state GoldilocksState) GoldilocksState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = c.gl.Add(state[i], gl.NewVariable(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]))
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) PartialFirstConstantLayerExtension(state GoldilocksStateExtension) GoldilocksStateExtension {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
fastPartialRoundConstant := gl.NewVariable(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])
|
||||
fastPartialRoundConstantQE := gl.NewQuadraticExtensionVariable(fastPartialRoundConstant, gl.Zero())
|
||||
state[i] = c.gl.AddExtension(state[i], fastPartialRoundConstantQE)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) mdsPartialLayerInit(state GoldilocksState) GoldilocksState {
|
||||
var result GoldilocksState
|
||||
for i := 0; i < 12; i++ {
|
||||
result[i] = gl.NewVariable(0)
|
||||
}
|
||||
|
||||
result[0] = state[0]
|
||||
|
||||
for r := 1; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
for d := 1; d < 12; d++ {
|
||||
if d < SPONGE_WIDTH {
|
||||
t := FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1]
|
||||
result[d] = c.gl.MulAddNoReduce(state[r], gl.NewVariable(t), result[d])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
result[i] = c.gl.Reduce(result[i])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) MdsPartialLayerInitExtension(state GoldilocksStateExtension) GoldilocksStateExtension {
|
||||
var result GoldilocksStateExtension
|
||||
for i := 0; i < 12; i++ {
|
||||
result[i] = gl.ZeroExtension()
|
||||
}
|
||||
|
||||
result[0] = state[0]
|
||||
|
||||
for r := 1; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
for d := 1; d < 12; d++ {
|
||||
if d < SPONGE_WIDTH {
|
||||
t := gl.NewVariable(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])
|
||||
tQE := gl.NewQuadraticExtensionVariable(t, gl.Zero())
|
||||
result[d] = c.gl.AddExtension(result[d], c.gl.MulExtension(state[r], tQE))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) mdsPartialLayerFast(state GoldilocksState, r int) GoldilocksState {
|
||||
dSum := gl.Zero()
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := FAST_PARTIAL_ROUND_W_HATS_VARS[r][i-1]
|
||||
dSum = c.gl.MulAddNoReduce(state[i], gl.NewVariable(t), dSum)
|
||||
}
|
||||
}
|
||||
|
||||
d := c.gl.MulAddNoReduce(state[0], gl.NewVariable(MDS0TO0_VAR), dSum)
|
||||
d = c.gl.Reduce(d)
|
||||
|
||||
var result GoldilocksState
|
||||
for i := 0; i < SPONGE_WIDTH; i++ {
|
||||
result[i] = gl.NewVariable(0)
|
||||
}
|
||||
|
||||
result[0] = d
|
||||
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := FAST_PARTIAL_ROUND_VS[r][i-1]
|
||||
result[i] = c.gl.MulAddNoReduce(state[0], gl.NewVariable(t), state[i])
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(state); i++ {
|
||||
result[i] = c.gl.Reduce(result[i])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *GoldilocksChip) MdsPartialLayerFastExtension(state GoldilocksStateExtension, r int) GoldilocksStateExtension {
|
||||
s0 := state[0]
|
||||
mds0to0 := gl.NewVariable(MDS0TO0)
|
||||
mds0to0QE := gl.NewQuadraticExtensionVariable(mds0to0, gl.Zero())
|
||||
d := c.gl.MulExtension(s0, mds0to0QE)
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := gl.NewVariable(FAST_PARTIAL_ROUND_W_HATS[r][i-1])
|
||||
tQE := gl.NewQuadraticExtensionVariable(t, gl.Zero())
|
||||
d = c.gl.AddExtension(d, c.gl.MulExtension(state[i], tQE))
|
||||
}
|
||||
}
|
||||
|
||||
var result GoldilocksStateExtension
|
||||
result[0] = d
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := gl.NewVariable(FAST_PARTIAL_ROUND_VS[r][i-1])
|
||||
tQE := gl.NewQuadraticExtensionVariable(t, gl.Zero())
|
||||
result[i] = c.gl.AddExtension(c.gl.MulExtension(state[0], tQE), state[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -3,12 +3,12 @@ package poseidon
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/gnark/backend/groth16"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/frontend/cs/r1cs"
|
||||
"github.com/consensys/gnark/test"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/field"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/utils"
|
||||
gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks"
|
||||
)
|
||||
|
||||
type TestPoseidonCircuit struct {
|
||||
@@ -17,19 +17,18 @@ type TestPoseidonCircuit struct {
|
||||
}
|
||||
|
||||
func (circuit *TestPoseidonCircuit) Define(api frontend.API) error {
|
||||
goldilocksApi := field.NewFieldAPI(api)
|
||||
qeAPI := field.NewQuadraticExtensionAPI(api, goldilocksApi)
|
||||
|
||||
var input PoseidonState
|
||||
var input GoldilocksState
|
||||
for i := 0; i < 12; i++ {
|
||||
input[i] = circuit.In[i]
|
||||
input[i] = gl.NewVariable(circuit.In[i])
|
||||
}
|
||||
|
||||
poseidonChip := NewPoseidonChip(api, goldilocksApi, qeAPI)
|
||||
poseidonChip := NewGoldilocksChip(api)
|
||||
output := poseidonChip.Poseidon(input)
|
||||
|
||||
glApi := gl.NewChip(api)
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
api.AssertIsEqual(output[i], circuit.Out[i])
|
||||
glApi.AssertIsEqual(output[i], gl.NewVariable(circuit.Out[i]))
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -41,7 +40,7 @@ func TestPoseidonWitness(t *testing.T) {
|
||||
testCase := func(in [12]frontend.Variable, out [12]frontend.Variable) {
|
||||
circuit := TestPoseidonCircuit{In: in, Out: out}
|
||||
witness := TestPoseidonCircuit{In: in, Out: out}
|
||||
err := test.IsSolved(&circuit, &witness, field.TEST_CURVE.ScalarField())
|
||||
err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField())
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
@@ -54,8 +53,8 @@ func TestPoseidonWitness(t *testing.T) {
|
||||
}
|
||||
var in [12]frontend.Variable
|
||||
var out [12]frontend.Variable
|
||||
copy(in[:], utils.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], utils.StrArrayToFrontendVariableArray(outStr))
|
||||
copy(in[:], gl.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], gl.StrArrayToFrontendVariableArray(outStr))
|
||||
testCase(in, out)
|
||||
}
|
||||
|
||||
@@ -69,18 +68,18 @@ func TestPoseidonProof(t *testing.T) {
|
||||
}
|
||||
var in [12]frontend.Variable
|
||||
var out [12]frontend.Variable
|
||||
copy(in[:], utils.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], utils.StrArrayToFrontendVariableArray(outStr))
|
||||
copy(in[:], gl.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], gl.StrArrayToFrontendVariableArray(outStr))
|
||||
|
||||
circuit := TestPoseidonCircuit{In: in, Out: out}
|
||||
assignment := TestPoseidonCircuit{In: in, Out: out}
|
||||
|
||||
r1cs, err := frontend.Compile(field.TEST_CURVE.ScalarField(), r1cs.NewBuilder, &circuit)
|
||||
r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
witness, err := frontend.NewWitness(&assignment, field.TEST_CURVE.ScalarField())
|
||||
witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -90,7 +89,7 @@ func TestPoseidonProof(t *testing.T) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = test.IsSolved(&circuit, &assignment, field.TEST_CURVE.ScalarField())
|
||||
err = test.IsSolved(&circuit, &assignment, ecc.BN254.ScalarField())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -1,338 +0,0 @@
|
||||
package poseidon
|
||||
|
||||
import (
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/field"
|
||||
)
|
||||
|
||||
const HALF_N_FULL_ROUNDS = 4
|
||||
const N_PARTIAL_ROUNDS = 22
|
||||
const MAX_WIDTH = 12
|
||||
const SPONGE_WIDTH = 12
|
||||
const SPONGE_RATE = 8
|
||||
|
||||
type PoseidonState = [SPONGE_WIDTH]frontend.Variable
|
||||
type PoseidonStateExtension = [SPONGE_WIDTH]field.QuadraticExtension
|
||||
type PoseidonHashOut = [4]field.F
|
||||
|
||||
type PoseidonChip struct {
|
||||
api frontend.API `gnark:"-"`
|
||||
fieldAPI field.FieldAPI `gnark:"-"`
|
||||
qeAPI *field.QuadraticExtensionAPI `gnark:"-"`
|
||||
}
|
||||
|
||||
func NewPoseidonChip(api frontend.API, fieldAPI field.FieldAPI, qeAPI *field.QuadraticExtensionAPI) *PoseidonChip {
|
||||
return &PoseidonChip{api: api, fieldAPI: fieldAPI, qeAPI: qeAPI}
|
||||
}
|
||||
|
||||
// The permutation function.
|
||||
// The input state MUST have all it's elements be within Goldilocks field (e.g. this function will not reduce the input elements).
|
||||
// The returned state's elements will all be within Goldilocks field.
|
||||
func (c *PoseidonChip) Poseidon(input PoseidonState) PoseidonState {
|
||||
state := input
|
||||
roundCounter := 0
|
||||
state = c.fullRounds(state, &roundCounter)
|
||||
state = c.partialRounds(state, &roundCounter)
|
||||
state = c.fullRounds(state, &roundCounter)
|
||||
return state
|
||||
}
|
||||
|
||||
// The input elements MUST have all it's elements be within Goldilocks field.
|
||||
// The returned slice's elements will all be within Goldilocks field.
|
||||
func (c *PoseidonChip) HashNToMNoPad(input []frontend.Variable, nbOutputs int) []frontend.Variable {
|
||||
var state PoseidonState
|
||||
|
||||
for i := 0; i < SPONGE_WIDTH; i++ {
|
||||
state[i] = frontend.Variable(0)
|
||||
}
|
||||
|
||||
for i := 0; i < len(input); i += SPONGE_RATE {
|
||||
for j := 0; j < SPONGE_RATE; j++ {
|
||||
if i+j < len(input) {
|
||||
state[j] = input[i+j]
|
||||
}
|
||||
}
|
||||
state = c.Poseidon(state)
|
||||
}
|
||||
|
||||
var outputs []frontend.Variable
|
||||
|
||||
for {
|
||||
for i := 0; i < SPONGE_RATE; i++ {
|
||||
outputs = append(outputs, state[i])
|
||||
if len(outputs) == nbOutputs {
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
state = c.Poseidon(state)
|
||||
}
|
||||
}
|
||||
|
||||
// The input elements can be outside of the Goldilocks field.
|
||||
// The returned slice's elements will all be within Goldilocks field.
|
||||
func (c *PoseidonChip) HashNoPad(input []field.F) PoseidonHashOut {
|
||||
var hash PoseidonHashOut
|
||||
inputVars := []frontend.Variable{}
|
||||
|
||||
for i := 0; i < len(input); i++ {
|
||||
inputVars = append(inputVars, c.fieldAPI.Reduce(input[i]).Limbs[0])
|
||||
}
|
||||
|
||||
outputVars := c.HashNToMNoPad(inputVars, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
hash[i] = c.fieldAPI.NewElement(outputVars[i])
|
||||
}
|
||||
|
||||
return hash
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) ToVec(hash PoseidonHashOut) []field.F {
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) fullRounds(state PoseidonState, roundCounter *int) PoseidonState {
|
||||
for i := 0; i < HALF_N_FULL_ROUNDS; i++ {
|
||||
state = c.constantLayer(state, roundCounter)
|
||||
state = c.sBoxLayer(state)
|
||||
state = c.mdsLayer(state)
|
||||
*roundCounter += 1
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) partialRounds(state PoseidonState, roundCounter *int) PoseidonState {
|
||||
state = c.partialFirstConstantLayer(state)
|
||||
state = c.mdsPartialLayerInit(state)
|
||||
|
||||
for i := 0; i < N_PARTIAL_ROUNDS; i++ {
|
||||
state[0] = c.sBoxMonomial(state[0])
|
||||
state[0] = field.GoldilocksMulAdd(c.api, frontend.Variable(1), state[0], FAST_PARTIAL_ROUND_CONSTANTS[i])
|
||||
state = c.mdsPartialLayerFast(state, i)
|
||||
}
|
||||
|
||||
*roundCounter += N_PARTIAL_ROUNDS
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) constantLayer(state PoseidonState, roundCounter *int) PoseidonState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
roundConstant := ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)]
|
||||
state[i] = field.GoldilocksMulAdd(c.api, frontend.Variable(1), state[i], roundConstant)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) ConstantLayerExtension(state PoseidonStateExtension, roundCounter *int) PoseidonStateExtension {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
roundConstant := c.qeAPI.VarToQE(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)])
|
||||
state[i] = c.qeAPI.AddExtension(state[i], roundConstant)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) sBoxMonomial(x frontend.Variable) frontend.Variable {
|
||||
x2 := field.GoldilocksMulAdd(c.api, x, x, frontend.Variable(0))
|
||||
x4 := field.GoldilocksMulAdd(c.api, x2, x2, frontend.Variable(0))
|
||||
x6 := field.GoldilocksMulAdd(c.api, x4, x2, frontend.Variable(0))
|
||||
return field.GoldilocksMulAdd(c.api, x6, x, frontend.Variable(0))
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) SBoxMonomialExtension(x field.QuadraticExtension) field.QuadraticExtension {
|
||||
x2 := c.qeAPI.SquareExtension(x)
|
||||
x4 := c.qeAPI.SquareExtension(x2)
|
||||
x3 := c.qeAPI.MulExtension(x, x2)
|
||||
return c.qeAPI.MulExtension(x3, x4)
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) sBoxLayer(state PoseidonState) PoseidonState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = c.sBoxMonomial(state[i])
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) SBoxLayerExtension(state PoseidonStateExtension) PoseidonStateExtension {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = c.SBoxMonomialExtension(state[i])
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsRowShf(r int, v [SPONGE_WIDTH]frontend.Variable) frontend.Variable {
|
||||
res := ZERO_VAR
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
res = field.GoldilocksMulAdd(c.api, v[(i+r)%SPONGE_WIDTH], MDS_MATRIX_CIRC_VARS[i], res)
|
||||
}
|
||||
}
|
||||
|
||||
res = field.GoldilocksMulAdd(c.api, v[r], MDS_MATRIX_DIAG_VARS[r], res)
|
||||
return res
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) MdsRowShfExtension(r int, v [SPONGE_WIDTH]field.QuadraticExtension) field.QuadraticExtension {
|
||||
res := c.qeAPI.FieldToQE(field.ZERO_F)
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
matrixVal := c.qeAPI.VarToQE(MDS_MATRIX_CIRC[i])
|
||||
res1 := c.qeAPI.MulExtension(v[(i+r)%SPONGE_WIDTH], matrixVal)
|
||||
res = c.qeAPI.AddExtension(res, res1)
|
||||
}
|
||||
}
|
||||
|
||||
matrixVal := c.qeAPI.VarToQE(MDS_MATRIX_DIAG[r])
|
||||
res = c.qeAPI.AddExtension(res, c.qeAPI.MulExtension(v[r], matrixVal))
|
||||
return res
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsLayer(state_ PoseidonState) PoseidonState {
|
||||
var result PoseidonState
|
||||
for i := 0; i < SPONGE_WIDTH; i++ {
|
||||
result[i] = frontend.Variable(0)
|
||||
}
|
||||
|
||||
for r := 0; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
result[r] = c.mdsRowShf(r, state_)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) MdsLayerExtension(state_ PoseidonStateExtension) PoseidonStateExtension {
|
||||
var result PoseidonStateExtension
|
||||
|
||||
for r := 0; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
sum := c.MdsRowShfExtension(r, state_)
|
||||
result[r] = sum
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) partialFirstConstantLayer(state PoseidonState) PoseidonState {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = field.GoldilocksMulAdd(c.api, frontend.Variable(1), state[i], FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) PartialFirstConstantLayerExtension(state PoseidonStateExtension) PoseidonStateExtension {
|
||||
for i := 0; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
state[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.VarToQE((FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])))
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsPartialLayerInit(state PoseidonState) PoseidonState {
|
||||
var result PoseidonState
|
||||
for i := 0; i < 12; i++ {
|
||||
result[i] = frontend.Variable(0)
|
||||
}
|
||||
|
||||
result[0] = state[0]
|
||||
|
||||
for r := 1; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
for d := 1; d < 12; d++ {
|
||||
if d < SPONGE_WIDTH {
|
||||
t := FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1]
|
||||
result[d] = field.GoldilocksMulAdd(c.api, state[r], t, result[d])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) MdsPartialLayerInitExtension(state PoseidonStateExtension) PoseidonStateExtension {
|
||||
var result PoseidonStateExtension
|
||||
for i := 0; i < 12; i++ {
|
||||
result[i] = c.qeAPI.FieldToQE(field.ZERO_F)
|
||||
}
|
||||
|
||||
result[0] = state[0]
|
||||
|
||||
for r := 1; r < 12; r++ {
|
||||
if r < SPONGE_WIDTH {
|
||||
for d := 1; d < 12; d++ {
|
||||
if d < SPONGE_WIDTH {
|
||||
t := c.qeAPI.VarToQE(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])
|
||||
result[d] = c.qeAPI.AddExtension(result[d], c.qeAPI.MulExtension(state[r], t))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) mdsPartialLayerFast(state PoseidonState, r int) PoseidonState {
|
||||
dSum := ZERO_VAR
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := FAST_PARTIAL_ROUND_W_HATS_VARS[r][i-1]
|
||||
dSum = field.GoldilocksMulAdd(c.api, state[i], t, dSum)
|
||||
}
|
||||
}
|
||||
|
||||
d := field.GoldilocksMulAdd(c.api, state[0], MDS0TO0_VAR, dSum)
|
||||
|
||||
var result PoseidonState
|
||||
for i := 0; i < SPONGE_WIDTH; i++ {
|
||||
result[i] = frontend.Variable(0)
|
||||
}
|
||||
|
||||
result[0] = d
|
||||
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := FAST_PARTIAL_ROUND_VS[r][i-1]
|
||||
result[i] = field.GoldilocksMulAdd(c.api, state[0], t, state[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonChip) MdsPartialLayerFastExtension(state PoseidonStateExtension, r int) PoseidonStateExtension {
|
||||
s0 := state[0]
|
||||
mds0to0 := c.qeAPI.VarToQE(MDS0TO0)
|
||||
d := c.qeAPI.MulExtension(s0, mds0to0)
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := c.qeAPI.VarToQE(FAST_PARTIAL_ROUND_W_HATS[r][i-1])
|
||||
d = c.qeAPI.AddExtension(d, c.qeAPI.MulExtension(state[i], t))
|
||||
}
|
||||
}
|
||||
|
||||
var result PoseidonStateExtension
|
||||
result[0] = d
|
||||
for i := 1; i < 12; i++ {
|
||||
if i < SPONGE_WIDTH {
|
||||
t := c.qeAPI.VarToQE(FAST_PARTIAL_ROUND_VS[r][i-1])
|
||||
result[i] = c.qeAPI.AddExtension(c.qeAPI.MulExtension(state[0], t), state[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package poseidon
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/field"
|
||||
)
|
||||
|
||||
const fullRounds = 8
|
||||
const partialRounds = 56
|
||||
const spongeWidth = 4
|
||||
const spongeRate = 3
|
||||
|
||||
type PoseidonBN128Chip struct {
|
||||
api frontend.API `gnark:"-"`
|
||||
fieldAPI field.FieldAPI `gnark:"-"`
|
||||
}
|
||||
|
||||
type PoseidonBN128State = [spongeWidth]frontend.Variable
|
||||
type PoseidonBN128HashOut = frontend.Variable
|
||||
|
||||
// This implementation is based on the following implementation:
|
||||
// https://github.com/iden3/go-iden3-crypto/blob/e5cf066b8be3da9a3df9544c65818df189fdbebe/poseidon/poseidon.go
|
||||
func NewPoseidonBN128Chip(api frontend.API, fieldAPI field.FieldAPI) *PoseidonBN128Chip {
|
||||
return &PoseidonBN128Chip{api: api, fieldAPI: fieldAPI}
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) Poseidon(state PoseidonBN128State) PoseidonBN128State {
|
||||
state = c.ark(state, 0)
|
||||
state = c.fullRounds(state, true)
|
||||
state = c.partialRounds(state)
|
||||
state = c.fullRounds(state, false)
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) HashNoPad(input []field.F) PoseidonBN128HashOut {
|
||||
state := PoseidonBN128State{
|
||||
frontend.Variable(0),
|
||||
frontend.Variable(0),
|
||||
frontend.Variable(0),
|
||||
frontend.Variable(0),
|
||||
}
|
||||
|
||||
for i := 0; i < len(input); i += spongeRate * 3 {
|
||||
endI := c.min(len(input), i+spongeRate*3)
|
||||
rateChunk := input[i:endI]
|
||||
for j, stateIdx := 0, 0; j < len(rateChunk); j, stateIdx = j+3, stateIdx+1 {
|
||||
endJ := c.min(len(rateChunk), j+3)
|
||||
bn128Chunk := rateChunk[j:endJ]
|
||||
|
||||
bits := []frontend.Variable{}
|
||||
for k := 0; k < len(bn128Chunk); k++ {
|
||||
bn128Chunk[k] = c.fieldAPI.Reduce(bn128Chunk[k])
|
||||
bits = append(bits, c.fieldAPI.ToBits(bn128Chunk[k])...)
|
||||
}
|
||||
|
||||
state[stateIdx+1] = c.api.FromBinary(bits...)
|
||||
}
|
||||
|
||||
state = c.Poseidon(state)
|
||||
}
|
||||
|
||||
return PoseidonBN128HashOut(state[0])
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) HashOrNoop(input []field.F) PoseidonBN128HashOut {
|
||||
if len(input) <= 3 {
|
||||
returnVal := frontend.Variable(0)
|
||||
|
||||
alpha := new(big.Int).SetInt64(1 << 32)
|
||||
for i, inputElement := range input {
|
||||
returnVal = c.api.Add(returnVal, c.api.Mul(inputElement, alpha.Exp(alpha, big.NewInt(int64(i)), nil)))
|
||||
}
|
||||
|
||||
return PoseidonBN128HashOut(returnVal)
|
||||
} else {
|
||||
return c.HashNoPad(input)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) TwoToOne(left PoseidonBN128HashOut, right PoseidonBN128HashOut) PoseidonBN128HashOut {
|
||||
var inputs PoseidonBN128State
|
||||
inputs[0] = frontend.Variable(0)
|
||||
inputs[1] = frontend.Variable(0)
|
||||
inputs[2] = left
|
||||
inputs[3] = right
|
||||
state := c.Poseidon(inputs)
|
||||
return state[0]
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) ToVec(hash PoseidonBN128HashOut) []field.F {
|
||||
bits := c.api.ToBinary(hash)
|
||||
|
||||
returnElements := []field.F{}
|
||||
|
||||
// Split into 7 byte chunks, since 8 byte chunks can result in collisions
|
||||
chunkSize := 56
|
||||
for i := 0; i < len(bits); i += chunkSize {
|
||||
maxIdx := c.min(len(bits), i+chunkSize)
|
||||
bitChunk := bits[i:maxIdx]
|
||||
returnElements = append(returnElements, c.fieldAPI.FromBits(bitChunk...))
|
||||
}
|
||||
|
||||
return returnElements
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) min(x, y int) int {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
|
||||
return y
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) fullRounds(state PoseidonBN128State, isFirst bool) PoseidonBN128State {
|
||||
for i := 0; i < fullRounds/2-1; i++ {
|
||||
state = c.exp5state(state)
|
||||
if isFirst {
|
||||
state = c.ark(state, (i+1)*spongeWidth)
|
||||
} else {
|
||||
state = c.ark(state, (fullRounds/2+1)*spongeWidth+partialRounds+i*spongeWidth)
|
||||
}
|
||||
state = c.mix(state, mMatrix)
|
||||
}
|
||||
|
||||
state = c.exp5state(state)
|
||||
if isFirst {
|
||||
state = c.ark(state, (fullRounds/2)*spongeWidth)
|
||||
state = c.mix(state, pMatrix)
|
||||
} else {
|
||||
state = c.mix(state, mMatrix)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) partialRounds(state PoseidonBN128State) PoseidonBN128State {
|
||||
for i := 0; i < partialRounds; i++ {
|
||||
state[0] = c.exp5(state[0])
|
||||
state[0] = c.api.Add(state[0], cConstants[(fullRounds/2+1)*spongeWidth+i])
|
||||
|
||||
var mul frontend.Variable
|
||||
newState0 := frontend.Variable(0)
|
||||
for j := 0; j < spongeWidth; j++ {
|
||||
mul = c.api.Mul(sConstants[(spongeWidth*2-1)*i+j], state[j])
|
||||
newState0 = c.api.Add(newState0, mul)
|
||||
}
|
||||
|
||||
for k := 1; k < spongeWidth; k++ {
|
||||
mul = c.api.Mul(state[0], sConstants[(spongeWidth*2-1)*i+spongeWidth+k-1])
|
||||
state[k] = c.api.Add(state[k], mul)
|
||||
}
|
||||
state[0] = newState0
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) ark(state PoseidonBN128State, it int) PoseidonBN128State {
|
||||
var result PoseidonBN128State
|
||||
|
||||
for i := 0; i < len(state); i++ {
|
||||
result[i] = c.api.Add(state[i], cConstants[it+i])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) exp5(x frontend.Variable) frontend.Variable {
|
||||
x2 := c.api.Mul(x, x)
|
||||
x4 := c.api.Mul(x2, x2)
|
||||
return c.api.Mul(x4, x)
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) exp5state(state PoseidonBN128State) PoseidonBN128State {
|
||||
for i := 0; i < spongeWidth; i++ {
|
||||
state[i] = c.exp5(state[i])
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *PoseidonBN128Chip) mix(state_ PoseidonBN128State, constantMatrix [][]*big.Int) PoseidonBN128State {
|
||||
var mul frontend.Variable
|
||||
var result PoseidonBN128State
|
||||
|
||||
for i := 0; i < spongeWidth; i++ {
|
||||
result[i] = frontend.Variable(0)
|
||||
}
|
||||
|
||||
for i := 0; i < spongeWidth; i++ {
|
||||
for j := 0; j < spongeWidth; j++ {
|
||||
mul = c.api.Mul(constantMatrix[j][i], state_[j])
|
||||
result[i] = c.api.Add(result[i], mul)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -7,8 +7,7 @@ import (
|
||||
"github.com/consensys/gnark/backend"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/test"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/field"
|
||||
"github.com/succinctlabs/gnark-plonky2-verifier/utils"
|
||||
gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks"
|
||||
)
|
||||
|
||||
var testCurve = ecc.BN254
|
||||
@@ -19,22 +18,22 @@ type TestPublicInputsHashCircuit struct {
|
||||
}
|
||||
|
||||
func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error {
|
||||
fieldAPI := field.NewFieldAPI(api)
|
||||
glAPI := gl.NewChip(api)
|
||||
|
||||
// BN254 -> Binary(64) -> F
|
||||
var input [3]field.F
|
||||
var input [3]gl.Variable
|
||||
for i := 0; i < 3; i++ {
|
||||
input[i] = fieldAPI.FromBits(api.ToBinary(circuit.In[i], 64)...)
|
||||
input[i] = gl.NewVariable(api.FromBinary(api.ToBinary(circuit.In[i], 64)...))
|
||||
}
|
||||
|
||||
poseidonChip := &PoseidonChip{api: api, fieldAPI: fieldAPI}
|
||||
poseidonChip := &GoldilocksChip{api: api, gl: *glAPI}
|
||||
output := poseidonChip.HashNoPad(input[:])
|
||||
|
||||
// Check that output is correct
|
||||
for i := 0; i < 4; i++ {
|
||||
fieldAPI.AssertIsEqual(
|
||||
glAPI.AssertIsEqual(
|
||||
output[i],
|
||||
fieldAPI.FromBits(api.ToBinary(circuit.Out[i])...),
|
||||
gl.NewVariable(api.FromBinary(api.ToBinary(circuit.Out[i])...)),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -55,8 +54,8 @@ func TestPublicInputsHashWitness(t *testing.T) {
|
||||
outStr := []string{"8416658900775745054", "12574228347150446423", "9629056739760131473", "3119289788404190010"}
|
||||
var in [3]frontend.Variable
|
||||
var out [4]frontend.Variable
|
||||
copy(in[:], utils.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], utils.StrArrayToFrontendVariableArray(outStr))
|
||||
copy(in[:], gl.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], gl.StrArrayToFrontendVariableArray(outStr))
|
||||
testCase(in, out)
|
||||
}
|
||||
|
||||
@@ -67,8 +66,8 @@ func TestPublicInputsHashWitness2(t *testing.T) {
|
||||
outStr := []string{"8416658900775745054", "12574228347150446423", "9629056739760131473", "3119289788404190010"}
|
||||
var in [3]frontend.Variable
|
||||
var out [4]frontend.Variable
|
||||
copy(in[:], utils.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], utils.StrArrayToFrontendVariableArray(outStr))
|
||||
copy(in[:], gl.StrArrayToFrontendVariableArray(inStr))
|
||||
copy(out[:], gl.StrArrayToFrontendVariableArray(outStr))
|
||||
|
||||
circuit := TestPublicInputsHashCircuit{In: in, Out: out}
|
||||
witness := TestPublicInputsHashCircuit{In: in, Out: out}
|
||||
|
||||
Reference in New Issue
Block a user