From 6ff38d47db7949ac5937fa866989e130d74b2711 Mon Sep 17 00:00:00 2001 From: Ilya Date: Thu, 18 May 2023 13:37:31 +0300 Subject: [PATCH] add frame size to poseidon hasher --- babyjub/babyjub_wrapper_test.go | 3 ++- poseidon/poseidon_wrapper.go | 19 +++++++++++++------ poseidon/poseidon_wrapper_test.go | 8 ++++---- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/babyjub/babyjub_wrapper_test.go b/babyjub/babyjub_wrapper_test.go index 623c6e7..284f7ae 100644 --- a/babyjub/babyjub_wrapper_test.go +++ b/babyjub/babyjub_wrapper_test.go @@ -29,7 +29,8 @@ func TestBjjWrappedPrivateKeyInterfaceImpl(t *testing.T) { func TestBjjWrappedPrivateKey(t *testing.T) { pk := RandomBjjWrappedKey() - hasher := poseidon.New() + hasher, err := poseidon.New(16) + require.NoError(t, err) hasher.Write([]byte("test")) digest := hasher.Sum(nil) diff --git a/poseidon/poseidon_wrapper.go b/poseidon/poseidon_wrapper.go index 462e776..411bc6a 100644 --- a/poseidon/poseidon_wrapper.go +++ b/poseidon/poseidon_wrapper.go @@ -2,25 +2,32 @@ package poseidon import ( "bytes" + "errors" "hash" ) type digest struct { - buf *bytes.Buffer + buf *bytes.Buffer + frameSize int } // NewPoseidon returns the Poseidon hash of the input bytes. +// use frame size of 16 inputs by default func NewPoseidon(b []byte) []byte { - h := New() + h, _ := New(16) h.Write(b) return h.Sum(nil) } // New returns a new hash.Hash computing the Poseidon hash. -func New() hash.Hash { - return &digest{ - buf: bytes.NewBuffer([]byte{}), +func New(frameSize int) (hash.Hash, error) { + if frameSize < 2 || frameSize > 16 { + return nil, errors.New("incorrect frame size") } + return &digest{ + buf: bytes.NewBuffer([]byte{}), + frameSize: frameSize, + }, nil } // Write (via the embedded io.Writer interface) adds more data to the running hash. @@ -30,7 +37,7 @@ func (d *digest) Write(p []byte) (n int, err error) { // Sum returns the Poseidon checksum of the data. func (d *digest) Sum(b []byte) []byte { - hahs, err := HashBytes(d.buf.Bytes()) + hahs, err := HashBytesX(d.buf.Bytes(), d.frameSize) if err != nil { panic(err) } diff --git a/poseidon/poseidon_wrapper_test.go b/poseidon/poseidon_wrapper_test.go index 4f5a588..5f93d6e 100644 --- a/poseidon/poseidon_wrapper_test.go +++ b/poseidon/poseidon_wrapper_test.go @@ -5,7 +5,6 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,12 +57,13 @@ func TestPoseidonWrapperSum(t *testing.T) { inputBytes, err := hex.DecodeString(vector.bytes) require.NoError(t, err) - hasher := New() + hasher, err := New(16) + require.NoError(t, err) hasher.Write(inputBytes) res := hasher.Sum(nil) require.NotEmpty(t, res) - assert.Equal(t, vector.expectedHash, hex.EncodeToString(res)) + require.Equal(t, vector.expectedHash, hex.EncodeToString(res)) }) } } @@ -120,7 +120,7 @@ func TestPoseidonNewPoseidon(t *testing.T) { res := NewPoseidon(inputBytes) require.NotEmpty(t, res) - assert.Equal(t, vector.expectedHash, hex.EncodeToString(res)) + require.Equal(t, vector.expectedHash, hex.EncodeToString(res)) }) } }