diff --git a/README.md b/README.md index fdbb26c..69923ee 100644 --- a/README.md +++ b/README.md @@ -35,3 +35,10 @@ https://en.wikipedia.org/wiki/Elliptic-curve_cryptography - [x] define elliptic curve - [x] get point at X - [x] Add two points on the elliptic curve +- [x] Multiply a point n times on the elliptic curve + + +To run all tests: +``` +go test ./... -v +``` diff --git a/ecc/coord.go b/ecc/coord.go index 4ac06a0..c963171 100644 --- a/ecc/coord.go +++ b/ecc/coord.go @@ -1,6 +1,9 @@ package ecc -import "math/big" +import ( + "bytes" + "math/big" +) var ( bigZero = big.NewInt(int64(0)) @@ -13,10 +16,10 @@ type Point struct { } func (c1 *Point) Equal(c2 Point) bool { - if c1.X.Int64() != c2.X.Int64() { + if !bytes.Equal(c1.X.Bytes(), c2.X.Bytes()) { return false } - if c1.Y.Int64() != c2.Y.Int64() { + if !bytes.Equal(c1.Y.Bytes(), c2.Y.Bytes()) { return false } return true diff --git a/ecc/ecc.go b/ecc/ecc.go index e219b67..4e62e75 100644 --- a/ecc/ecc.go +++ b/ecc/ecc.go @@ -1,6 +1,7 @@ package ecc import ( + "bytes" "errors" "math/big" ) @@ -11,10 +12,7 @@ type EC struct { Q *big.Int } -/* - (y^2 = x^3 + Ax + B ) mod Q - Q: prime number -*/ +// NewEC (y^2 = x^3 + ax + b) mod q, where q is a prime number func NewEC(a, b, q int) (ec EC) { ec.A = big.NewInt(int64(a)) ec.B = big.NewInt(int64(b)) @@ -52,15 +50,41 @@ func (ec *EC) Neg(p Point) Point { // Add adds two points p1 and p2 and gets q func (ec *EC) Add(p1, p2 Point) (Point, error) { if p1.Equal(zeroPoint) { - return p2, errors.New("p1==(0, 0)") + return p2, nil } if p2.Equal(zeroPoint) { - return p1, errors.New("p1==(0, 0)") + return p1, nil } - // slope - numerator := new(big.Int).Sub(p1.Y, p2.Y) - denominator := new(big.Int).Sub(p1.X, p2.X) - s := new(big.Int).Div(numerator, denominator) + + var numerator, denominator, sRaw, s *big.Int + if bytes.Equal(p1.X.Bytes(), p2.X.Bytes()) && (!bytes.Equal(p1.Y.Bytes(), p2.Y.Bytes()) || bytes.Equal(p1.Y.Bytes(), bigZero.Bytes())) { + return zeroPoint, nil + } else if bytes.Equal(p1.X.Bytes(), p2.X.Bytes()) { + // use tangent as slope + // x^2 + x2 := new(big.Int).Mul(p1.X, p1.X) + // 3 * x^2 + x23 := new(big.Int).Mul(big.NewInt(int64(3)), x2) + // 3 * x^2 + a + numerator = new(big.Int).Add(x23, ec.A) + // 2 * y + denominator = new(big.Int).Mul(big.NewInt(int64(2)), p1.Y) + // (3 * x^2 + a) / (2 * y) mod ec.Q + denInv := new(big.Int).ModInverse(denominator, ec.Q) + sRaw = new(big.Int).Mul(numerator, denInv) + s = new(big.Int).Mod(sRaw, ec.Q) + } else { + // slope + // y0-y1 + numerator = new(big.Int).Sub(p1.Y, p2.Y) + // x0-x1 + denominator = new(big.Int).Sub(p1.X, p2.X) + // (y0-y1) / (x0-x1) mod ec.Q + denInv := new(big.Int).ModInverse(denominator, ec.Q) + sRaw = new(big.Int).Mul(numerator, denInv) + s = new(big.Int).Mod(sRaw, ec.Q) + } + // q: new point var q Point // s^2 @@ -77,7 +101,20 @@ func (ec *EC) Add(p1, p2 Point) (Point, error) { sXoX2 := new(big.Int).Mul(s, xoX2) // s(p1.X - q.X) - p1.Y sXoX2Y := new(big.Int).Sub(sXoX2, p1.Y) + // q.Y = (s(p1.X - q.X) - p1.Y) mod ec.Q q.Y = new(big.Int).Mod(sXoX2Y, ec.Q) return q, nil } + +// Mul multiplies a point n times on the elliptic curve +func (ec *EC) Mul(p Point, n int) (Point, error) { + var err error + for i := 0; i < n; i++ { + p, err = ec.Add(p, p) + if err != nil { + return zeroPoint, err + } + } + return p, nil +} diff --git a/ecc/ecc_test.go b/ecc/ecc_test.go index 1234cb8..ced18cf 100644 --- a/ecc/ecc_test.go +++ b/ecc/ecc_test.go @@ -1,7 +1,6 @@ package ecc import ( - "fmt" "math/big" "testing" ) @@ -32,29 +31,19 @@ func TestNeg(t *testing.T) { } func TestAdd(t *testing.T) { - fmt.Println("y^2 = x^3 + 7") - fmt.Print("ec: ") ec := NewEC(0, 7, 11) - fmt.Println(ec) p1, _, err := ec.At(big.NewInt(int64(7))) if err != nil { t.Errorf(err.Error()) } - fmt.Print("p1: ") - fmt.Println(p1) p2, _, err := ec.At(big.NewInt(int64(6))) if err != nil { t.Errorf(err.Error()) } - fmt.Print("p2: ") - fmt.Println(p2) - q, err := ec.Add(p1, p2) if err != nil { t.Errorf(err.Error()) } - fmt.Print("q: ") - fmt.Println(q) if !q.Equal(Point{big.NewInt(int64(2)), big.NewInt(int64(9))}) { t.Errorf("q!=(2, 9)") } @@ -69,3 +58,73 @@ func TestAdd(t *testing.T) { } } + +func TestAddSamePoint(t *testing.T) { + ec := NewEC(0, 7, 11) + p1, p1_, err := ec.At(big.NewInt(int64(4))) + if err != nil { + t.Errorf(err.Error()) + } + + q, err := ec.Add(p1, p1) + if err != nil { + t.Errorf(err.Error()) + } + if !q.Equal(Point{big.NewInt(int64(6)), big.NewInt(int64(6))}) { + t.Errorf("q!=(6, 6)") + } + + q_, err := ec.Add(p1_, p1_) + if err != nil { + t.Errorf(err.Error()) + } + if !q_.Equal(Point{big.NewInt(int64(6)), big.NewInt(int64(5))}) { + t.Errorf("q_!=(6, 5)") + } + +} + +func TestMulEqualSelfAdd(t *testing.T) { + ec := NewEC(0, 7, 11) + p1, _, err := ec.At(big.NewInt(int64(4))) + if err != nil { + t.Errorf(err.Error()) + } + p1p1, err := ec.Add(p1, p1) + if err != nil { + t.Errorf(err.Error()) + } + q, err := ec.Mul(p1, 1) + if err != nil { + t.Errorf(err.Error()) + } + if !q.Equal(p1p1) { + t.Errorf("q!=p1*p1") + } +} +func TestMul(t *testing.T) { + ec := NewEC(0, 7, 29) + p1 := Point{big.NewInt(int64(4)), big.NewInt(int64(19))} + q3, err := ec.Mul(p1, 3) + if err != nil { + t.Errorf(err.Error()) + } + if !q3.Equal(Point{big.NewInt(int64(19)), big.NewInt(int64(15))}) { + t.Errorf("q3!=(19, 15)") + } + q7, err := ec.Mul(p1, 7) + if err != nil { + t.Errorf(err.Error()) + } + if !q7.Equal(Point{big.NewInt(int64(19)), big.NewInt(int64(15))}) { + t.Errorf("q7!=(19, 15)") + } + + q8, err := ec.Mul(p1, 8) + if err != nil { + t.Errorf(err.Error()) + } + if !q8.Equal(Point{big.NewInt(int64(4)), big.NewInt(int64(19))}) { + t.Errorf("q8!=(4, 19)") + } +}