|
|
package poseidon
import ( "errors" "fmt" "math/big"
"github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" )
// NROUNDSF constant from Poseidon paper
const NROUNDSF = 8
// NROUNDSP constant from Poseidon paper
var NROUNDSP = []int{56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68}
const spongeChunkSize = 31 const spongeInputs = 16
func zero() *ff.Element { return ff.NewElement() }
var big5 = big.NewInt(5)
// exp5 performs x^5 mod p
// https://eprint.iacr.org/2019/458.pdf page 8
func exp5(a *ff.Element) { a.Exp(*a, big5) }
// exp5state perform exp5 for whole state
func exp5state(state []*ff.Element) { for i := 0; i < len(state); i++ { exp5(state[i]) } }
// ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf
func ark(state, c []*ff.Element, it int) { for i := 0; i < len(state); i++ { state[i].Add(state[i], c[it+i]) } }
// mix returns [[matrix]] * [vector]
func mix(state []*ff.Element, t int, m [][]*ff.Element) []*ff.Element { mul := zero() newState := make([]*ff.Element, t) for i := 0; i < t; i++ { newState[i] = zero() } for i := 0; i < len(state); i++ { newState[i].SetUint64(0) for j := 0; j < len(state); j++ { mul.Mul(m[j][i], state[j]) newState[i].Add(newState[i], mul) } } return newState }
// Hash computes the Poseidon hash for the given inputs
func Hash(inpBI []*big.Int) (*big.Int, error) { t := len(inpBI) + 1 if len(inpBI) == 0 || len(inpBI) > len(NROUNDSP) { return nil, fmt.Errorf("invalid inputs length %d, max %d", len(inpBI), len(NROUNDSP)) } if !utils.CheckBigIntArrayInField(inpBI) { return nil, errors.New("inputs values not inside Finite Field") } inp := utils.BigIntArrayToElementArray(inpBI)
nRoundsF := NROUNDSF nRoundsP := NROUNDSP[t-2] C := c.c[t-2] S := c.s[t-2] M := c.m[t-2] P := c.p[t-2]
state := make([]*ff.Element, t) state[0] = zero() copy(state[1:], inp)
ark(state, C, 0)
for i := 0; i < nRoundsF/2-1; i++ { exp5state(state) ark(state, C, (i+1)*t) state = mix(state, t, M) } exp5state(state) ark(state, C, (nRoundsF/2)*t) state = mix(state, t, P)
mul := zero() for i := 0; i < nRoundsP; i++ { exp5(state[0]) state[0].Add(state[0], C[(nRoundsF/2+1)*t+i])
mul.SetZero() newState0 := zero() for j := 0; j < len(state); j++ { mul.Mul(S[(t*2-1)*i+j], state[j]) newState0.Add(newState0, mul) }
for k := 1; k < t; k++ { mul.SetZero() state[k] = state[k].Add(state[k], mul.Mul(state[0], S[(t*2-1)*i+t+k-1])) } state[0] = newState0 }
for i := 0; i < nRoundsF/2-1; i++ { exp5state(state) ark(state, C, (nRoundsF/2+1)*t+nRoundsP+i*t) state = mix(state, t, M) } exp5state(state) state = mix(state, t, M)
rE := state[0] r := big.NewInt(0) rE.ToBigIntRegular(r) return r, nil }
// HashBytes returns a sponge hash of a msg byte slice split into blocks of 31 bytes
func HashBytes(msg []byte) (*big.Int, error) { return HashBytesX(msg, spongeInputs) }
// HashBytesX returns a sponge hash of a msg byte slice split into blocks of 31 bytes
func HashBytesX(msg []byte, frameSize int) (*big.Int, error) { if frameSize < 2 || frameSize > 16 { return nil, errors.New("incorrect frame size") }
// not used inputs default to zero
inputs := make([]*big.Int, frameSize) for j := 0; j < frameSize; j++ { inputs[j] = new(big.Int) } dirty := false var hash *big.Int var err error
k := 0 for i := 0; i < len(msg)/spongeChunkSize; i++ { dirty = true inputs[k].SetBytes(msg[spongeChunkSize*i : spongeChunkSize*(i+1)]) if k == frameSize-1 { hash, err = Hash(inputs) dirty = false if err != nil { return nil, err } inputs = make([]*big.Int, frameSize) inputs[0] = hash for j := 1; j < frameSize; j++ { inputs[j] = new(big.Int) } k = 1 } else { k++ } }
if len(msg)%spongeChunkSize != 0 { // the last chunk of the message is less than 31 bytes
// zero padding it, so that 0xdeadbeaf becomes
// 0xdeadbeaf000000000000000000000000000000000000000000000000000000
var buf [spongeChunkSize]byte copy(buf[:], msg[(len(msg)/spongeChunkSize)*spongeChunkSize:]) inputs[k] = new(big.Int).SetBytes(buf[:]) dirty = true }
if dirty { // we haven't hashed something in the main sponge loop and need to do hash here
hash, err = Hash(inputs) if err != nil { return nil, err } }
return hash, nil }
// SpongeHash returns a sponge hash of inputs (using Poseidon with frame size of 16 inputs)
func SpongeHash(inputs []*big.Int) (*big.Int, error) { return SpongeHashX(inputs, spongeInputs) }
// SpongeHashX returns a sponge hash of inputs using Poseidon with configurable frame size
func SpongeHashX(inputs []*big.Int, frameSize int) (*big.Int, error) { if frameSize < 2 || frameSize > 16 { return nil, errors.New("incorrect frame size") }
// not used frame default to zero
frame := make([]*big.Int, frameSize) for j := 0; j < frameSize; j++ { frame[j] = new(big.Int) } dirty := false var hash *big.Int var err error
k := 0 for i := 0; i < len(inputs); i++ { dirty = true frame[k] = inputs[i] if k == frameSize-1 { hash, err = Hash(frame) dirty = false if err != nil { return nil, err } frame = make([]*big.Int, frameSize) frame[0] = hash for j := 1; j < frameSize; j++ { frame[j] = new(big.Int) } k = 1 } else { k++ } }
if dirty { // we haven't hashed something in the main sponge loop and need to do hash here
hash, err = Hash(frame) if err != nil { return nil, err } }
return hash, nil }
|