From 8a260d66d3d8a14f6461fa7855297d5b975cf56f Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 9 Mar 2020 11:51:41 +0100 Subject: [PATCH] Add goff ff.Element to babyjubjub WIP, at this moment still does not bring much optimization --- babyjub/babyjub.go | 131 ++++++++++++++++++++++++---------------- babyjub/babyjub_test.go | 91 +++++++++++++++------------- babyjub/eddsa.go | 12 ++-- babyjub/eddsa_test.go | 4 +- ff/util.go | 8 +++ 5 files changed, 144 insertions(+), 102 deletions(-) diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go index 26fa214..444a63c 100644 --- a/babyjub/babyjub.go +++ b/babyjub/babyjub.go @@ -5,14 +5,15 @@ import ( "math/big" "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" ) // A is one of the babyjub constants. -var A *big.Int +var A *ff.Element // D is one of the babyjub constants. -var D *big.Int +var D *ff.Element // Order of the babyjub curve. var Order *big.Int @@ -27,29 +28,52 @@ var B8 *Point // init initializes global numbers and the subgroup base. func init() { - A = utils.NewIntFromString("168700") - D = utils.NewIntFromString("168696") + A = ff.NewElement().SetString("168700") + D = ff.NewElement().SetString("168696") Order = utils.NewIntFromString( "21888242871839275222246405745257275088614511777268538073601725287587578984328") SubOrder = new(big.Int).Rsh(Order, 3) B8 = NewPoint() - B8.X = utils.NewIntFromString( + B8.X = ff.NewElement().SetString( "5299619240641551281634865583518297030282874472190772894086521144482721001553") - B8.Y = utils.NewIntFromString( + B8.Y = ff.NewElement().SetString( "16950150798460657717958625567821834550301663161624707787222815936182638968203") } -// Point represents a point of the babyjub curve. -type Point struct { +// PointBI represents a point of the babyjub curve. +type PointBI struct { X *big.Int Y *big.Int } -// NewPoint creates a new Point. +type Point struct { + X *ff.Element + Y *ff.Element +} + +func PointBIToPoint(p *PointBI) *Point { + return &Point{ + X: ff.NewElement().SetBigInt(p.X), + Y: ff.NewElement().SetBigInt(p.Y), + } +} + +func PointToPointBI(p *Point) *PointBI { + return &PointBI{ + X: p.X.BigInt(), + Y: p.Y.BigInt(), + } +} + +// NewPoint creates a new PointBI. +func NewPointBI() *PointBI { + return &PointBI{X: big.NewInt(0), Y: big.NewInt(1)} +} + func NewPoint() *Point { - return &Point{X: big.NewInt(0), Y: big.NewInt(1)} + return &Point{X: ff.NewElement().SetZero(), Y: ff.NewElement().SetOne()} } // Set copies a Point c into the Point p @@ -59,44 +83,45 @@ func (p *Point) Set(c *Point) *Point { return p } +func (p *Point) Equal(q *Point) bool { + // return p.X.Cmp(q.X) == 0 && p.Y.Cmp(q.Y) == 0 + return p.X.Equal(q.X) && p.Y.Equal(q.Y) +} + // Add adds Point a and b into res func (res *Point) Add(a *Point, b *Point) *Point { // x = (a.x * b.y + b.x * a.y) * (1 + D * a.x * b.x * a.y * b.y)^-1 mod q - x1a := new(big.Int).Mul(a.X, b.Y) - x1b := new(big.Int).Mul(b.X, a.Y) + x1a := ff.NewElement().Mul(a.X, b.Y) + x1b := ff.NewElement().Mul(b.X, a.Y) x1a.Add(x1a, x1b) // x1a = a.x * b.y + b.x * a.y - x2 := new(big.Int).Set(D) + x2 := ff.NewElement().Set(D) x2.Mul(x2, a.X) x2.Mul(x2, b.X) x2.Mul(x2, a.Y) x2.Mul(x2, b.Y) - x2.Add(constants.One, x2) - x2.Mod(x2, constants.Q) - x2.ModInverse(x2, constants.Q) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1 + x2.Add(ff.NewElement().SetOne(), x2) + x2.Inverse(x2) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1 // y = (a.y * b.y - A * a.x * b.x) * (1 - D * a.x * b.x * a.y * b.y)^-1 mod q - y1a := new(big.Int).Mul(a.Y, b.Y) - y1b := new(big.Int).Set(A) + y1a := ff.NewElement().Mul(a.Y, b.Y) + y1b := ff.NewElement().Set(A) y1b.Mul(y1b, a.X) y1b.Mul(y1b, b.X) y1a.Sub(y1a, y1b) // y1a = a.y * b.y - A * a.x * b.x - y2 := new(big.Int).Set(D) + y2 := ff.NewElement().Set(D) y2.Mul(y2, a.X) y2.Mul(y2, b.X) y2.Mul(y2, a.Y) y2.Mul(y2, b.Y) - y2.Sub(constants.One, y2) - y2.Mod(y2, constants.Q) - y2.ModInverse(y2, constants.Q) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1 + y2.Sub(ff.NewElement().SetOne(), y2) + y2.Inverse(y2) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1 res.X = x1a.Mul(x1a, x2) - res.X = res.X.Mod(res.X, constants.Q) res.Y = y1a.Mul(y1a, y2) - res.Y = res.Y.Mod(res.Y, constants.Q) return res } @@ -104,8 +129,8 @@ func (res *Point) Add(a *Point, b *Point) *Point { // Mul multiplies the Point p by the scalar s and stores the result in res, // which is also returned. func (res *Point) Mul(s *big.Int, p *Point) *Point { - res.X = big.NewInt(0) - res.Y = big.NewInt(1) + res.X = ff.NewElement().SetZero() + res.Y = ff.NewElement().SetOne() exp := NewPoint().Set(p) for i := 0; i < s.BitLen(); i++ { @@ -120,25 +145,21 @@ func (res *Point) Mul(s *big.Int, p *Point) *Point { // InCurve returns true when the Point p is in the babyjub curve. func (p *Point) InCurve() bool { - x2 := new(big.Int).Set(p.X) + x2 := ff.NewElement().Set(p.X) x2.Mul(x2, x2) - x2.Mod(x2, constants.Q) - y2 := new(big.Int).Set(p.Y) + y2 := ff.NewElement().Set(p.Y) y2.Mul(y2, y2) - y2.Mod(y2, constants.Q) - a := new(big.Int).Mul(A, x2) + a := ff.NewElement().Mul(A, x2) a.Add(a, y2) - a.Mod(a, constants.Q) - b := new(big.Int).Set(D) + b := ff.NewElement().Set(D) b.Mul(b, x2) b.Mul(b, y2) - b.Add(constants.One, b) - b.Mod(b, constants.Q) + b.Add(ff.NewElement().SetOne(), b) - return a.Cmp(b) == 0 + return a.Equal(b) } // InSubGroup returns true when the Point p is in the subgroup of the babyjub @@ -148,7 +169,7 @@ func (p *Point) InSubGroup() bool { return false } res := NewPoint().Mul(SubOrder, p) - return (res.X.Cmp(constants.Zero) == 0) && (res.Y.Cmp(constants.One) == 0) + return res.X.Equal(ff.NewElement().SetZero()) && res.Y.Equal(ff.NewElement().SetOne()) } // PointCoordSign returns the sign of the curve point coordinate. It returns @@ -171,8 +192,9 @@ func PackPoint(ay *big.Int, sign bool) [32]byte { // Compress the point into a 32 byte array that contains the y coordinate in // little endian and the sign of the x coordinate. func (p *Point) Compress() [32]byte { - sign := PointCoordSign(p.X) - return PackPoint(p.Y, sign) + pBI := PointToPointBI(p) + sign := PointCoordSign(pBI.X) + return PackPoint(pBI.Y, sign) } // Decompress a compressed Point into p, and also returns the decompressed @@ -183,34 +205,37 @@ func (p *Point) Decompress(leBuf [32]byte) (*Point, error) { sign = true leBuf[31] = leBuf[31] & 0x7F } - utils.SetBigIntFromLEBytes(p.Y, leBuf[:]) - if p.Y.Cmp(constants.Q) >= 0 { + y := big.NewInt(0) + utils.SetBigIntFromLEBytes(y, leBuf[:]) + if y.Cmp(constants.Q) >= 0 { return nil, fmt.Errorf("p.y >= Q") } + p.Y = ff.NewElement().SetBigInt(y) - y2 := new(big.Int).Mul(p.Y, p.Y) - y2.Mod(y2, constants.Q) - xa := big.NewInt(1) + y2 := ff.NewElement().Mul(p.Y, p.Y) + xa := ff.NewElement().SetOne() xa.Sub(xa, y2) // xa == 1 - y^2 - xb := new(big.Int).Mul(D, y2) - xb.Mod(xb, constants.Q) + xb := ff.NewElement().Mul(D, y2) xb.Sub(A, xb) // xb = A - d * y^2 - if xb.Cmp(big.NewInt(0)) == 0 { + if xb.Equal(ff.NewElement().SetZero()) { return nil, fmt.Errorf("division by 0") } - xb.ModInverse(xb, constants.Q) + xb.Inverse(xb) p.X.Mul(xa, xb) // xa / xb - p.X.Mod(p.X, constants.Q) - noSqrt := p.X.ModSqrt(p.X, constants.Q) + + q := PointToPointBI(p) + noSqrt := q.X.ModSqrt(q.X, constants.Q) if noSqrt == nil { return nil, fmt.Errorf("x is not a square mod q") } - if (sign && !PointCoordSign(p.X)) || (!sign && PointCoordSign(p.X)) { - p.X.Mul(p.X, constants.MinusOne) + if (sign && !PointCoordSign(q.X)) || (!sign && PointCoordSign(q.X)) { + q.X.Mul(q.X, constants.MinusOne) } - p.X.Mod(p.X, constants.Q) + q.X.Mod(q.X, constants.Q) + + p = PointBIToPoint(q) return p, nil } diff --git a/babyjub/babyjub_test.go b/babyjub/babyjub_test.go index 01f8589..a667b45 100644 --- a/babyjub/babyjub_test.go +++ b/babyjub/babyjub_test.go @@ -7,13 +7,21 @@ import ( "testing" "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" "github.com/stretchr/testify/assert" ) +func zero() *ff.Element { + return ff.NewElement().SetZero() +} +func one() *ff.Element { + return ff.NewElement().SetOne() +} + func TestAdd1(t *testing.T) { - a := &Point{X: big.NewInt(0), Y: big.NewInt(1)} - b := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + a := &Point{X: zero(), Y: one()} + b := &Point{X: zero(), Y: one()} c := NewPoint().Add(a, b) // fmt.Printf("%v = 2 * %v", *c, *a) @@ -22,15 +30,15 @@ func TestAdd1(t *testing.T) { } func TestAdd2(t *testing.T) { - aX := utils.NewIntFromString( + aX := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - aY := utils.NewIntFromString( + aY := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") a := &Point{X: aX, Y: aY} - bX := utils.NewIntFromString( + bX := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - bY := utils.NewIntFromString( + bY := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") b := &Point{X: bX, Y: bY} @@ -45,15 +53,15 @@ func TestAdd2(t *testing.T) { } func TestAdd3(t *testing.T) { - aX := utils.NewIntFromString( + aX := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - aY := utils.NewIntFromString( + aY := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") a := &Point{X: aX, Y: aY} - bX := utils.NewIntFromString( + bX := ff.NewElement().SetString( "16540640123574156134436876038791482806971768689494387082833631921987005038935") - bY := utils.NewIntFromString( + bY := ff.NewElement().SetString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} @@ -68,15 +76,15 @@ func TestAdd3(t *testing.T) { } func TestAdd4(t *testing.T) { - aX := utils.NewIntFromString( + aX := ff.NewElement().SetString( "0") - aY := utils.NewIntFromString( + aY := ff.NewElement().SetString( "1") a := &Point{X: aX, Y: aY} - bX := utils.NewIntFromString( + bX := ff.NewElement().SetString( "16540640123574156134436876038791482806971768689494387082833631921987005038935") - bY := utils.NewIntFromString( + bY := ff.NewElement().SetString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} @@ -91,19 +99,19 @@ func TestAdd4(t *testing.T) { } func TestInCurve1(t *testing.T) { - p := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + p := &Point{X: zero(), Y: one()} assert.Equal(t, true, p.InCurve()) } func TestInCurve2(t *testing.T) { - p := &Point{X: big.NewInt(1), Y: big.NewInt(0)} + p := &Point{X: one(), Y: zero()} assert.Equal(t, false, p.InCurve()) } func TestMul0(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} s := utils.NewIntFromString("3") @@ -123,9 +131,9 @@ func TestMul0(t *testing.T) { } func TestMul1(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} s := utils.NewIntFromString( @@ -140,9 +148,9 @@ func TestMul1(t *testing.T) { } func TestMul2(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} s := utils.NewIntFromString( @@ -157,45 +165,45 @@ func TestMul2(t *testing.T) { } func TestInCurve3(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InCurve()) } func TestInCurve4(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InCurve()) } func TestInSubGroup1(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InSubGroup()) } func TestInSubGroup2(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InSubGroup()) } func TestCompressDecompress1(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} @@ -209,9 +217,9 @@ func TestCompressDecompress1(t *testing.T) { } func TestCompressDecompress2(t *testing.T) { - x := utils.NewIntFromString( + x := ff.NewElement().SetString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y := ff.NewElement().SetString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} @@ -230,7 +238,8 @@ func TestCompressDecompressRnd(t *testing.T) { buf := p1.Compress() p2, err := NewPoint().Decompress(buf) assert.Equal(t, nil, err) - assert.Equal(t, p1, p2) + // assert.Equal(t, p1, p2) + assert.True(t, p1.Equal(p2)) } } @@ -241,15 +250,15 @@ func BenchmarkBabyjub(b *testing.B) { var badpoints [n]*Point for i := 0; i < n; i++ { - x := new(big.Int).Rand(rnd, constants.Q) - y := new(big.Int).Rand(rnd, constants.Q) + x := ff.NewElement().SetRandom() + y := ff.NewElement().SetRandom() badpoints[i] = &Point{X: x, Y: y} } var points [n]*Point - baseX := utils.NewIntFromString( + baseX := ff.NewElement().SetString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - baseY := utils.NewIntFromString( + baseY := ff.NewElement().SetString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") base := &Point{X: baseX, Y: baseY} for i := 0; i < n; i++ { @@ -263,8 +272,8 @@ func BenchmarkBabyjub(b *testing.B) { } b.Run("AddConst", func(b *testing.B) { - p0 := &Point{X: big.NewInt(0), Y: big.NewInt(1)} - p1 := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + p0 := &Point{X: zero(), Y: one()} + p1 := &Point{X: zero(), Y: one()} p2 := NewPoint() for i := 0; i < b.N; i++ { diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go index 3093a4a..fef2922 100644 --- a/babyjub/eddsa.go +++ b/babyjub/eddsa.go @@ -180,7 +180,7 @@ func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { r.Mod(r, SubOrder) R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B A := k.Public().Point() - hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg} + hmInput := []*big.Int{R8.X.BigInt(), R8.Y.BigInt(), A.X.BigInt(), A.Y.BigInt(), msg} hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { panic(err) @@ -196,7 +196,7 @@ func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { // VerifyMimc7 verifies the signature of a message encoded as a big.Int in Zq // using blake-512 hash for buffer hashing and mimc7 for big.Int hashing. func (p *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { - hmInput := []*big.Int{sig.R8.X, sig.R8.Y, p.X, p.Y, msg} + hmInput := []*big.Int{sig.R8.X.BigInt(), sig.R8.Y.BigInt(), p.X.BigInt(), p.Y.BigInt(), msg} hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { panic(err) @@ -207,7 +207,7 @@ func (p *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { r1.Mul(r1, hm) right := NewPoint().Mul(r1, p.Point()) right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A - return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) + return left.X.Equal(right.X) && left.Y.Equal(right.Y) } // SignPoseidon signs a message encoded as a big.Int in Zq using blake-512 hash @@ -223,7 +223,7 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B A := k.Public().Point() - hmInput := [poseidon.T]*big.Int{R8.X, R8.Y, A.X, A.Y, msg, big.NewInt(int64(0))} + hmInput := [poseidon.T]*big.Int{R8.X.BigInt(), R8.Y.BigInt(), A.X.BigInt(), A.Y.BigInt(), msg, big.NewInt(int64(0))} hm, err := poseidon.PoseidonHash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { panic(err) @@ -240,7 +240,7 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { // VerifyPoseidon verifies the signature of a message encoded as a big.Int in Zq // using blake-512 hash for buffer hashing and Poseidon for big.Int hashing. func (p *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool { - hmInput := [poseidon.T]*big.Int{sig.R8.X, sig.R8.Y, p.X, p.Y, msg, big.NewInt(int64(0))} + hmInput := [poseidon.T]*big.Int{sig.R8.X.BigInt(), sig.R8.Y.BigInt(), p.X.BigInt(), p.Y.BigInt(), msg, big.NewInt(int64(0))} hm, err := poseidon.PoseidonHash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { panic(err) @@ -251,5 +251,5 @@ func (p *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool { r1.Mul(r1, hm) right := NewPoint().Mul(r1, p.Point()) right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A - return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) + return left.X.Equal(right.X) && left.Y.Equal(right.Y) } diff --git a/babyjub/eddsa_test.go b/babyjub/eddsa_test.go index b7f68bc..8065a61 100644 --- a/babyjub/eddsa_test.go +++ b/babyjub/eddsa_test.go @@ -31,8 +31,8 @@ func TestPublicKey(t *testing.T) { hex.Decode(k[:], []byte{byte(i)}) } pk := k.Public() - assert.True(t, pk.X.Cmp(constants.Q) == -1) - assert.True(t, pk.Y.Cmp(constants.Q) == -1) + assert.True(t, pk.X.BigInt().Cmp(constants.Q) == -1) + assert.True(t, pk.Y.BigInt().Cmp(constants.Q) == -1) } func TestSignVerifyMimc7(t *testing.T) { diff --git a/ff/util.go b/ff/util.go index 501b6ae..0b56871 100644 --- a/ff/util.go +++ b/ff/util.go @@ -1,5 +1,13 @@ package ff +import "math/big" + func NewElement() *Element { return &Element{} } + +func (e *Element) BigInt() *big.Int { + b := big.NewInt(0) + e.ToBigIntRegular(b) + return b +}