You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 lines
6.8 KiB

  1. package kzg
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "fmt"
  6. "math/big"
  7. "strconv"
  8. bn256 "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare"
  9. )
  10. // Q is the order of the integer field (Zq) that fits inside the snark
  11. var Q, _ = new(big.Int).SetString(
  12. "21888242871839275222246405745257275088696311157297823662689037894645226208583", 10)
  13. // R is the mod of the finite field
  14. var R, _ = new(big.Int).SetString(
  15. "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10)
  16. func randBigInt() (*big.Int, error) {
  17. maxbits := R.BitLen()
  18. b := make([]byte, (maxbits/8)-1)
  19. _, err := rand.Read(b)
  20. if err != nil {
  21. return nil, err
  22. }
  23. r := new(big.Int).SetBytes(b)
  24. rq := new(big.Int).Mod(r, R)
  25. return rq, nil
  26. }
  27. func arrayOfZeroes(n int) []*big.Int {
  28. r := make([]*big.Int, n)
  29. for i := 0; i < n; i++ {
  30. r[i] = new(big.Int).SetInt64(0)
  31. }
  32. return r[:]
  33. }
  34. //nolint:deadcode,unused
  35. func arrayOfZeroesG1(n int) []*bn256.G1 {
  36. r := make([]*bn256.G1, n)
  37. for i := 0; i < n; i++ {
  38. r[i] = new(bn256.G1).ScalarBaseMult(big.NewInt(0))
  39. }
  40. return r[:]
  41. }
  42. //nolint:deadcode,unused
  43. func arrayOfZeroesG2(n int) []*bn256.G2 {
  44. r := make([]*bn256.G2, n)
  45. for i := 0; i < n; i++ {
  46. r[i] = new(bn256.G2).ScalarBaseMult(big.NewInt(0))
  47. }
  48. return r[:]
  49. }
  50. func compareBigIntArray(a, b []*big.Int) bool {
  51. if len(a) != len(b) {
  52. return false
  53. }
  54. for i := 0; i < len(a); i++ {
  55. if a[i] != b[i] {
  56. return false
  57. }
  58. }
  59. return true
  60. }
  61. //nolint:deadcode,unused
  62. func checkArrayOfZeroes(a []*big.Int) bool {
  63. z := arrayOfZeroes(len(a))
  64. return compareBigIntArray(a, z)
  65. }
  66. func fAdd(a, b *big.Int) *big.Int {
  67. ab := new(big.Int).Add(a, b)
  68. return ab.Mod(ab, R)
  69. }
  70. func fSub(a, b *big.Int) *big.Int {
  71. ab := new(big.Int).Sub(a, b)
  72. return new(big.Int).Mod(ab, R)
  73. }
  74. func fMul(a, b *big.Int) *big.Int {
  75. ab := new(big.Int).Mul(a, b)
  76. return ab.Mod(ab, R)
  77. }
  78. func fDiv(a, b *big.Int) *big.Int {
  79. ab := new(big.Int).Mul(a, new(big.Int).ModInverse(b, R))
  80. return new(big.Int).Mod(ab, R)
  81. }
  82. func fNeg(a *big.Int) *big.Int {
  83. return new(big.Int).Mod(new(big.Int).Neg(a), R)
  84. }
  85. //nolint:deadcode,unused
  86. func fInv(a *big.Int) *big.Int {
  87. return new(big.Int).ModInverse(a, R)
  88. }
  89. func fExp(base *big.Int, e *big.Int) *big.Int {
  90. res := big.NewInt(1)
  91. rem := new(big.Int).Set(e)
  92. exp := base
  93. for !bytes.Equal(rem.Bytes(), big.NewInt(int64(0)).Bytes()) {
  94. // if BigIsOdd(rem) {
  95. if rem.Bit(0) == 1 { // .Bit(0) returns 1 when is odd
  96. res = fMul(res, exp)
  97. }
  98. exp = fMul(exp, exp)
  99. rem.Rsh(rem, 1)
  100. }
  101. return res
  102. }
  103. func max(a, b int) int {
  104. if a > b {
  105. return a
  106. }
  107. return b
  108. }
  109. func polynomialAdd(a, b []*big.Int) []*big.Int {
  110. r := arrayOfZeroes(max(len(a), len(b)))
  111. for i := 0; i < len(a); i++ {
  112. r[i] = fAdd(r[i], a[i])
  113. }
  114. for i := 0; i < len(b); i++ {
  115. r[i] = fAdd(r[i], b[i])
  116. }
  117. return r
  118. }
  119. func polynomialSub(a, b []*big.Int) []*big.Int {
  120. r := arrayOfZeroes(max(len(a), len(b)))
  121. for i := 0; i < len(a); i++ {
  122. r[i] = fAdd(r[i], a[i])
  123. }
  124. for i := 0; i < len(b); i++ {
  125. r[i] = fSub(r[i], b[i])
  126. }
  127. return r
  128. }
  129. func polynomialMul(a, b []*big.Int) []*big.Int {
  130. r := arrayOfZeroes(len(a) + len(b) - 1)
  131. for i := 0; i < len(a); i++ {
  132. for j := 0; j < len(b); j++ {
  133. r[i+j] = fAdd(r[i+j], fMul(a[i], b[j]))
  134. }
  135. }
  136. return r
  137. }
  138. func polynomialDiv(a, b []*big.Int) ([]*big.Int, []*big.Int) {
  139. // https://en.wikipedia.org/wiki/Division_algorithm
  140. r := arrayOfZeroes(len(a) - len(b) + 1)
  141. rem := a
  142. for len(rem) >= len(b) {
  143. l := fDiv(rem[len(rem)-1], b[len(b)-1])
  144. pos := len(rem) - len(b)
  145. r[pos] = l
  146. aux := arrayOfZeroes(pos)
  147. aux1 := append(aux, l)
  148. aux2 := polynomialSub(rem, polynomialMul(b, aux1))
  149. rem = aux2[:len(aux2)-1]
  150. }
  151. return r, rem
  152. }
  153. func polynomialMulByConstant(a []*big.Int, c *big.Int) []*big.Int {
  154. for i := 0; i < len(a); i++ {
  155. a[i] = fMul(a[i], c)
  156. }
  157. return a
  158. }
  159. func polynomialDivByConstant(a []*big.Int, c *big.Int) []*big.Int {
  160. for i := 0; i < len(a); i++ {
  161. a[i] = fDiv(a[i], c)
  162. }
  163. return a
  164. }
  165. // polynomialEval evaluates the polinomial over the Finite Field at the given value x
  166. func polynomialEval(p []*big.Int, x *big.Int) *big.Int {
  167. r := big.NewInt(int64(0))
  168. for i := 0; i < len(p); i++ {
  169. xi := fExp(x, big.NewInt(int64(i)))
  170. elem := fMul(p[i], xi)
  171. r = fAdd(r, elem)
  172. }
  173. return r
  174. }
  175. // newPolZeroAt generates a new polynomial that has value zero at the given value
  176. func newPolZeroAt(pointPos, totalPoints int, height *big.Int) []*big.Int {
  177. fac := 1
  178. for i := 1; i < totalPoints+1; i++ {
  179. if i != pointPos {
  180. fac = fac * (pointPos - i)
  181. }
  182. }
  183. facBig := big.NewInt(int64(fac))
  184. hf := fDiv(height, facBig)
  185. r := []*big.Int{hf}
  186. for i := 1; i < totalPoints+1; i++ {
  187. if i != pointPos {
  188. ineg := big.NewInt(int64(-i))
  189. b1 := big.NewInt(int64(1))
  190. r = polynomialMul(r, []*big.Int{ineg, b1})
  191. }
  192. }
  193. return r
  194. }
  195. // zeroPolynomial returns the zero polynomial:
  196. // z(x) = (x - z_0) (x - z_1) ... (x - z_{k-1})
  197. func zeroPolynomial(zs []*big.Int) []*big.Int {
  198. z := []*big.Int{fNeg(zs[0]), big.NewInt(1)} // (x - z0)
  199. for i := 1; i < len(zs); i++ {
  200. z = polynomialMul(z, []*big.Int{fNeg(zs[i]), big.NewInt(1)}) // (x - zi)
  201. }
  202. return z
  203. }
  204. var sNums = map[string]string{
  205. "0": "⁰",
  206. "1": "¹",
  207. "2": "²",
  208. "3": "³",
  209. "4": "⁴",
  210. "5": "⁵",
  211. "6": "⁶",
  212. "7": "⁷",
  213. "8": "⁸",
  214. "9": "⁹",
  215. }
  216. func intToSNum(n int) string {
  217. s := strconv.Itoa(n)
  218. sN := ""
  219. for i := 0; i < len(s); i++ {
  220. sN += sNums[string(s[i])]
  221. }
  222. return sN
  223. }
  224. // PolynomialToString converts a polynomial represented by a *big.Int array,
  225. // into its string human readable representation
  226. func PolynomialToString(p []*big.Int) string {
  227. s := ""
  228. for i := len(p) - 1; i >= 1; i-- {
  229. if bytes.Equal(p[i].Bytes(), big.NewInt(1).Bytes()) {
  230. s += fmt.Sprintf("x%s + ", intToSNum(i))
  231. } else if !bytes.Equal(p[i].Bytes(), big.NewInt(0).Bytes()) {
  232. s += fmt.Sprintf("%sx%s + ", p[i], intToSNum(i))
  233. }
  234. }
  235. s += p[0].String()
  236. return s
  237. }
  238. // LagrangeInterpolation implements the Lagrange interpolation:
  239. // https://en.wikipedia.org/wiki/Lagrange_polynomial
  240. func LagrangeInterpolation(x, y []*big.Int) ([]*big.Int, error) {
  241. // p(x) will be the interpoled polynomial
  242. // var p []*big.Int
  243. if len(x) != len(y) {
  244. return nil, fmt.Errorf("len(x)!=len(y): %d, %d", len(x), len(y))
  245. }
  246. p := arrayOfZeroes(len(x))
  247. k := len(x)
  248. for j := 0; j < k; j++ {
  249. // jPol is the Lagrange basis polynomial for each point
  250. var jPol []*big.Int
  251. for m := 0; m < k; m++ {
  252. // if x[m] == x[j] {
  253. if m == j {
  254. continue
  255. }
  256. // numerator & denominator of the current iteration
  257. num := []*big.Int{fNeg(x[m]), big.NewInt(1)} // (x^1 - x_m)
  258. den := fSub(x[j], x[m]) // x_j-x_m
  259. mPol := polynomialDivByConstant(num, den)
  260. if len(jPol) == 0 {
  261. // first j iteration
  262. jPol = mPol
  263. continue
  264. }
  265. jPol = polynomialMul(jPol, mPol)
  266. }
  267. p = polynomialAdd(p, polynomialMulByConstant(jPol, y[j]))
  268. }
  269. return p, nil
  270. }
  271. // TODO add method to 'clean' the polynomial, to remove right-zeroes