From 8b985dba31c07a4962c0953d812144f59785bf99 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Sat, 13 Mar 2021 09:09:05 +0100 Subject: [PATCH] Add Point Compression & Decompression methods --- blindsecp256k1.go | 74 ++++++++++++++++++++++++++++++++++++++++-- blindsecp256k1_test.go | 46 ++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 2 deletions(-) diff --git a/blindsecp256k1.go b/blindsecp256k1.go index 053b3da..05c6342 100644 --- a/blindsecp256k1.go +++ b/blindsecp256k1.go @@ -27,6 +27,18 @@ import ( // ) var ( + zero *big.Int = big.NewInt(0) + + // B (from y^2 = x^3 + B) + B *big.Int = btcec.S256().B + + // P represents the secp256k1 finite field + P *big.Int = btcec.S256().P + + // Q = (P+1)/4 + Q = new(big.Int).Div(new(big.Int).Add(P, + big.NewInt(1)), big.NewInt(4)) // nolint:gomnd + // G represents the base point of secp256k1 G *Point = &Point{ X: btcec.S256().Gx, @@ -35,8 +47,6 @@ var ( // N represents the order of G of secp256k1 N *big.Int = btcec.S256().N - - zero *big.Int = big.NewInt(0) ) // Point represents a point on the secp256k1 curve @@ -76,6 +86,66 @@ func (p *Point) isValid() error { return nil } +// Compress packs a Point to a byte array of 33 bytes +func (p *Point) Compress() [33]byte { + xBytes := p.X.Bytes() + odd := byte(0) + if isOdd(p.Y) { + odd = byte(1) + } + var b [33]byte + copy(b[32-len(xBytes):32], xBytes) + b[32] = odd + return b +} + +func isOdd(b *big.Int) bool { + return b.Bit(0) != 0 +} + +// DecompressPoint unpacks a Point from the given byte array of 33 bytes +// https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294 +func DecompressPoint(b [33]byte) (*Point, error) { + x := new(big.Int).SetBytes(b[:32]) + var odd bool + if b[32] == byte(1) { + odd = true + } + + // secp256k1: y2 = x3+ ax2 + b (where A==0, B==7) + + // compute x^3 + B mod p + x3 := new(big.Int).Mul(x, x) + x3 = new(big.Int).Mul(x3, x) + // x3 := new(big.Int).Exp(x, big.NewInt(3), nil) + x3 = new(big.Int).Add(x3, B) + x3 = new(big.Int).Mod(x3, P) + + // sqrt mod p of x^3 + B + y := new(big.Int).ModSqrt(x3, P) + if y == nil { + return nil, fmt.Errorf("not sqrt mod of x^3") + } + if odd != isOdd(y) { + y = new(big.Int).Sub(P, y) + // TODO if needed Mod + } + + // check that y is a square root of x^3 + B + y2 := new(big.Int).Mul(y, y) + y2 = new(big.Int).Mod(y2, P) + if !bytes.Equal(y2.Bytes(), x3.Bytes()) { + return nil, fmt.Errorf("invalid square root") + } + + if odd != isOdd(y) { + return nil, fmt.Errorf("odd does not match oddness") + } + + p := &Point{X: x, Y: y} + return p, nil +} + // WIP func newRand() *big.Int { var b [32]byte diff --git a/blindsecp256k1_test.go b/blindsecp256k1_test.go index a6ff47f..44cba52 100644 --- a/blindsecp256k1_test.go +++ b/blindsecp256k1_test.go @@ -1,6 +1,7 @@ package blindsecp256k1 import ( + "encoding/hex" "math/big" "testing" @@ -84,3 +85,48 @@ func TestHashMOddBytes(t *testing.T) { // _, err = sk.BlindSign(mBlinded, k) // assert.Equal(t, "mBlinded too small", err.Error()) // } + +func TestPointCompressDecompress(t *testing.T) { + p := G + b := p.Compress() + assert.Equal(t, + "79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f8179800", + hex.EncodeToString(b[:])) + p2, err := DecompressPoint(b) + require.Nil(t, err) + assert.Equal(t, p, p2) + + for i := 2; i < 1000; i++ { + p := G.Mul(big.NewInt(int64(i))) + b := p.Compress() + assert.Equal(t, 33, len(b)) + + p2, err := DecompressPoint(b) + require.Nil(t, err) + assert.Equal(t, p, p2) + } +} + +func BenchmarkCompressDecompress(b *testing.B) { + const n = 256 + var points [n]*Point + var compPoints [n][33]byte + + for i := 0; i < n; i++ { + points[i] = G.Mul(big.NewInt(int64(i))) + } + for i := 0; i < n; i++ { + compPoints[i] = points[i].Compress() + } + + b.Run("Compress", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = points[i%n].Compress() + } + }) + b.Run("DecompressPoint", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = DecompressPoint(compPoints[i%n]) + } + }) +}