Browse Source

Add add-2008-bbjlp for point addition

Add `add-2008-bbjlp` for point addition

Benchmarks (On a Intel(R) Core(TM) i7-8705G CPU @ 3.10GHz, with 32 GB of RAM):

```
- Old:
BenchmarkBabyjub/AddConst-8              1000000              1072 ns/op
BenchmarkBabyjub/AddRnd-8                  93417             12943 ns/op
BenchmarkBabyjub/MulRnd-8                    252           4797810 ns/op
BenchmarkBabyjub/Compress-8              7291580               166 ns/op
BenchmarkBabyjub/InCurve-8                611137              1999 ns/op
BenchmarkBabyjub/InSubGroup-8             615792              2021 ns/op
BenchmarkBabyjubEddsa/SignMimc7-8            126           9358542 ns/op
BenchmarkBabyjubEddsa/VerifyMimc7-8          124           9484005 ns/op
BenchmarkBabyjubEddsa/SignPoseidon-8                 126           9486484 ns/op
BenchmarkBabyjubEddsa/VerifyPoseidon-8               126           9622807 ns/op

- With new point addition algorithm:
BenchmarkBabyjub/AddConst-8              1356836               881 ns/op
BenchmarkBabyjub/AddRnd-8                 274112              4220 ns/op
BenchmarkBabyjub/MulRnd-8                    492           2474412 ns/op
BenchmarkBabyjub/Compress-8              6964855               197 ns/op
BenchmarkBabyjub/InCurve-8                608169              2008 ns/op
BenchmarkBabyjub/InSubGroup-8             618772              1954 ns/op
BenchmarkBabyjubEddsa/SignMimc7-8            238           4962397 ns/op
BenchmarkBabyjubEddsa/VerifyMimc7-8          235           5234883 ns/op
BenchmarkBabyjubEddsa/SignPoseidon-8                 240           5028720 ns/op
BenchmarkBabyjubEddsa/VerifyPoseidon-8               243           5226654 ns/op
```

Point Addition: ~3x
Point scalar Mul: ~1.9x
Signature (poseidon): ~1.88x
Verification (poseidon): ~1.84x
feature/babyjubjub-optimization
arnaucube 4 years ago
parent
commit
aab1a681dd
3 changed files with 120 additions and 59 deletions
  1. +87
    -46
      babyjub/babyjub.go
  2. +27
    -11
      babyjub/babyjub_test.go
  3. +6
    -2
      babyjub/eddsa.go

+ 87
- 46
babyjub/babyjub.go

