You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
5.1 KiB

package poseidon
// This is a customized implementation of the Poseidon hash function inside the BN254 field.
// This implementation is based on the following implementation:
//
// https://github.com/iden3/go-iden3-crypto/blob/master/poseidon/poseidon.go
//
// The input and output are modified to ingest Goldilocks field elements.
import (
"math/big"
"github.com/consensys/gnark/frontend"
gl "github.com/succinctlabs/gnark-plonky2-verifier/goldilocks"
)
const BN254_FULL_ROUNDS int = 8
const BN254_PARTIAL_ROUNDS int = 56
const BN254_SPONGE_WIDTH int = 4
const BN254_SPONGE_RATE int = 3
type BN254Chip struct {
api frontend.API `gnark:"-"`
gl gl.Chip `gnark:"-"`
}
type BN254State = [BN254_SPONGE_WIDTH]frontend.Variable
type BN254HashOut = frontend.Variable
func NewBN254Chip(api frontend.API) *BN254Chip {
return &BN254Chip{api: api, gl: *gl.NewChip(api)}
}
func (c *BN254Chip) Poseidon(state BN254State) BN254State {
state = c.ark(state, 0)
state = c.fullRounds(state, true)
state = c.partialRounds(state)
state = c.fullRounds(state, false)
return state
}
func (c *BN254Chip) HashNoPad(input []gl.Variable) BN254HashOut {
state := BN254State{
frontend.Variable(0),
frontend.Variable(0),
frontend.Variable(0),
frontend.Variable(0),
}
for i := 0; i < len(input); i += BN254_SPONGE_RATE * 3 {
endI := c.min(len(input), i+BN254_SPONGE_RATE*3)
rateChunk := input[i:endI]
for j, stateIdx := 0, 0; j < len(rateChunk); j, stateIdx = j+3, stateIdx+1 {
endJ := c.min(len(rateChunk), j+3)
bn254Chunk := rateChunk[j:endJ]
bits := []frontend.Variable{}
for k := 0; k < len(bn254Chunk); k++ {
bn254Chunk[k] = c.gl.Reduce(bn254Chunk[k])
bits = append(bits, c.api.ToBinary(bn254Chunk[k].Limb, 64)...)
}
state[stateIdx+1] = c.api.FromBinary(bits...)
}
state = c.Poseidon(state)
}
return BN254HashOut(state[0])
}
func (c *BN254Chip) HashOrNoop(input []gl.Variable) BN254HashOut {
if len(input) <= 3 {
returnVal := frontend.Variable(0)
alpha := new(big.Int).SetInt64(1 << 32)
for i, inputElement := range input {
returnVal = c.api.Add(returnVal, c.api.Mul(inputElement, alpha.Exp(alpha, big.NewInt(int64(i)), nil)))
}
return BN254HashOut(returnVal)
} else {
return c.HashNoPad(input)
}
}
func (c *BN254Chip) TwoToOne(left BN254HashOut, right BN254HashOut) BN254HashOut {
var inputs BN254State
inputs[0] = frontend.Variable(0)
inputs[1] = frontend.Variable(0)
inputs[2] = left
inputs[3] = right
state := c.Poseidon(inputs)
return state[0]
}
func (c *BN254Chip) ToVec(hash BN254HashOut) []gl.Variable {
bits := c.api.ToBinary(hash)
returnElements := []gl.Variable{}
// Split into 7 byte chunks, since 8 byte chunks can result in collisions
chunkSize := 56
for i := 0; i < len(bits); i += chunkSize {
maxIdx := c.min(len(bits), i+chunkSize)
bitChunk := bits[i:maxIdx]
returnElements = append(returnElements, gl.NewVariable(c.api.FromBinary(bitChunk...)))
}
return returnElements
}
func (c *BN254Chip) min(x, y int) int {
if x < y {
return x
}
return y
}
func (c *BN254Chip) fullRounds(state BN254State, isFirst bool) BN254State {
for i := 0; i < BN254_FULL_ROUNDS/2-1; i++ {
state = c.exp5state(state)
if isFirst {
state = c.ark(state, (i+1)*BN254_SPONGE_WIDTH)
} else {
state = c.ark(state, (BN254_FULL_ROUNDS/2+1)*BN254_SPONGE_WIDTH+BN254_PARTIAL_ROUNDS+i*BN254_SPONGE_WIDTH)
}
state = c.mix(state, mMatrix)
}
state = c.exp5state(state)
if isFirst {
state = c.ark(state, (BN254_FULL_ROUNDS/2)*BN254_SPONGE_WIDTH)
state = c.mix(state, pMatrix)
} else {
state = c.mix(state, mMatrix)
}
return state
}
func (c *BN254Chip) partialRounds(state BN254State) BN254State {
for i := 0; i < BN254_PARTIAL_ROUNDS; i++ {
state[0] = c.exp5(state[0])
state[0] = c.api.Add(state[0], cConstants[(BN254_FULL_ROUNDS/2+1)*BN254_SPONGE_WIDTH+i])
var mul frontend.Variable
newState0 := frontend.Variable(0)
for j := 0; j < BN254_SPONGE_WIDTH; j++ {
mul = c.api.Mul(sConstants[(BN254_SPONGE_WIDTH*2-1)*i+j], state[j])
newState0 = c.api.Add(newState0, mul)
}
for k := 1; k < BN254_SPONGE_WIDTH; k++ {
mul = c.api.Mul(state[0], sConstants[(BN254_SPONGE_WIDTH*2-1)*i+BN254_SPONGE_WIDTH+k-1])
state[k] = c.api.Add(state[k], mul)
}
state[0] = newState0
}
return state
}
func (c *BN254Chip) ark(state BN254State, it int) BN254State {
var result BN254State
for i := 0; i < len(state); i++ {
result[i] = c.api.Add(state[i], cConstants[it+i])
}
return result
}
func (c *BN254Chip) exp5(x frontend.Variable) frontend.Variable {
x2 := c.api.Mul(x, x)
x4 := c.api.Mul(x2, x2)
return c.api.Mul(x4, x)
}
func (c *BN254Chip) exp5state(state BN254State) BN254State {
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
state[i] = c.exp5(state[i])
}
return state
}
func (c *BN254Chip) mix(state_ BN254State, constantMatrix [][]*big.Int) BN254State {
var mul frontend.Variable
var result BN254State
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
result[i] = frontend.Variable(0)
}
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
for j := 0; j < BN254_SPONGE_WIDTH; j++ {
mul = c.api.Mul(constantMatrix[j][i], state_[j])
result[i] = c.api.Add(result[i], mul)
}
}
return result
}