diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go index 9a03332..ca913be 100644 --- a/babyjub/babyjub.go +++ b/babyjub/babyjub.go @@ -5,15 +5,22 @@ import ( "math/big" "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/ff" "github.com/iden3/go-iden3-crypto/utils" ) // A is one of the babyjub constants. var A *big.Int +// Aff is A value in *ff.Element representation +var Aff *ff.Element + // D is one of the babyjub constants. var D *big.Int +// Dff is D value in *ff.Element representation +var Dff *ff.Element + // Order of the babyjub curve. var Order *big.Int @@ -29,6 +36,8 @@ var B8 *Point func init() { A = utils.NewIntFromString("168700") D = utils.NewIntFromString("168696") + Aff = ff.NewElement().SetBigInt(A) + Dff = ff.NewElement().SetBigInt(D) Order = utils.NewIntFromString( "21888242871839275222246405745257275088614511777268538073601725287587578984328") @@ -41,6 +50,70 @@ func init() { "16950150798460657717958625567821834550301663161624707787222815936182638968203") } +// PointProjective is the Point representation in projective coordinates +type PointProjective struct { + X *ff.Element + Y *ff.Element + Z *ff.Element +} + +// NewPointProjective creates a new Point in projective coordinates. +func NewPointProjective() *PointProjective { + return &PointProjective{X: ff.NewElement().SetZero(), Y: ff.NewElement().SetOne(), Z: ff.NewElement().SetOne()} +} + +// Affine returns the Point from the projective representation +func (p *PointProjective) Affine() *Point { + if p.Z.Equal(ff.NewElement().SetZero()) { + return &Point{ + X: big.NewInt(0), + Y: big.NewInt(0), + } + } + zinv := ff.NewElement().Inverse(p.Z) + x := ff.NewElement().Mul(p.X, zinv) + + y := ff.NewElement().Mul(p.Y, zinv) + xBig := big.NewInt(0) + x.ToBigIntRegular(xBig) + yBig := big.NewInt(0) + y.ToBigIntRegular(yBig) + return &Point{ + X: xBig, + Y: yBig, + } +} + +// Add computes the addition of two points in projective coordinates representation +func (res *PointProjective) Add(p *PointProjective, q *PointProjective) *PointProjective { + // add-2008-bbjlp https://hyperelliptic.org/EFD/g1p/auto-twisted-projective.html#doubling-dbl-2008-bbjlp + a := ff.NewElement().Mul(p.Z, q.Z) + b := ff.NewElement().Square(a) + c := ff.NewElement().Mul(p.X, q.X) + d := ff.NewElement().Mul(p.Y, q.Y) + e := ff.NewElement().Mul(Dff, c) + e.MulAssign(d) + f := ff.NewElement().Sub(b, e) + g := ff.NewElement().Add(b, e) + x1y1 := ff.NewElement().Add(p.X, p.Y) + x2y2 := ff.NewElement().Add(q.X, q.Y) + x3 := ff.NewElement().Mul(x1y1, x2y2) + x3.SubAssign(c) + x3.SubAssign(d) + x3.MulAssign(a) + x3.MulAssign(f) + ac := ff.NewElement().Mul(Aff, c) + y3 := ff.NewElement().Sub(d, ac) + y3.MulAssign(a) + y3.MulAssign(g) + z3 := ff.NewElement().Mul(f, g) + + res.X = x3 + res.Y = y3 + res.Z = z3 + return res +} + // Point represents a point of the babyjub curve. type Point struct { X *big.Int @@ -59,62 +132,32 @@ func (p *Point) Set(c *Point) *Point { return p } -// Add adds Point a and b into res -func (res *Point) Add(a *Point, b *Point) *Point { - // x = (a.x * b.y + b.x * a.y) * (1 + D * a.x * b.x * a.y * b.y)^-1 mod q - x1a := new(big.Int).Mul(a.X, b.Y) - x1b := new(big.Int).Mul(b.X, a.Y) - x1a.Add(x1a, x1b) // x1a = a.x * b.y + b.x * a.y - - x2 := new(big.Int).Set(D) - x2.Mul(x2, a.X) - x2.Mul(x2, b.X) - x2.Mul(x2, a.Y) - x2.Mul(x2, b.Y) - x2.Add(constants.One, x2) - x2.Mod(x2, constants.Q) - x2.ModInverse(x2, constants.Q) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1 - - // y = (a.y * b.y - A * a.x * b.x) * (1 - D * a.x * b.x * a.y * b.y)^-1 mod q - y1a := new(big.Int).Mul(a.Y, b.Y) - y1b := new(big.Int).Set(A) - y1b.Mul(y1b, a.X) - y1b.Mul(y1b, b.X) - - y1a.Sub(y1a, y1b) // y1a = a.y * b.y - A * a.x * b.x - - y2 := new(big.Int).Set(D) - y2.Mul(y2, a.X) - y2.Mul(y2, b.X) - y2.Mul(y2, a.Y) - y2.Mul(y2, b.Y) - y2.Sub(constants.One, y2) - y2.Mod(y2, constants.Q) - y2.ModInverse(y2, constants.Q) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1 - - res.X = x1a.Mul(x1a, x2) - res.X = res.X.Mod(res.X, constants.Q) - - res.Y = y1a.Mul(y1a, y2) - res.Y = res.Y.Mod(res.Y, constants.Q) - - return res +// Projective returns a PointProjective from the Point +func (p *Point) Projective() *PointProjective { + return &PointProjective{ + X: ff.NewElement().SetBigInt(p.X), + Y: ff.NewElement().SetBigInt(p.Y), + Z: ff.NewElement().SetOne(), + } } // Mul multiplies the Point p by the scalar s and stores the result in res, // which is also returned. func (res *Point) Mul(s *big.Int, p *Point) *Point { - res.X = big.NewInt(0) - res.Y = big.NewInt(1) - exp := NewPoint().Set(p) + resProj := &PointProjective{ + X: ff.NewElement().SetZero(), + Y: ff.NewElement().SetOne(), + Z: ff.NewElement().SetOne(), + } + exp := p.Projective() for i := 0; i < s.BitLen(); i++ { if s.Bit(i) == 1 { - res.Add(res, exp) + resProj.Add(resProj, exp) } - exp.Add(exp, exp) + exp = exp.Add(exp, exp) } - + res = resProj.Affine() return res } diff --git a/babyjub/babyjub_test.go b/babyjub/babyjub_test.go index cb7af64..2f31dbb 100644 --- a/babyjub/babyjub_test.go +++ b/babyjub/babyjub_test.go @@ -15,7 +15,7 @@ func TestAdd1(t *testing.T) { a := &Point{X: big.NewInt(0), Y: big.NewInt(1)} b := &Point{X: big.NewInt(0), Y: big.NewInt(1)} - c := NewPoint().Add(a, b) + c := NewPoint().Projective().Add(a.Projective(), b.Projective()) // fmt.Printf("%v = 2 * %v", *c, *a) assert.Equal(t, "0", c.X.String()) assert.Equal(t, "1", c.Y.String()) @@ -34,7 +34,7 @@ func TestAdd2(t *testing.T) { "2626589144620713026669568689430873010625803728049924121243784502389097019475") b := &Point{X: bX, Y: bY} - c := NewPoint().Add(a, b) + c := NewPoint().Projective().Add(a.Projective(), b.Projective()).Affine() // fmt.Printf("%v = 2 * %v", *c, *a) assert.Equal(t, "6890855772600357754907169075114257697580319025794532037257385534741338397365", @@ -42,6 +42,17 @@ func TestAdd2(t *testing.T) { assert.Equal(t, "4338620300185947561074059802482547481416142213883829469920100239455078257889", c.Y.String()) + + d := NewPointProjective().Add(c.Projective(), c.Projective()).Affine() + assert.Equal(t, "2f6458832049e917c95867185a96621336df33e13c98e81d1ef4928cdbb77772", hex.EncodeToString(d.X.Bytes())) + + // Projective + aP := a.Projective() + bP := b.Projective() + cP := NewPointProjective().Add(aP, bP) + c2 := cP.Affine() + assert.Equal(t, c, c2) + } func TestAdd3(t *testing.T) { @@ -57,7 +68,7 @@ func TestAdd3(t *testing.T) { "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} - c := NewPoint().Add(a, b) + c := NewPoint().Projective().Add(a.Projective(), b.Projective()).Affine() // fmt.Printf("%v = 2 * %v", *c, *a) assert.Equal(t, "7916061937171219682591368294088513039687205273691143098332585753343424131937", @@ -80,7 +91,7 @@ func TestAdd4(t *testing.T) { "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} - c := NewPoint().Add(a, b) + c := NewPoint().Projective().Add(a.Projective(), b.Projective()).Affine() // fmt.Printf("%v = 2 * %v", *c, *a) assert.Equal(t, "16540640123574156134436876038791482806971768689494387082833631921987005038935", @@ -108,8 +119,8 @@ func TestMul0(t *testing.T) { p := &Point{X: x, Y: y} s := utils.NewIntFromString("3") - r2 := NewPoint().Add(p, p) - r2 = NewPoint().Add(r2, p) + r2 := NewPoint().Projective().Add(p.Projective(), p.Projective()).Affine() + r2 = NewPoint().Projective().Add(r2.Projective(), p.Projective()).Affine() r := NewPoint().Mul(s, p) assert.Equal(t, r2.X.String(), r.X.String()) assert.Equal(t, r2.Y.String(), r.Y.String()) @@ -244,7 +255,8 @@ func TestCompressDecompressRnd(t *testing.T) { buf := p1.Compress() p2, err := NewPoint().Decompress(buf) assert.Equal(t, nil, err) - assert.Equal(t, p1, p2) + assert.Equal(t, p1.X.Bytes(), p2.X.Bytes()) + assert.Equal(t, p1.Y.Bytes(), p2.Y.Bytes()) } } @@ -261,6 +273,7 @@ func BenchmarkBabyjub(b *testing.B) { } var points [n]*Point + var pointsProj [n]*PointProjective baseX := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") baseY := utils.NewIntFromString( @@ -269,6 +282,7 @@ func BenchmarkBabyjub(b *testing.B) { for i := 0; i < n; i++ { s := new(big.Int).Rand(rnd, constants.Q) points[i] = NewPoint().Mul(s, base) + pointsProj[i] = NewPoint().Mul(s, base).Projective() } var scalars [n]*big.Int @@ -279,17 +293,19 @@ func BenchmarkBabyjub(b *testing.B) { b.Run("AddConst", func(b *testing.B) { p0 := &Point{X: big.NewInt(0), Y: big.NewInt(1)} p1 := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + p0Proj := p0.Projective() + p1Proj := p1.Projective() - p2 := NewPoint() + p2 := NewPoint().Projective() for i := 0; i < b.N; i++ { - p2.Add(p0, p1) + p2.Add(p0Proj, p1Proj) } }) b.Run("AddRnd", func(b *testing.B) { - res := NewPoint() + res := NewPoint().Projective() for i := 0; i < b.N; i++ { - res.Add(points[i%(n/2)], points[i%(n/2)+1]) + res.Add(pointsProj[i%(n/2)], pointsProj[i%(n/2)+1]) } }) diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go index 6d68e8f..b931cd6 100644 --- a/babyjub/eddsa.go +++ b/babyjub/eddsa.go @@ -236,7 +236,9 @@ func (p *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { r1 := big.NewInt(8) r1.Mul(r1, hm) right := NewPoint().Mul(r1, p.Point()) - right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A + rightProj := right.Projective() + rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A + right = rightProj.Affine() return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) } @@ -280,7 +282,9 @@ func (p *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool { r1 := big.NewInt(8) r1.Mul(r1, hm) right := NewPoint().Mul(r1, p.Point()) - right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A + rightProj := right.Projective() + rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A + right = rightProj.Affine() return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) }