From af2d0cf6c156f4ed41906449773fe34cb5a199ba Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 1 Feb 2021 19:10:03 +0100 Subject: [PATCH] WIP --- blindsecp256k1.go | 149 +++++++++++++++++++++++++++++++++++++++++ blindsecp256k1_test.go | 47 +++++++++++++ 2 files changed, 196 insertions(+) diff --git a/blindsecp256k1.go b/blindsecp256k1.go index add3d78..5b5b782 100644 --- a/blindsecp256k1.go +++ b/blindsecp256k1.go @@ -12,6 +12,7 @@ package blindsecp256k1 import ( "bytes" "crypto/rand" + "fmt" "math/big" "github.com/btcsuite/btcd/btcec" @@ -27,6 +28,13 @@ var ( // N represents the order of G of secp256k1 N *big.Int = btcec.S256().N + + // B (from y^2 = x^3 + B) + B *big.Int = btcec.S256().B + + // Q = (P+1)/4 + Q = new(big.Int).Div(new(big.Int).Add(btcec.S256().P, + big.NewInt(1)), big.NewInt(4)) ) // Point represents a point on the secp256k1 curve @@ -53,6 +61,147 @@ func (p *Point) Mul(scalar *big.Int) *Point { } } +func (p *Point) Compress() [33]byte { + xBytes := p.X.Bytes() + sign := byte(0) + if isOdd(p.Y) { + sign = byte(1) + } + var b [33]byte + copy(b[32-len(xBytes):32], xBytes) + b[32] = sign + return b +} + +func isOdd(b *big.Int) bool { + return b.Bit(0) != 0 +} + +func squareMul(r, x *big.Int, bit bool) *big.Int { + // r = new(big.Int).Mul(r, r) // r^2 + r = new(big.Int).Exp(r, big.NewInt(2), N) + if bit { + r = new(big.Int).Mul(r, x) + } + return new(big.Int).Mod(r, N) +} + +// https://en.wikipedia.org/wiki/Exponentiation_by_squaring +func sqrtQ(x *big.Int) *big.Int { + // xBytes := x.Bytes() + qBytes := Q.Bytes() + r := big.NewInt(1) + // fmt.Println(hex.EncodeToString(qBytes)) + for _, b := range qBytes { + // fmt.Printf("%d, %x %d\n", i, b, r) + // fmt.Printf("%x %s\n", b, r.String()) + switch b { + // Most common case, where all 8 bits are set. + case 0xff: + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + // fmt.Printf("%x %s\n", b, r.String()) + + // First byte of Q (0x3f), where all but the top two bits are + // set. Note that this case only applies six operations, since + // the highest bit of Q resides in bit six of the first byte. We + // ignore the first two bits, since squaring for these bits will + // result in an invalid result. We forgo squaring f before the + // first multiply, since 1^2 = 1. + case 0x3f: + r = new(big.Int).Mul(r, x) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + + // Byte 28 of Q (0xbf), where only bit 7 is unset. + case 0xbf: + r = squareMul(r, x, true) + r = squareMul(r, x, false) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + + // Byte 31 of Q (0x0c), where only bits 3 and 4 are set. + default: + r = squareMul(r, x, false) + r = squareMul(r, x, false) + r = squareMul(r, x, false) + r = squareMul(r, x, false) + r = squareMul(r, x, true) + r = squareMul(r, x, true) + r = squareMul(r, x, false) + r = squareMul(r, x, false) + } + } + return r +} + +// https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294 +// func (p *Point) Decompress(b [33]byte) error { +func Decompress(b [33]byte) (*Point, error) { + fmt.Println(b) + x := new(big.Int).SetBytes(b[:32]) + fmt.Println(x) + var sign bool + if b[32] == byte(1) { + sign = true + } + + // y2 = x3+ ax2 + b (where A==0, B==7) + + // compute x^3 + B mod p + x3 := new(big.Int).Mul(x, x) + x3 = new(big.Int).Mul(x3, x) + // x3 := new(big.Int).Exp(x, big.NewInt(3), N) + x3 = new(big.Int).Add(x3, B) + x3 = new(big.Int).Mod(x3, N) + + // sqrt mod p of x^3 + B + fmt.Println("x3", x3) + y := new(big.Int).ModSqrt(x3, N) + // y := sqrtQ(x3) + if y == nil { + return nil, fmt.Errorf("not sqrt mod of x^3") + } + fmt.Println("y", y) + fmt.Println("y", new(big.Int).Sub(N, y)) + fmt.Println("y", new(big.Int).Mod(new(big.Int).Neg(y), N)) + if sign != isOdd(y) { + y = new(big.Int).Sub(N, y) + // TODO check if needed Mod + } + + // check that y is a square root of x^3 + B + y2 := new(big.Int).Mul(y, y) + y2 = new(big.Int).Mod(y2, N) + if !bytes.Equal(y2.Bytes(), x3.Bytes()) { + return nil, fmt.Errorf("invalid square root") + } + + if sign != isOdd(y) { + return nil, fmt.Errorf("sign does not match oddness") + } + + p := &Point{X: x, Y: y} + // p = &Point{} + // p.X = x + // p.Y = y + // fmt.Println("I", p.X, p.Y) + return p, nil +} + // WIP func newRand() *big.Int { var b [32]byte diff --git a/blindsecp256k1_test.go b/blindsecp256k1_test.go index 6175d36..d3dcd65 100644 --- a/blindsecp256k1_test.go +++ b/blindsecp256k1_test.go @@ -30,3 +30,50 @@ func TestFlow(t *testing.T) { verified := Verify(msg, sig, signerPubK) assert.True(t, verified) } + +// func TestPointCompressDecompress(t *testing.T) { +// // x := big.NewInt(25) +// // f := big.NewInt(1) +// // fmt.Println("f", f) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f) +// // f = squareMul(f, x, true) +// // require.Equal(t, "21684043449710088680149056017398834228515625", f.String()) +// // fmt.Println("f", f, x) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f, x) +// // require.Equal(t, "72482250313621475425650965409810619910529643899145444686122770647178269858429", f.String()) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f, x) +// // f = squareMul(f, x, true) +// // fmt.Println("f", f, x) +// +// // sqrtQ +// // r := sqrtQ(big.NewInt(25)) +// // assert.Equal(t, "115792089237316195423570985008687907853269984665640564039457584007908834671658", r.String()) +// fmt.Println(N) +// +// // +// // p := G.Mul(big.NewInt(1234)) +// p := G +// // p := &Point{ +// // X: big.NewInt(3), +// // Y: big.NewInt(3), +// // } +// fmt.Println("eX", p.X) +// fmt.Println("eY", p.Y) +// b := p.Compress() +// // fmt.Println("hex", hex.EncodeToString(b[:])) +// +// // var p2 *Point +// // err := p2.Decompress(b) +// p2, err := Decompress(b) +// require.Nil(t, err) +// assert.Equal(t, p, p2) +// }