diff --git a/babyjub/babyjub_wrapper.go b/babyjub/babyjub_wrapper.go index 6bd6a20..66e40ee 100644 --- a/babyjub/babyjub_wrapper.go +++ b/babyjub/babyjub_wrapper.go @@ -62,7 +62,10 @@ func (w *BjjWrappedPrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.S } digestBI := big.NewInt(0).SetBytes(digest) - sig := w.privKey.SignPoseidon(digestBI) + sig, err := w.privKey.SignPoseidon(digestBI) + if err != nil { + return nil, err + } return sig.Compress().MarshalText() } diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go index 60035e8..af77929 100644 --- a/babyjub/eddsa.go +++ b/babyjub/eddsa.go @@ -243,7 +243,7 @@ func (s Signature) Value() (driver.Value, error) { // SignMimc7 signs a message encoded as a big.Int in Zq using blake-512 hash // for buffer hashing and mimc7 for big.Int hashing. -func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { +func (k *PrivateKey) SignMimc7(msg *big.Int) (*Signature, error) { h1 := Blake512(k[:]) msgBuf := utils.BigIntLEBytes(msg) msgBuf32 := [32]byte{} @@ -256,23 +256,23 @@ func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg} hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { - panic(err) + return nil, err } S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) S = S.Mul(hm, S) S.Add(r, S) S.Mod(S, SubOrder) // S = r + hm * 8 * s - return &Signature{R8: R8, S: S} + return &Signature{R8: R8, S: S}, nil } // VerifyMimc7 verifies the signature of a message encoded as a big.Int in Zq // using blake-512 hash for buffer hashing and mimc7 for big.Int hashing. -func (pk *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { +func (pk *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) error { hmInput := []*big.Int{sig.R8.X, sig.R8.Y, pk.X, pk.Y, msg} hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { - return false + return err } left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B @@ -282,12 +282,15 @@ func (pk *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { rightProj := right.Projective() rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A right = rightProj.Affine() - return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) + if (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) { + return nil + } + return fmt.Errorf("verifyMimc7 failed") } // SignPoseidon signs a message encoded as a big.Int in Zq using blake-512 hash // for buffer hashing and Poseidon for big.Int hashing. -func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { +func (k *PrivateKey) SignPoseidon(msg *big.Int) (*Signature, error) { h1 := Blake512(k[:]) msgBuf := utils.BigIntLEBytes(msg) msgBuf32 := [32]byte{} @@ -301,7 +304,7 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg} hm, err := poseidon.Hash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { - panic(err) + return nil, err } S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) @@ -309,16 +312,16 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { S.Add(r, S) S.Mod(S, SubOrder) // S = r + hm * 8 * s - return &Signature{R8: R8, S: S} + return &Signature{R8: R8, S: S}, nil } // VerifyPoseidon verifies the signature of a message encoded as a big.Int in Zq // using blake-512 hash for buffer hashing and Poseidon for big.Int hashing. -func (pk *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool { +func (pk *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) error { hmInput := []*big.Int{sig.R8.X, sig.R8.Y, pk.X, pk.Y, msg} hm, err := poseidon.Hash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) if err != nil { - return false + return err } left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B @@ -328,7 +331,10 @@ func (pk *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool { rightProj := right.Projective() rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A right = rightProj.Affine() - return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) + if (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) { + return nil + } + return fmt.Errorf("verifyPoseidon failed") } // Scan implements Scanner for database/sql. diff --git a/babyjub/eddsa_test.go b/babyjub/eddsa_test.go index f2c7964..e048273 100644 --- a/babyjub/eddsa_test.go +++ b/babyjub/eddsa_test.go @@ -43,7 +43,8 @@ func TestSignVerifyMimc7(t *testing.T) { "13622229784656158136036771217484571176836296686641868549125388198837476602820", pk.Y.String()) - sig := k.SignMimc7(msg) + sig, err := k.SignMimc7(msg) + assert.NoError(t, err) assert.Equal(t, "11384336176656855268977457483345535180380036354188103142384839473266348197733", sig.R8.X.String()) @@ -54,20 +55,20 @@ func TestSignVerifyMimc7(t *testing.T) { "2523202440825208709475937830811065542425109372212752003460238913256192595070", sig.S.String()) - ok := pk.VerifyMimc7(msg, sig) - assert.Equal(t, true, ok) + err = pk.VerifyMimc7(msg, sig) + assert.NoError(t, err) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) - assert.Equal(t, nil, err) + assert.NoError(t, err) assert.Equal(t, ""+ "dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+ "7ed40dab29bf993c928e789d007387998901a24913d44fddb64b1f21fc149405", hex.EncodeToString(sigBuf[:])) - ok = pk.VerifyMimc7(msg, sig2) - assert.Equal(t, true, ok) + err = pk.VerifyMimc7(msg, sig2) + assert.NoError(t, err) } func TestSignVerifyPoseidon(t *testing.T) { @@ -89,7 +90,8 @@ func TestSignVerifyPoseidon(t *testing.T) { "13622229784656158136036771217484571176836296686641868549125388198837476602820", pk.Y.String()) - sig := k.SignPoseidon(msg) + sig, err := k.SignPoseidon(msg) + assert.NoError(t, err) assert.Equal(t, "11384336176656855268977457483345535180380036354188103142384839473266348197733", sig.R8.X.String()) @@ -100,20 +102,20 @@ func TestSignVerifyPoseidon(t *testing.T) { "1672775540645840396591609181675628451599263765380031905495115170613215233181", sig.S.String()) - ok := pk.VerifyPoseidon(msg, sig) - assert.Equal(t, true, ok) + err = pk.VerifyPoseidon(msg, sig) + assert.NoError(t, err) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) - assert.Equal(t, nil, err) + assert.NoError(t, err) assert.Equal(t, ""+ "dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+ "9d043ece562a8f82bfc0adb640c0107a7d3a27c1c7c1a6179a0da73de5c1b203", hex.EncodeToString(sigBuf[:])) - ok = pk.VerifyPoseidon(msg, sig2) - assert.Equal(t, true, ok) + err = pk.VerifyPoseidon(msg, sig2) + assert.NoError(t, err) } func TestCompressDecompress(t *testing.T) { @@ -128,22 +130,28 @@ func TestCompressDecompress(t *testing.T) { panic(err) } msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) - sig := k.SignMimc7(msg) + sig, err := k.SignMimc7(msg) + assert.NoError(t, err) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) - assert.Equal(t, nil, err) - ok := pk.VerifyMimc7(msg, sig2) - assert.Equal(t, true, ok) + assert.NoError(t, err) + err = pk.VerifyMimc7(msg, sig2) + assert.NoError(t, err) } } func TestSignatureCompScannerValuer(t *testing.T) { privK := NewRandPrivKey() + var err error + sig, err := privK.SignPoseidon(big.NewInt(674238462)) + assert.NoError(t, err) var value driver.Valuer //nolint:gosimple // this is done to ensure interface compatibility - value = privK.SignPoseidon(big.NewInt(674238462)).Compress() - scan := privK.SignPoseidon(big.NewInt(1)).Compress() + value = sig.Compress() + sig, err = privK.SignPoseidon(big.NewInt(1)) + assert.NoError(t, err) + scan := sig.Compress() fromDB, err := value.Value() - assert.Nil(t, err) + assert.NoError(t, err) assert.Nil(t, scan.Scan(fromDB)) assert.Equal(t, value, scan) } @@ -152,10 +160,13 @@ func TestSignatureScannerValuer(t *testing.T) { privK := NewRandPrivKey() var value driver.Valuer var scan sql.Scanner - value = privK.SignPoseidon(big.NewInt(674238462)) - scan = privK.SignPoseidon(big.NewInt(1)) + var err error + value, err = privK.SignPoseidon(big.NewInt(674238462)) + assert.NoError(t, err) + scan, err = privK.SignPoseidon(big.NewInt(1)) + assert.NoError(t, err) fromDB, err := value.Value() - assert.Nil(t, err) + assert.NoError(t, err) assert.Nil(t, scan.Scan(fromDB)) assert.Equal(t, value, scan) } @@ -217,12 +228,12 @@ func BenchmarkBabyjubEddsa(b *testing.B) { }) for i := 0; i < n; i++ { - sigs[i%n] = k.SignMimc7(msgs[i%n]) + sigs[i%n], _ = k.SignMimc7(msgs[i%n]) } b.Run("VerifyMimc7", func(b *testing.B) { for i := 0; i < b.N; i++ { - pk.VerifyMimc7(msgs[i%n], sigs[i%n]) + _ = pk.VerifyMimc7(msgs[i%n], sigs[i%n]) } }) @@ -233,12 +244,12 @@ func BenchmarkBabyjubEddsa(b *testing.B) { }) for i := 0; i < n; i++ { - sigs[i%n] = k.SignPoseidon(msgs[i%n]) + sigs[i%n], _ = k.SignPoseidon(msgs[i%n]) } b.Run("VerifyPoseidon", func(b *testing.B) { for i := 0; i < b.N; i++ { - pk.VerifyPoseidon(msgs[i%n], sigs[i%n]) + _ = pk.VerifyPoseidon(msgs[i%n], sigs[i%n]) } }) }