You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
5.4 KiB

package poseidon
import (
"errors"
"fmt"
"math/big"
"github.com/iden3/go-iden3-crypto/ff"
"github.com/iden3/go-iden3-crypto/utils"
)
// NROUNDSF constant from Poseidon paper
const NROUNDSF = 8
// NROUNDSP constant from Poseidon paper
var NROUNDSP = []int{56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68}
const spongeChunkSize = 31
const spongeInputs = 16
func zero() *ff.Element {
return ff.NewElement()
}
var big5 = big.NewInt(5)
// exp5 performs x^5 mod p
// https://eprint.iacr.org/2019/458.pdf page 8
func exp5(a *ff.Element) {
a.Exp(*a, big5)
}
// exp5state perform exp5 for whole state
func exp5state(state []*ff.Element) {
for i := 0; i < len(state); i++ {
exp5(state[i])
}
}
// ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf
func ark(state, c []*ff.Element, it int) {
for i := 0; i < len(state); i++ {
state[i].Add(state[i], c[it+i])
}
}
// mix returns [[matrix]] * [vector]
func mix(state []*ff.Element, t int, m [][]*ff.Element) []*ff.Element {
mul := zero()
newState := make([]*ff.Element, t)
for i := 0; i < t; i++ {
newState[i] = zero()
}
for i := 0; i < len(state); i++ {
newState[i].SetUint64(0)
for j := 0; j < len(state); j++ {
mul.Mul(m[j][i], state[j])
newState[i].Add(newState[i], mul)
}
}
return newState
}
// Hash computes the Poseidon hash for the given inputs
func Hash(inpBI []*big.Int) (*big.Int, error) {
t := len(inpBI) + 1
if len(inpBI) == 0 || len(inpBI) > len(NROUNDSP) {
return nil, fmt.Errorf("invalid inputs length %d, max %d", len(inpBI), len(NROUNDSP))
}
if !utils.CheckBigIntArrayInField(inpBI) {
return nil, errors.New("inputs values not inside Finite Field")
}
inp := utils.BigIntArrayToElementArray(inpBI)
nRoundsF := NROUNDSF
nRoundsP := NROUNDSP[t-2]
C := c.c[t-2]
S := c.s[t-2]
M := c.m[t-2]
P := c.p[t-2]
state := make([]*ff.Element, t)
state[0] = zero()
copy(state[1:], inp)
ark(state, C, 0)
for i := 0; i < nRoundsF/2-1; i++ {
exp5state(state)
ark(state, C, (i+1)*t)
state = mix(state, t, M)
}
exp5state(state)
ark(state, C, (nRoundsF/2)*t)
state = mix(state, t, P)
mul := zero()
for i := 0; i < nRoundsP; i++ {
exp5(state[0])
state[0].Add(state[0], C[(nRoundsF/2+1)*t+i])
mul.SetZero()
newState0 := zero()
for j := 0; j < len(state); j++ {
mul.Mul(S[(t*2-1)*i+j], state[j])
newState0.Add(newState0, mul)
}
for k := 1; k < t; k++ {
mul.SetZero()
state[k] = state[k].Add(state[k], mul.Mul(state[0], S[(t*2-1)*i+t+k-1]))
}
state[0] = newState0
}
for i := 0; i < nRoundsF/2-1; i++ {
exp5state(state)
ark(state, C, (nRoundsF/2+1)*t+nRoundsP+i*t)
state = mix(state, t, M)
}
exp5state(state)
state = mix(state, t, M)
rE := state[0]
r := big.NewInt(0)
rE.ToBigIntRegular(r)
return r, nil
}
// HashBytes returns a sponge hash of a msg byte slice split into blocks of 31 bytes
func HashBytes(msg []byte) (*big.Int, error) {
return HashBytesX(msg, spongeInputs)
}
// HashBytesX returns a sponge hash of a msg byte slice split into blocks of 31 bytes
func HashBytesX(msg []byte, frameSize int) (*big.Int, error) {
if frameSize < 2 || frameSize > 16 {
return nil, errors.New("incorrect frame size")
}
// not used inputs default to zero
inputs := make([]*big.Int, frameSize)
for j := 0; j < frameSize; j++ {
inputs[j] = new(big.Int)
}
dirty := false
var hash *big.Int
var err error
k := 0
for i := 0; i < len(msg)/spongeChunkSize; i++ {
dirty = true
inputs[k].SetBytes(msg[spongeChunkSize*i : spongeChunkSize*(i+1)])
if k == frameSize-1 {
hash, err = Hash(inputs)
dirty = false
if err != nil {
return nil, err
}
inputs = make([]*big.Int, frameSize)
inputs[0] = hash
for j := 1; j < frameSize; j++ {
inputs[j] = new(big.Int)
}
k = 1
} else {
k++
}
}
if len(msg)%spongeChunkSize != 0 {
// the last chunk of the message is less than 31 bytes
// zero padding it, so that 0xdeadbeaf becomes
// 0xdeadbeaf000000000000000000000000000000000000000000000000000000
var buf [spongeChunkSize]byte
copy(buf[:], msg[(len(msg)/spongeChunkSize)*spongeChunkSize:])
inputs[k] = new(big.Int).SetBytes(buf[:])
dirty = true
}
if dirty {
// we haven't hashed something in the main sponge loop and need to do hash here
hash, err = Hash(inputs)
if err != nil {
return nil, err
}
}
return hash, nil
}
// SpongeHash returns a sponge hash of inputs (using Poseidon with frame size of 16 inputs)
func SpongeHash(inputs []*big.Int) (*big.Int, error) {
return SpongeHashX(inputs, spongeInputs)
}
// SpongeHashX returns a sponge hash of inputs using Poseidon with configurable frame size
func SpongeHashX(inputs []*big.Int, frameSize int) (*big.Int, error) {
if frameSize < 2 || frameSize > 16 {
return nil, errors.New("incorrect frame size")
}
// not used frame default to zero
frame := make([]*big.Int, frameSize)
for j := 0; j < frameSize; j++ {
frame[j] = new(big.Int)
}
dirty := false
var hash *big.Int
var err error
k := 0
for i := 0; i < len(inputs); i++ {
dirty = true
frame[k] = inputs[i]
if k == frameSize-1 {
hash, err = Hash(frame)
dirty = false
if err != nil {
return nil, err
}
frame = make([]*big.Int, frameSize)
frame[0] = hash
for j := 1; j < frameSize; j++ {
frame[j] = new(big.Int)
}
k = 1
} else {
k++
}
}
if dirty {
// we haven't hashed something in the main sponge loop and need to do hash here
hash, err = Hash(frame)
if err != nil {
return nil, err
}
}
return hash, nil
}