Browse Source

ECC fixed Add, added Mul

master
arnaucode 5 years ago
parent
commit
a440bab76a
4 changed files with 130 additions and 24 deletions
  1. +7
    -0
      README.md
  2. +6
    -3
      ecc/coord.go
  3. +47
    -10
      ecc/ecc.go
  4. +70
    -11
      ecc/ecc_test.go

+ 7
- 0
README.md

@ -35,3 +35,10 @@ https://en.wikipedia.org/wiki/Elliptic-curve_cryptography
- [x] define elliptic curve - [x] define elliptic curve
- [x] get point at X - [x] get point at X
- [x] Add two points on the elliptic curve - [x] Add two points on the elliptic curve
- [x] Multiply a point n times on the elliptic curve
To run all tests:
```
go test ./... -v
```

+ 6
- 3
ecc/coord.go

@ -1,6 +1,9 @@
package ecc package ecc
import "math/big"
import (
"bytes"
"math/big"
)
var ( var (
bigZero = big.NewInt(int64(0)) bigZero = big.NewInt(int64(0))
@ -13,10 +16,10 @@ type Point struct {
} }
func (c1 *Point) Equal(c2 Point) bool { func (c1 *Point) Equal(c2 Point) bool {
if c1.X.Int64() != c2.X.Int64() {
if !bytes.Equal(c1.X.Bytes(), c2.X.Bytes()) {
return false return false
} }
if c1.Y.Int64() != c2.Y.Int64() {
if !bytes.Equal(c1.Y.Bytes(), c2.Y.Bytes()) {
return false return false
} }
return true return true

+ 47
- 10
ecc/ecc.go

@ -1,6 +1,7 @@
package ecc package ecc
import ( import (
"bytes"
"errors" "errors"
"math/big" "math/big"
) )
@ -11,10 +12,7 @@ type EC struct {
Q *big.Int Q *big.Int
} }
/*
(y^2 = x^3 + Ax + B ) mod Q
Q: prime number
*/
// NewEC (y^2 = x^3 + ax + b) mod q, where q is a prime number
func NewEC(a, b, q int) (ec EC) { func NewEC(a, b, q int) (ec EC) {
ec.A = big.NewInt(int64(a)) ec.A = big.NewInt(int64(a))
ec.B = big.NewInt(int64(b)) ec.B = big.NewInt(int64(b))
@ -52,15 +50,41 @@ func (ec *EC) Neg(p Point) Point {
// Add adds two points p1 and p2 and gets q // Add adds two points p1 and p2 and gets q
func (ec *EC) Add(p1, p2 Point) (Point, error) { func (ec *EC) Add(p1, p2 Point) (Point, error) {
if p1.Equal(zeroPoint) { if p1.Equal(zeroPoint) {
return p2, errors.New("p1==(0, 0)")
return p2, nil
} }
if p2.Equal(zeroPoint) { if p2.Equal(zeroPoint) {
return p1, errors.New("p1==(0, 0)")
return p1, nil
} }
// slope
numerator := new(big.Int).Sub(p1.Y, p2.Y)
denominator := new(big.Int).Sub(p1.X, p2.X)
s := new(big.Int).Div(numerator, denominator)
var numerator, denominator, sRaw, s *big.Int
if bytes.Equal(p1.X.Bytes(), p2.X.Bytes()) && (!bytes.Equal(p1.Y.Bytes(), p2.Y.Bytes()) || bytes.Equal(p1.Y.Bytes(), bigZero.Bytes())) {
return zeroPoint, nil
} else if bytes.Equal(p1.X.Bytes(), p2.X.Bytes()) {
// use tangent as slope
// x^2
x2 := new(big.Int).Mul(p1.X, p1.X)
// 3 * x^2
x23 := new(big.Int).Mul(big.NewInt(int64(3)), x2)
// 3 * x^2 + a
numerator = new(big.Int).Add(x23, ec.A)
// 2 * y
denominator = new(big.Int).Mul(big.NewInt(int64(2)), p1.Y)
// (3 * x^2 + a) / (2 * y) mod ec.Q
denInv := new(big.Int).ModInverse(denominator, ec.Q)
sRaw = new(big.Int).Mul(numerator, denInv)
s = new(big.Int).Mod(sRaw, ec.Q)
} else {
// slope
// y0-y1
numerator = new(big.Int).Sub(p1.Y, p2.Y)
// x0-x1
denominator = new(big.Int).Sub(p1.X, p2.X)
// (y0-y1) / (x0-x1) mod ec.Q
denInv := new(big.Int).ModInverse(denominator, ec.Q)
sRaw = new(big.Int).Mul(numerator, denInv)
s = new(big.Int).Mod(sRaw, ec.Q)
}
// q: new point // q: new point
var q Point var q Point
// s^2 // s^2
@ -77,7 +101,20 @@ func (ec *EC) Add(p1, p2 Point) (Point, error) {
sXoX2 := new(big.Int).Mul(s, xoX2) sXoX2 := new(big.Int).Mul(s, xoX2)
// s(p1.X - q.X) - p1.Y // s(p1.X - q.X) - p1.Y
sXoX2Y := new(big.Int).Sub(sXoX2, p1.Y) sXoX2Y := new(big.Int).Sub(sXoX2, p1.Y)
// q.Y = (s(p1.X - q.X) - p1.Y) mod ec.Q
q.Y = new(big.Int).Mod(sXoX2Y, ec.Q) q.Y = new(big.Int).Mod(sXoX2Y, ec.Q)
return q, nil return q, nil
} }
// Mul multiplies a point n times on the elliptic curve
func (ec *EC) Mul(p Point, n int) (Point, error) {
var err error
for i := 0; i < n; i++ {
p, err = ec.Add(p, p)
if err != nil {
return zeroPoint, err
}
}
return p, nil
}

+ 70
- 11
ecc/ecc_test.go

@ -1,7 +1,6 @@
package ecc package ecc
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
) )
@ -32,29 +31,19 @@ func TestNeg(t *testing.T) {
} }
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
fmt.Println("y^2 = x^3 + 7")
fmt.Print("ec: ")
ec := NewEC(0, 7, 11) ec := NewEC(0, 7, 11)
fmt.Println(ec)
p1, _, err := ec.At(big.NewInt(int64(7))) p1, _, err := ec.At(big.NewInt(int64(7)))
if err != nil { if err != nil {
t.Errorf(err.Error()) t.Errorf(err.Error())
} }
fmt.Print("p1: ")
fmt.Println(p1)
p2, _, err := ec.At(big.NewInt(int64(6))) p2, _, err := ec.At(big.NewInt(int64(6)))
if err != nil { if err != nil {
t.Errorf(err.Error()) t.Errorf(err.Error())
} }
fmt.Print("p2: ")
fmt.Println(p2)
q, err := ec.Add(p1, p2) q, err := ec.Add(p1, p2)
if err != nil { if err != nil {
t.Errorf(err.Error()) t.Errorf(err.Error())
} }
fmt.Print("q: ")
fmt.Println(q)
if !q.Equal(Point{big.NewInt(int64(2)), big.NewInt(int64(9))}) { if !q.Equal(Point{big.NewInt(int64(2)), big.NewInt(int64(9))}) {
t.Errorf("q!=(2, 9)") t.Errorf("q!=(2, 9)")
} }
@ -69,3 +58,73 @@ func TestAdd(t *testing.T) {
} }
} }
func TestAddSamePoint(t *testing.T) {
ec := NewEC(0, 7, 11)
p1, p1_, err := ec.At(big.NewInt(int64(4)))
if err != nil {
t.Errorf(err.Error())
}
q, err := ec.Add(p1, p1)
if err != nil {
t.Errorf(err.Error())
}
if !q.Equal(Point{big.NewInt(int64(6)), big.NewInt(int64(6))}) {
t.Errorf("q!=(6, 6)")
}
q_, err := ec.Add(p1_, p1_)
if err != nil {
t.Errorf(err.Error())
}
if !q_.Equal(Point{big.NewInt(int64(6)), big.NewInt(int64(5))}) {
t.Errorf("q_!=(6, 5)")
}
}
func TestMulEqualSelfAdd(t *testing.T) {
ec := NewEC(0, 7, 11)
p1, _, err := ec.At(big.NewInt(int64(4)))
if err != nil {
t.Errorf(err.Error())
}
p1p1, err := ec.Add(p1, p1)
if err != nil {
t.Errorf(err.Error())
}
q, err := ec.Mul(p1, 1)
if err != nil {
t.Errorf(err.Error())
}
if !q.Equal(p1p1) {
t.Errorf("q!=p1*p1")
}
}
func TestMul(t *testing.T) {
ec := NewEC(0, 7, 29)
p1 := Point{big.NewInt(int64(4)), big.NewInt(int64(19))}
q3, err := ec.Mul(p1, 3)
if err != nil {
t.Errorf(err.Error())
}
if !q3.Equal(Point{big.NewInt(int64(19)), big.NewInt(int64(15))}) {
t.Errorf("q3!=(19, 15)")
}
q7, err := ec.Mul(p1, 7)
if err != nil {
t.Errorf(err.Error())
}
if !q7.Equal(Point{big.NewInt(int64(19)), big.NewInt(int64(15))}) {
t.Errorf("q7!=(19, 15)")
}
q8, err := ec.Mul(p1, 8)
if err != nil {
t.Errorf(err.Error())
}
if !q8.Equal(Point{big.NewInt(int64(4)), big.NewInt(int64(19))}) {
t.Errorf("q8!=(4, 19)")
}
}

Loading…
Cancel
Save