Browse Source

add frame size to poseidon hasher

fix/bbjj-err
Ilya 1 year ago
parent
commit
6ff38d47db
3 changed files with 19 additions and 11 deletions
  1. +2
    -1
      babyjub/babyjub_wrapper_test.go
  2. +13
    -6
      poseidon/poseidon_wrapper.go
  3. +4
    -4
      poseidon/poseidon_wrapper_test.go

+ 2
- 1
babyjub/babyjub_wrapper_test.go

@ -29,7 +29,8 @@ func TestBjjWrappedPrivateKeyInterfaceImpl(t *testing.T) {
func TestBjjWrappedPrivateKey(t *testing.T) {
pk := RandomBjjWrappedKey()
hasher := poseidon.New()
hasher, err := poseidon.New(16)
require.NoError(t, err)
hasher.Write([]byte("test"))
digest := hasher.Sum(nil)

+ 13
- 6
poseidon/poseidon_wrapper.go

@ -2,25 +2,32 @@ package poseidon
import (
"bytes"
"errors"
"hash"
)
type digest struct {
buf *bytes.Buffer
buf *bytes.Buffer
frameSize int
}
// NewPoseidon returns the Poseidon hash of the input bytes.
// use frame size of 16 inputs by default
func NewPoseidon(b []byte) []byte {
h := New()
h, _ := New(16)
h.Write(b)
return h.Sum(nil)
}
// New returns a new hash.Hash computing the Poseidon hash.
func New() hash.Hash {
return &digest{
buf: bytes.NewBuffer([]byte{}),
func New(frameSize int) (hash.Hash, error) {
if frameSize < 2 || frameSize > 16 {
return nil, errors.New("incorrect frame size")
}
return &digest{
buf: bytes.NewBuffer([]byte{}),
frameSize: frameSize,
}, nil
}
// Write (via the embedded io.Writer interface) adds more data to the running hash.
@ -30,7 +37,7 @@ func (d *digest) Write(p []byte) (n int, err error) {
// Sum returns the Poseidon checksum of the data.
func (d *digest) Sum(b []byte) []byte {
hahs, err := HashBytes(d.buf.Bytes())
hahs, err := HashBytesX(d.buf.Bytes(), d.frameSize)
if err != nil {
panic(err)
}

+ 4
- 4
poseidon/poseidon_wrapper_test.go

@ -5,7 +5,6 @@ import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -58,12 +57,13 @@ func TestPoseidonWrapperSum(t *testing.T) {
inputBytes, err := hex.DecodeString(vector.bytes)
require.NoError(t, err)
hasher := New()
hasher, err := New(16)
require.NoError(t, err)
hasher.Write(inputBytes)
res := hasher.Sum(nil)
require.NotEmpty(t, res)
assert.Equal(t, vector.expectedHash, hex.EncodeToString(res))
require.Equal(t, vector.expectedHash, hex.EncodeToString(res))
})
}
}
@ -120,7 +120,7 @@ func TestPoseidonNewPoseidon(t *testing.T) {
res := NewPoseidon(inputBytes)
require.NotEmpty(t, res)
assert.Equal(t, vector.expectedHash, hex.EncodeToString(res))
require.Equal(t, vector.expectedHash, hex.EncodeToString(res))
})
}
}

Loading…
Cancel
Save