@ -1,6 +1,7 @@
package babyjub package babyjub
import ( import (
"bytes"
"fmt" "fmt"
"math/big" "math/big"
@ -41,6 +42,76 @@ func init() {
"16950150798460657717958625567821834550301663161624707787222815936182638968203") "16950150798460657717958625567821834550301663161624707787222815936182638968203")
} }
// PointProjective is the Point representation in projective coordinates
type PointProjective struct {
X *big.Int
Y *big.Int
Z *big.Int
}
// NewPointProjective creates a new Point in projective coordinates.
func NewPointProjective() *PointProjective {
return &PointProjective{X: big.NewInt(0), Y: big.NewInt(1), Z: big.NewInt(1)}
}
// Affine returns the Point from the projective representation
func (p *PointProjective) Affine() *Point {
if bytes.Equal(p.Z.Bytes(), big.NewInt(0).Bytes()) {
return &Point{
X: big.NewInt(0),
Y: big.NewInt(0),
}
}
zinv := new(big.Int).ModInverse(p.Z, constants.Q)
x := new(big.Int).Mul(p.X, zinv)
x.Mod(x, constants.Q)
y := new(big.Int).Mul(p.Y, zinv)
y.Mod(y, constants.Q)
return &Point{
X: x,
Y: y,
}
}
// Add computes the addition of two points in projective coordinates representation
func (res *PointProjective) Add(p *PointProjective, q *PointProjective) *PointProjective {
// add-2008-bbjlp https://hyperelliptic.org/EFD/g1p/auto-twisted-projective.html#doubling-dbl-2008-bbjlp
a := new(big.Int).Mul(p.Z, q.Z)
b := new(big.Int).Set(a)
b.Exp(b, big.NewInt(2), constants.Q)
c := new(big.Int).Mul(p.X, q.X)
c.Mod(c, constants.Q) // apply Mod to reduce number file and speed computation
d := new(big.Int).Mul(p.Y, q.Y)
d.Mod(d, constants.Q)
e := new(big.Int).Mul(D, c)
e.Mul(e, d)
e.Mod(e, constants.Q)
f := new(big.Int).Sub(b, e)
f.Mod(f, constants.Q)
g := new(big.Int).Add(b, e)
g.Mod(g, constants.Q)
x1y1 := new(big.Int).Add(p.X, p.Y)
x2y2 := new(big.Int).Add(q.X, q.Y)
x3 := new(big.Int).Mul(x1y1, x2y2)
x3.Sub(x3, c)
x3.Sub(x3, d)
x3.Mul(x3, a)
x3.Mul(x3, f)
x3.Mod(x3, constants.Q)
ac := new(big.Int).Mul(A, c)
y3 := new(big.Int).Sub(d, ac)
y3.Mul(y3, a)
y3.Mul(y3, g)
y3.Mod(y3, constants.Q)
z3 := new(big.Int).Mul(f, g)
z3.Mod(z3, constants.Q)
res.X = x3
res.Y = y3
res.Z = z3
return res
}
// Point represents a point of the babyjub curve. // Point represents a point of the babyjub curve.
type Point struct { type Point struct {
X *big.Int X *big.Int
@ -59,62 +130,32 @@ func (p *Point) Set(c *Point) *Point {
return p return p
} }
// Add adds Point a and b into res
func (res *Point) Add(a *Point, b *Point) *Point {
// x = (a.x * b.y + b.x * a.y) * (1 + D * a.x * b.x * a.y * b.y)^-1 mod q
x1a := new(big.Int).Mul(a.X, b.Y)
x1b := new(big.Int).Mul(b.X, a.Y)
x1a.Add(x1a, x1b) // x1a = a.x * b.y + b.x * a.y
x2 := new(big.Int).Set(D)
x2.Mul(x2, a.X)
x2.Mul(x2, b.X)
x2.Mul(x2, a.Y)
x2.Mul(x2, b.Y)
x2.Add(constants.One, x2)
x2.Mod(x2, constants.Q)
x2.ModInverse(x2, constants.Q) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1
// y = (a.y * b.y - A * a.x * b.x) * (1 - D * a.x * b.x * a.y * b.y)^-1 mod q
y1a := new(big.Int).Mul(a.Y, b.Y)
y1b := new(big.Int).Set(A)
y1b.Mul(y1b, a.X)
y1b.Mul(y1b, b.X)
y1a.Sub(y1a, y1b) // y1a = a.y * b.y - A * a.x * b.x
y2 := new(big.Int).Set(D)
y2.Mul(y2, a.X)
y2.Mul(y2, b.X)
y2.Mul(y2, a.Y)
y2.Mul(y2, b.Y)
y2.Sub(constants.One, y2)
y2.Mod(y2, constants.Q)
y2.ModInverse(y2, constants.Q) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1
res.X = x1a.Mul(x1a, x2)
res.X = res.X.Mod(res.X, constants.Q)
res.Y = y1a.Mul(y1a, y2)
res.Y = res.Y.Mod(res.Y, constants.Q)
return res
// Projective returns a PointProjective from the Point
func (p *Point) Projective() *PointProjective {
return &PointProjective{
X: p.X,
Y: p.Y,
Z: big.NewInt(1),
}
} }
// Mul multiplies the Point p by the scalar s and stores the result in res, // Mul multiplies the Point p by the scalar s and stores the result in res,
// which is also returned. // which is also returned.
func (res *Point) Mul(s *big.Int, p *Point) *Point { func (res *Point) Mul(s *big.Int, p *Point) *Point {
res.X = big.NewInt(0)
res.Y = big.NewInt(1)
exp := NewPoint().Set(p)
resProj := &PointProjective{
X: big.NewInt(0),
Y: big.NewInt(1),
Z: big.NewInt(1),
}
exp := p.Projective()
for i := 0; i < s.BitLen(); i++ { for i := 0; i < s.BitLen(); i++ {
if s.Bit(i) == 1 { if s.Bit(i) == 1 {
res.Add(res, exp)
resProj.Add(resProj, exp)
} }
exp.Add(exp, exp)
exp = exp.Add(exp, exp)
} }
res = resProj.Affine()
return res return res
} }

+ 27
- 11
babyjub/babyjub_test.go

@ -15,7 +15,7 @@ func TestAdd1(t *testing.T) {
a := &Point{X: big.NewInt(0), Y: big.NewInt(1)} a := &Point{X: big.NewInt(0), Y: big.NewInt(1)}
b := &Point{X: big.NewInt(0), Y: big.NewInt(1)} b := &Point{X: big.NewInt(0), Y: big.NewInt(1)}
c := NewPoint().Add(a, b)
c := NewPoint().Projective().Add(a.Projective(), b.Projective())
// fmt.Printf("%v = 2 * %v", *c, *a) // fmt.Printf("%v = 2 * %v", *c, *a)
assert.Equal(t, "0", c.X.String()) assert.Equal(t, "0", c.X.String())
assert.Equal(t, "1", c.Y.String()) assert.Equal(t, "1", c.Y.String())
@ -34,7 +34,7 @@ func TestAdd2(t *testing.T) {
"2626589144620713026669568689430873010625803728049924121243784502389097019475") "2626589144620713026669568689430873010625803728049924121243784502389097019475")
b := &Point{X: bX, Y: bY} b := &Point{X: bX, Y: bY}
c := NewPoint().Add(a, b)
c := NewPoint().Projective().Add(a.Projective(), b.Projective()).Affine()
// fmt.Printf("%v = 2 * %v", *c, *a) // fmt.Printf("%v = 2 * %v", *c, *a)
assert.Equal(t, assert.Equal(t,
"6890855772600357754907169075114257697580319025794532037257385534741338397365", "6890855772600357754907169075114257697580319025794532037257385534741338397365",
@ -42,6 +42,17 @@ func TestAdd2(t *testing.T) {
assert.Equal(t, assert.Equal(t,
"4338620300185947561074059802482547481416142213883829469920100239455078257889", "4338620300185947561074059802482547481416142213883829469920100239455078257889",
c.Y.String()) c.Y.String())
d := NewPointProjective().Add(c.Projective(), c.Projective()).Affine()
assert.Equal(t, "2f6458832049e917c95867185a96621336df33e13c98e81d1ef4928cdbb77772", hex.EncodeToString(d.X.Bytes()))
// Projective
aP := a.Projective()
bP := b.Projective()
cP := NewPointProjective().Add(aP, bP)
c2 := cP.Affine()
assert.Equal(t, c, c2)
} }
func TestAdd3(t *testing.T) { func TestAdd3(t *testing.T) {
@ -57,7 +68,7 @@ func TestAdd3(t *testing.T) {
"20819045374670962167435360035096875258406992893633759881276124905556507972311") "20819045374670962167435360035096875258406992893633759881276124905556507972311")
b := &Point{X: bX, Y: bY} b := &Point{X: bX, Y: bY}
c := NewPoint().Add(a, b)
c := NewPoint().Projective().Add(a.Projective(), b.Projective()).Affine()
// fmt.Printf("%v = 2 * %v", *c, *a) // fmt.Printf("%v = 2 * %v", *c, *a)
assert.Equal(t, assert.Equal(t,
"7916061937171219682591368294088513039687205273691143098332585753343424131937", "7916061937171219682591368294088513039687205273691143098332585753343424131937",
@ -80,7 +91,7 @@ func TestAdd4(t *testing.T) {
"20819045374670962167435360035096875258406992893633759881276124905556507972311") "20819045374670962167435360035096875258406992893633759881276124905556507972311")
b := &Point{X: bX, Y: bY} b := &Point{X: bX, Y: bY}
c := NewPoint().Add(a, b)
c := NewPoint().Projective().Add(a.Projective(), b.Projective()).Affine()
// fmt.Printf("%v = 2 * %v", *c, *a) // fmt.Printf("%v = 2 * %v", *c, *a)
assert.Equal(t, assert.Equal(t,
"16540640123574156134436876038791482806971768689494387082833631921987005038935", "16540640123574156134436876038791482806971768689494387082833631921987005038935",
@ -108,8 +119,8 @@ func TestMul0(t *testing.T) {
p := &Point{X: x, Y: y} p := &Point{X: x, Y: y}
s := utils.NewIntFromString("3") s := utils.NewIntFromString("3")
r2 := NewPoint().Add(p, p)
r2 = NewPoint().Add(r2, p)
r2 := NewPoint().Projective().Add(p.Projective(), p.Projective()).Affine()
r2 = NewPoint().Projective().Add(r2.Projective(), p.Projective()).Affine()
r := NewPoint().Mul(s, p) r := NewPoint().Mul(s, p)
assert.Equal(t, r2.X.String(), r.X.String()) assert.Equal(t, r2.X.String(), r.X.String())
assert.Equal(t, r2.Y.String(), r.Y.String()) assert.Equal(t, r2.Y.String(), r.Y.String())
@ -244,7 +255,8 @@ func TestCompressDecompressRnd(t *testing.T) {
buf := p1.Compress() buf := p1.Compress()
p2, err := NewPoint().Decompress(buf) p2, err := NewPoint().Decompress(buf)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, p1, p2)
assert.Equal(t, p1.X.Bytes(), p2.X.Bytes())
assert.Equal(t, p1.Y.Bytes(), p2.Y.Bytes())
} }
} }
@ -261,6 +273,7 @@ func BenchmarkBabyjub(b *testing.B) {
} }
var points [n]*Point var points [n]*Point
var pointsProj [n]*PointProjective
baseX := utils.NewIntFromString( baseX := utils.NewIntFromString(
"17777552123799933955779906779655732241715742912184938656739573121738514868268") "17777552123799933955779906779655732241715742912184938656739573121738514868268")
baseY := utils.NewIntFromString( baseY := utils.NewIntFromString(
@ -269,6 +282,7 @@ func BenchmarkBabyjub(b *testing.B) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
s := new(big.Int).Rand(rnd, constants.Q) s := new(big.Int).Rand(rnd, constants.Q)
points[i] = NewPoint().Mul(s, base) points[i] = NewPoint().Mul(s, base)
pointsProj[i] = NewPoint().Mul(s, base).Projective()
} }
var scalars [n]*big.Int var scalars [n]*big.Int
@ -279,17 +293,19 @@ func BenchmarkBabyjub(b *testing.B) {
b.Run("AddConst", func(b *testing.B) { b.Run("AddConst", func(b *testing.B) {
p0 := &Point{X: big.NewInt(0), Y: big.NewInt(1)} p0 := &Point{X: big.NewInt(0), Y: big.NewInt(1)}
p1 := &Point{X: big.NewInt(0), Y: big.NewInt(1)} p1 := &Point{X: big.NewInt(0), Y: big.NewInt(1)}
p0Proj := p0.Projective()
p1Proj := p1.Projective()
p2 := NewPoint()
p2 := NewPoint().Projective()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
p2.Add(p0, p1)
p2.Add(p0Proj, p1Proj)
} }
}) })
b.Run("AddRnd", func(b *testing.B) { b.Run("AddRnd", func(b *testing.B) {
res := NewPoint()
res := NewPoint().Projective()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
res.Add(points[i%(n/2)], points[i%(n/2)+1])
res.Add(pointsProj[i%(n/2)], pointsProj[i%(n/2)+1])
} }
}) })

+ 6
- 2
babyjub/eddsa.go

@ -236,7 +236,9 @@ func (p *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool {
r1 := big.NewInt(8) r1 := big.NewInt(8)
r1.Mul(r1, hm) r1.Mul(r1, hm)
right := NewPoint().Mul(r1, p.Point()) right := NewPoint().Mul(r1, p.Point())
right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A
rightProj := right.Projective()
rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A
right = rightProj.Affine()
return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0)
} }
@ -280,7 +282,9 @@ func (p *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool {
r1 := big.NewInt(8) r1 := big.NewInt(8)
r1.Mul(r1, hm) r1.Mul(r1, hm)
right := NewPoint().Mul(r1, p.Point()) right := NewPoint().Mul(r1, p.Point())
right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A
rightProj := right.Projective()
rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A
right = rightProj.Affine()
return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0)
} }

Loading…
Cancel
Save