176 lines
4.0 KiB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. package poseidon
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/big"
  6. "github.com/iden3/go-iden3-crypto/ff"
  7. "github.com/iden3/go-iden3-crypto/utils"
  8. )
  9. const NROUNDSF = 8 //nolint:golint
  10. var NROUNDSP = []int{56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68} //nolint:golint
  11. const spongeChunkSize = 31
  12. const spongeInputs = 16
  13. func zero() *ff.Element {
  14. return ff.NewElement()
  15. }
  16. // exp5 performs x^5 mod p
  17. // https://eprint.iacr.org/2019/458.pdf page 8
  18. func exp5(a *ff.Element) {
  19. a.Exp(*a, big.NewInt(5)) //nolint:gomnd
  20. }
  21. // exp5state perform exp5 for whole state
  22. func exp5state(state []*ff.Element) {
  23. for i := 0; i < len(state); i++ {
  24. exp5(state[i])
  25. }
  26. }
  27. // ark computes Add-Round Key, from the paper https://eprint.iacr.org/2019/458.pdf
  28. func ark(state []*ff.Element, c []*ff.Element, it int) {
  29. for i := 0; i < len(state); i++ {
  30. state[i].Add(state[i], c[it+i])
  31. }
  32. }
  33. // mix returns [[matrix]] * [vector]
  34. func mix(state []*ff.Element, t int, m [][]*ff.Element) []*ff.Element {
  35. mul := zero()
  36. newState := make([]*ff.Element, t)
  37. for i := 0; i < t; i++ {
  38. newState[i] = zero()
  39. }
  40. for i := 0; i < len(state); i++ {
  41. newState[i].SetUint64(0)
  42. for j := 0; j < len(state); j++ {
  43. mul.Mul(m[j][i], state[j])
  44. newState[i].Add(newState[i], mul)
  45. }
  46. }
  47. return newState
  48. }
  49. // Hash computes the Poseidon hash for the given inputs
  50. func Hash(inpBI []*big.Int) (*big.Int, error) {
  51. t := len(inpBI) + 1
  52. if len(inpBI) == 0 || len(inpBI) > len(NROUNDSP) {
  53. return nil, fmt.Errorf("invalid inputs length %d, max %d", len(inpBI), len(NROUNDSP)) //nolint:gomnd,lll
  54. }
  55. if !utils.CheckBigIntArrayInField(inpBI[:]) {
  56. return nil, errors.New("inputs values not inside Finite Field")
  57. }
  58. inp := utils.BigIntArrayToElementArray(inpBI[:])
  59. nRoundsF := NROUNDSF
  60. nRoundsP := NROUNDSP[t-2]
  61. C := c.c[t-2]
  62. S := c.s[t-2]
  63. M := c.m[t-2]
  64. P := c.p[t-2]
  65. state := make([]*ff.Element, t)
  66. state[0] = zero()
  67. copy(state[1:], inp[:])
  68. ark(state, C, 0)
  69. for i := 0; i < nRoundsF/2-1; i++ {
  70. exp5state(state)
  71. ark(state, C, (i+1)*t)
  72. state = mix(state, t, M)
  73. }
  74. exp5state(state)
  75. ark(state, C, (nRoundsF/2)*t)
  76. state = mix(state, t, P)
  77. for i := 0; i < nRoundsP; i++ {
  78. exp5(state[0])
  79. state[0].Add(state[0], C[(nRoundsF/2+1)*t+i])
  80. mul := zero()
  81. newState0 := zero()
  82. for j := 0; j < len(state); j++ {
  83. mul.Mul(S[(t*2-1)*i+j], state[j])
  84. newState0.Add(newState0, mul)
  85. }
  86. for k := 1; k < t; k++ {
  87. mul = zero()
  88. state[k] = state[k].Add(state[k], mul.Mul(state[0], S[(t*2-1)*i+t+k-1]))
  89. }
  90. state[0] = newState0
  91. }
  92. for i := 0; i < nRoundsF/2-1; i++ {
  93. exp5state(state)
  94. ark(state, C, (nRoundsF/2+1)*t+nRoundsP+i*t)
  95. state = mix(state, t, M)
  96. }
  97. exp5state(state)
  98. state = mix(state, t, M)
  99. rE := state[0]
  100. r := big.NewInt(0)
  101. rE.ToBigIntRegular(r)
  102. return r, nil
  103. }
  104. // HashBytes returns a sponge hash of a msg byte slice split into blocks of 31 bytes
  105. func HashBytes(msg []byte) (*big.Int, error) {
  106. // not used inputs default to zero
  107. inputs := make([]*big.Int, spongeInputs)
  108. for j := 0; j < spongeInputs; j++ {
  109. inputs[j] = new(big.Int)
  110. }
  111. dirty := false
  112. var hash *big.Int
  113. var err error
  114. k := 0
  115. for i := 0; i < len(msg)/spongeChunkSize; i++ {
  116. dirty = true
  117. inputs[k].SetBytes(msg[spongeChunkSize*i : spongeChunkSize*(i+1)])
  118. if k == spongeInputs-1 {
  119. hash, err = Hash(inputs)
  120. dirty = false
  121. if err != nil {
  122. return nil, err
  123. }
  124. inputs = make([]*big.Int, spongeInputs)
  125. inputs[0] = hash
  126. for j := 1; j < spongeInputs; j++ {
  127. inputs[j] = new(big.Int)
  128. }
  129. k = 1
  130. } else {
  131. k++
  132. }
  133. }
  134. if len(msg)%spongeChunkSize != 0 {
  135. // the last chunk of the message is less than 31 bytes
  136. // zero padding it, so that 0xdeadbeaf becomes
  137. // 0xdeadbeaf000000000000000000000000000000000000000000000000000000
  138. var buf [spongeChunkSize]byte
  139. copy(buf[:], msg[(len(msg)/spongeChunkSize)*spongeChunkSize:])
  140. inputs[k] = new(big.Int).SetBytes(buf[:])
  141. dirty = true
  142. }
  143. if dirty {
  144. // we haven't hashed something in the main sponge loop and need to do hash here
  145. hash, err = Hash(inputs)
  146. if err != nil {
  147. return nil, err
  148. }
  149. }
  150. return hash, nil
  151. }