diff --git a/ffg/element.go b/ffg/element.go index 14190cf..1a07254 100644 --- a/ffg/element.go +++ b/ffg/element.go @@ -559,6 +559,12 @@ func (z Element) ToBigIntRegular(res *big.Int) *big.Int { return z.ToBigInt(res) } +// ToUint64Regular returns z as a uint64 in regular form +func (z Element) ToUint64Regular() uint64 { + z.FromMont() + return z[0] +} + // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. func (z *Element) Bytes() (res [Limbs * 8]byte) { diff --git a/goldenposeidon/constants.go b/goldenposeidon/constants.go index 88bfb34..2ad93de 100644 --- a/goldenposeidon/constants.go +++ b/goldenposeidon/constants.go @@ -13,7 +13,7 @@ var ( mcirc = []uint64{17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20} mdiag = []uint64{8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - c = [360]uint64{ + c = []uint64{ 0xb585f766f2144405, 0x7746a55f43921ad7, 0xb2fb0d31cee799b4, 0x0f6760a4803427d7, 0xe10d666650f4e012, 0x8cae14cb07d09bf1, 0xd438539c95f63e9f, 0xef781c7ce35b4c3d, 0xcdc4a239b0c44426, 0x277fa208bf337bff, 0xe17653a29da578a1, 0xc54302f225db2c76, @@ -115,18 +115,12 @@ func init() { C = append(C, ffg.NewElementFromUint64(c[i])) } - var mFFCirc, mFFDiag []*ffg.Element - for i := 0; i < mLen; i++ { - mFFCirc = append(mFFCirc, ffg.NewElementFromUint64(mcirc[i])) - mFFDiag = append(mFFDiag, ffg.NewElementFromUint64(mdiag[i])) - } - for i := 0; i < mLen; i++ { var row []*ffg.Element for j := 0; j < mLen; j++ { - ele := mFFCirc[(-i+j+mLen)%mLen] + ele := ffg.NewElementFromUint64(mcirc[(-i+j+mLen)%mLen]) if i == j { - ele = ele.Add(ele, mFFDiag[i]) + ele = ffg.NewElementFromUint64(mcirc[0] + mdiag[i]) } row = append(row, ele) } diff --git a/goldenposeidon/poseidon.go b/goldenposeidon/poseidon.go index a7b14cc..b64c5f2 100644 --- a/goldenposeidon/poseidon.go +++ b/goldenposeidon/poseidon.go @@ -43,7 +43,7 @@ func mix(state []*ffg.Element) []*ffg.Element { for i := 0; i < mLen; i++ { newState[i].SetUint64(0) for j := 0; j < mLen; j++ { - mul.Mul(M[j][i], state[j]) + mul.Mul(M[i][j], state[j]) newState[i].Add(newState[i], mul) } } @@ -51,12 +51,12 @@ func mix(state []*ffg.Element) []*ffg.Element { } // Hash computes the Poseidon hash for the given inputs -func Hash(inpBI []*big.Int, capBI []*big.Int) (*big.Int, error) { +func Hash(inpBI []*big.Int, capBI []*big.Int) ([CAPLEN]uint64, error) { if len(inpBI) != NROUNDSF { - return nil, fmt.Errorf("invalid inputs length %d, must be 8", len(inpBI)) + return [CAPLEN]uint64{}, fmt.Errorf("invalid inputs length %d, must be 8", len(inpBI)) } if len(capBI) != CAPLEN { - return nil, fmt.Errorf("invalid capcity length %d, must be 4", len(capBI)) + return [CAPLEN]uint64{}, fmt.Errorf("invalid capcity length %d, must be 4", len(capBI)) } state := make([]*ffg.Element, mLen) @@ -69,73 +69,15 @@ func Hash(inpBI []*big.Int, capBI []*big.Int) (*big.Int, error) { for r := 0; r < NROUNDSF+NROUNDSP; r++ { ark(state, r*mLen) + if r < NROUNDSF/2 || r >= NROUNDSF/2+NROUNDSP { exp7state(state) } else { exp7(state[0]) } - state = mix(state) - } - - r := big.NewInt(0) - for i := 0; i < CAPLEN; i++ { - res := big.NewInt(0) - state[i].ToBigIntRegular(res) - r.Add(r.Lsh(r, 64), res) - } - - return r, nil -} - -// HashBytes returns a sponge hash of a msg byte slice split into blocks of 31 bytes -func HashBytes(msg []byte) (*big.Int, error) { - // not used inputs default to zero - inputs := make([]*big.Int, spongeInputs) - for j := 0; j < spongeInputs; j++ { - inputs[j] = new(big.Int) - } - dirty := false - var hash *big.Int - var err error - - k := 0 - for i := 0; i < len(msg)/spongeChunkSize; i++ { - dirty = true - inputs[k].SetBytes(msg[spongeChunkSize*i : spongeChunkSize*(i+1)]) - if k == spongeInputs-1 { - hash, err = Hash(inputs, []*big.Int{big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)}) - dirty = false - if err != nil { - return nil, err - } - inputs = make([]*big.Int, spongeInputs) - inputs[0] = hash - for j := 1; j < spongeInputs; j++ { - inputs[j] = new(big.Int) - } - k = 1 - } else { - k++ - } - } - if len(msg)%spongeChunkSize != 0 { - // the last chunk of the message is less than 31 bytes - // zero padding it, so that 0xdeadbeaf becomes - // 0xdeadbeaf000000000000000000000000000000000000000000000000000000 - var buf [spongeChunkSize]byte - copy(buf[:], msg[(len(msg)/spongeChunkSize)*spongeChunkSize:]) - inputs[k] = new(big.Int).SetBytes(buf[:]) - dirty = true - } - - if dirty { - // we haven't hashed something in the main sponge loop and need to do hash here - hash, err = Hash(inputs, []*big.Int{big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)}) - if err != nil { - return nil, err - } + state = mix(state) } - return hash, nil + return [CAPLEN]uint64{state[0].ToUint64Regular(), state[1].ToUint64Regular(), state[2].ToUint64Regular(), state[3].ToUint64Regular()}, nil } diff --git a/goldenposeidon/poseidon_test.go b/goldenposeidon/poseidon_test.go index 31665d0..6d57423 100644 --- a/goldenposeidon/poseidon_test.go +++ b/goldenposeidon/poseidon_test.go @@ -9,10 +9,82 @@ import ( func TestPoseidonHash(t *testing.T) { b0 := big.NewInt(0) + b1 := big.NewInt(1) + b_1 := big.NewInt(-1) + bM := new(big.Int).SetUint64(18446744069414584321) h, err := Hash([]*big.Int{b0, b0, b0, b0, b0, b0, b0, b0}, []*big.Int{b0, b0, b0, b0}) assert.Nil(t, err) assert.Equal(t, - "18586133768512220936620570745912940619677854269274689475585506675881198879027", - h.String()) + [CAPLEN]uint64{ + 4330397376401421145, + 14124799381142128323, + 8742572140681234676, + 14345658006221440202, + }, h, + ) + + h, err = Hash([]*big.Int{b1, b1, b1, b1, b1, b1, b1, b1}, []*big.Int{b1, b1, b1, b1}) + assert.Nil(t, err) + assert.Equal(t, + [CAPLEN]uint64{ + 16428316519797902711, + 13351830238340666928, + 682362844289978626, + 12150588177266359240, + }, h, + ) + + h, err = Hash([]*big.Int{b1, b1, b1, b1, b1, b1, b1, b1}, []*big.Int{b1, b1, b1, b1}) + assert.Nil(t, err) + assert.Equal(t, + [CAPLEN]uint64{ + 16428316519797902711, + 13351830238340666928, + 682362844289978626, + 12150588177266359240, + }, h, + ) + + h, err = Hash([]*big.Int{b_1, b_1, b_1, b_1, b_1, b_1, b_1, b_1}, []*big.Int{b_1, b_1, b_1, b_1}) + assert.Nil(t, err) + assert.Equal(t, + [CAPLEN]uint64{ + 13691089994624172887, + 15662102337790434313, + 14940024623104903507, + 10772674582659927682, + }, h, + ) + + h, err = Hash([]*big.Int{bM, bM, bM, bM, bM, bM, bM, bM}, []*big.Int{b0, b0, b0, b0}) + assert.Nil(t, err) + assert.Equal(t, + [CAPLEN]uint64{ + 4330397376401421145, + 14124799381142128323, + 8742572140681234676, + 14345658006221440202, + }, h, + ) + + h, err = Hash([]*big.Int{ + new(big.Int).SetUint64(923978), + new(big.Int).SetUint64(235763497586), + new(big.Int).SetUint64(9827635653498), + new(big.Int).SetUint64(112870), + new(big.Int).SetUint64(289273673480943876), + new(big.Int).SetUint64(230295874986745876), + new(big.Int).SetUint64(6254867324987), + new(big.Int).SetUint64(2087), + }, []*big.Int{b0, b0, b0, b0}) + assert.Nil(t, err) + assert.Equal(t, + [CAPLEN]uint64{ + 1892171027578617759, + 984732815927439256, + 7866041765487844082, + 8161503938059336191, + }, h, + ) }