You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

148 lines
3.5 KiB

  1. package ecc
  2. import (
  3. "bytes"
  4. "errors"
  5. "math/big"
  6. )
  7. // EC is the data structure for the elliptic curve parameters
  8. type EC struct {
  9. A *big.Int
  10. B *big.Int
  11. Q *big.Int
  12. }
  13. // NewEC (y^2 = x^3 + ax + b) mod q, where q is a prime number
  14. func NewEC(a, b, q int) (ec EC) {
  15. ec.A = big.NewInt(int64(a))
  16. ec.B = big.NewInt(int64(b))
  17. ec.Q = big.NewInt(int64(q))
  18. return ec
  19. }
  20. // At gets a point x in the curve
  21. func (ec *EC) At(x *big.Int) (Point, Point, error) {
  22. if x.Cmp(ec.Q) > 0 {
  23. return Point{}, Point{}, errors.New("x<ec.Q")
  24. }
  25. // y^2 = (x^3 + ax + b) mod q
  26. // y = sqrt (x^3 + ax + b) mod q
  27. // x^3
  28. x3 := new(big.Int).Exp(x, big.NewInt(int64(3)), nil)
  29. // a^x
  30. aX := new(big.Int).Mul(ec.A, x)
  31. // x^3 + a^x
  32. x3aX := new(big.Int).Add(x3, aX)
  33. // x^3 + a^x + b
  34. x3aXb := new(big.Int).Add(x3aX, ec.B)
  35. // y = sqrt (x^3 + ax + b) mod q
  36. y := new(big.Int).ModSqrt(x3aXb, ec.Q)
  37. return Point{x, y}, Point{x, new(big.Int).Sub(ec.Q, y)}, nil
  38. }
  39. // TODO add valid checker point function Valid()
  40. // Neg returns the inverse of the P point on the elliptic curve
  41. func (ec *EC) Neg(p Point) Point {
  42. // TODO get error when point not found on the ec
  43. return Point{p.X, new(big.Int).Sub(ec.Q, p.Y)}
  44. }
  45. // Order returns smallest n where nG = O (point at zero)
  46. func (ec *EC) Order(g Point) (int, error) {
  47. for i := 1; i < int(ec.Q.Int64())+1; i++ {
  48. mPoint, err := ec.Mul(g, i)
  49. if err != nil {
  50. return i, err
  51. }
  52. if mPoint.Equal(zeroPoint) {
  53. return i, nil
  54. }
  55. }
  56. return -1, errors.New("invalid order")
  57. }
  58. // Add adds two points p1 and p2 and gets q, returns the negate of q
  59. func (ec *EC) Add(p1, p2 Point) (Point, error) {
  60. if p1.Equal(zeroPoint) {
  61. return p2, nil
  62. }
  63. if p2.Equal(zeroPoint) {
  64. return p1, nil
  65. }
  66. var numerator, denominator, sRaw, s *big.Int
  67. if bytes.Equal(p1.X.Bytes(), p2.X.Bytes()) && (!bytes.Equal(p1.Y.Bytes(), p2.Y.Bytes()) || bytes.Equal(p1.Y.Bytes(), bigZero.Bytes())) {
  68. return zeroPoint, nil
  69. } else if bytes.Equal(p1.X.Bytes(), p2.X.Bytes()) {
  70. // use tangent as slope
  71. // x^2
  72. x2 := new(big.Int).Mul(p1.X, p1.X)
  73. // 3 * x^2
  74. x23 := new(big.Int).Mul(big.NewInt(int64(3)), x2)
  75. // 3 * x^2 + a
  76. numerator = new(big.Int).Add(x23, ec.A)
  77. // 2 * y
  78. denominator = new(big.Int).Mul(big.NewInt(int64(2)), p1.Y)
  79. // s = (3 * x^2 + a) / (2 * y) mod ec.Q
  80. denInv := new(big.Int).ModInverse(denominator, ec.Q)
  81. sRaw = new(big.Int).Mul(numerator, denInv)
  82. s = new(big.Int).Mod(sRaw, ec.Q)
  83. } else {
  84. // slope
  85. // y0-y1
  86. numerator = new(big.Int).Sub(p1.Y, p2.Y)
  87. // x0-x1
  88. denominator = new(big.Int).Sub(p1.X, p2.X)
  89. // s = (y0-y1) / (x0-x1) mod ec.Q
  90. denInv := new(big.Int).ModInverse(denominator, ec.Q)
  91. sRaw = new(big.Int).Mul(numerator, denInv)
  92. s = new(big.Int).Mod(sRaw, ec.Q)
  93. }
  94. // q: new point
  95. var q Point
  96. // s^2
  97. s2 := new(big.Int).Exp(s, big.NewInt(int64(2)), nil)
  98. // s^2 - p1.X
  99. x2Xo := new(big.Int).Sub(s2, p1.X)
  100. // s^2 - p1.X - p2.X
  101. x2XoX2 := new(big.Int).Sub(x2Xo, p2.X)
  102. q.X = new(big.Int).Mod(x2XoX2, ec.Q)
  103. // p1.X - q.X
  104. xoX2 := new(big.Int).Sub(p1.X, q.X)
  105. // s(p1.X - q.X)
  106. sXoX2 := new(big.Int).Mul(s, xoX2)
  107. // s(p1.X - q.X) - p1.Y
  108. sXoX2Y := new(big.Int).Sub(sXoX2, p1.Y)
  109. // q.Y = (s(p1.X - q.X) - p1.Y) mod ec.Q
  110. q.Y = new(big.Int).Mod(sXoX2Y, ec.Q)
  111. // negate q
  112. // q = ec.Neg(q)
  113. return q, nil
  114. }
  115. // Mul multiplies a point n times on the elliptic curve
  116. func (ec *EC) Mul(p Point, n int) (Point, error) {
  117. var err error
  118. p2 := p
  119. r := zeroPoint
  120. for 0 < n {
  121. if n&1 == 1 {
  122. r, err = ec.Add(r, p2)
  123. if err != nil {
  124. return p, err
  125. }
  126. }
  127. n = n >> 1
  128. p2, err = ec.Add(p2, p2)
  129. if err != nil {
  130. return p, err
  131. }
  132. }
  133. return r, nil
  134. }