Upgrade to gnark 0.8 (#18)

* make proof with PIS public input

* upgraded to 0.8 gnark

* reduced pow witness

* fixed bug

* fixed test

* fixed bug

* adding profiling

* changed everything to be pointers

* convert remaining poseidon constants

* added the recursive_very_small

* added more outputs for benchmark
This commit is contained in:
Kevin Jue
2023-05-25 07:39:06 -07:00
committed by GitHub
parent cf84b032e2
commit 302b5f5bf1
31 changed files with 5336 additions and 2089 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -8,11 +8,11 @@ import (
type Hash = [4]field.F
type HashAPI struct {
fieldAPI frontend.API
fieldAPI field.FieldAPI
}
func NewHashAPI(
fieldAPI frontend.API,
fieldAPI field.FieldAPI,
) *HashAPI {
return &HashAPI{
fieldAPI: fieldAPI,
@@ -22,7 +22,7 @@ func NewHashAPI(
func (h *HashAPI) SelectHash(bit frontend.Variable, leftHash, rightHash Hash) Hash {
var returnHash Hash
for i := 0; i < 4; i++ {
returnHash[i] = h.fieldAPI.Select(bit, leftHash[i], rightHash[i]).(field.F)
returnHash[i] = h.fieldAPI.Select(bit, leftHash[i], rightHash[i])
}
return returnHash
@@ -32,7 +32,7 @@ func (h *HashAPI) Lookup2Hash(b0 frontend.Variable, b1 frontend.Variable, h0, h1
var returnHash Hash
for i := 0; i < 4; i++ {
returnHash[i] = h.fieldAPI.Lookup2(b0, b1, h0[i], h1[i], h2[i], h3[i]).(field.F)
returnHash[i] = h.fieldAPI.Lookup2(b0, b1, h0[i], h1[i], h2[i], h3[i])
}
return returnHash
@@ -47,7 +47,7 @@ func (h *HashAPI) AssertIsEqualHash(h1, h2 Hash) {
func Uint64ArrayToHashArray(input [][]uint64) []Hash {
var output []Hash
for i := 0; i < len(input); i++ {
output = append(output, [4]field.F{field.NewFieldElement(input[i][0]), field.NewFieldElement(input[i][1]), field.NewFieldElement(input[i][2]), field.NewFieldElement(input[i][3])})
output = append(output, [4]field.F{field.NewFieldConst(input[i][0]), field.NewFieldConst(input[i][1]), field.NewFieldConst(input[i][2]), field.NewFieldConst(input[i][3])})
}
return output
}

View File

@@ -18,11 +18,11 @@ type PoseidonStateExtension = [SPONGE_WIDTH]field.QuadraticExtension
type PoseidonChip struct {
api frontend.API `gnark:"-"`
fieldAPI frontend.API `gnark:"-"`
fieldAPI field.FieldAPI `gnark:"-"`
qeAPI *field.QuadraticExtensionAPI `gnark:"-"`
}
func NewPoseidonChip(api frontend.API, fieldAPI frontend.API, qeAPI *field.QuadraticExtensionAPI) *PoseidonChip {
func NewPoseidonChip(api frontend.API, fieldAPI field.FieldAPI, qeAPI *field.QuadraticExtensionAPI) *PoseidonChip {
return &PoseidonChip{api: api, fieldAPI: fieldAPI, qeAPI: qeAPI}
}
@@ -86,7 +86,7 @@ func (c *PoseidonChip) PartialRounds(state PoseidonState, roundCounter *int) Pos
for i := 0; i < N_PARTIAL_ROUNDS; i++ {
state[0] = c.SBoxMonomial(state[0])
state[0] = c.fieldAPI.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i]).(field.F)
state[0] = c.fieldAPI.Add(state[0], FAST_PARTIAL_ROUND_CONSTANTS[i])
state = c.MdsPartialLayerFast(state, i)
}
@@ -98,8 +98,8 @@ func (c *PoseidonChip) PartialRounds(state PoseidonState, roundCounter *int) Pos
func (c *PoseidonChip) ConstantLayer(state PoseidonState, roundCounter *int) PoseidonState {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
roundConstant := field.NewFieldElement(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)])
state[i] = c.fieldAPI.Add(state[i], roundConstant).(field.F)
roundConstant := ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)]
state[i] = c.fieldAPI.Add(state[i], roundConstant)
}
}
return state
@@ -108,7 +108,7 @@ func (c *PoseidonChip) ConstantLayer(state PoseidonState, roundCounter *int) Pos
func (c *PoseidonChip) ConstantLayerExtension(state PoseidonStateExtension, roundCounter *int) PoseidonStateExtension {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
roundConstant := c.qeAPI.FieldToQE(field.NewFieldElement(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)]))
roundConstant := c.qeAPI.FieldToQE(ALL_ROUND_CONSTANTS[i+SPONGE_WIDTH*(*roundCounter)])
state[i] = c.qeAPI.AddExtension(state[i], roundConstant)
}
}
@@ -119,7 +119,7 @@ func (c *PoseidonChip) SBoxMonomial(x field.F) field.F {
x2 := c.fieldAPI.Mul(x, x)
x4 := c.fieldAPI.Mul(x2, x2)
x3 := c.fieldAPI.Mul(x, x2)
return c.fieldAPI.Mul(x3, x4).(field.F)
return c.fieldAPI.Mul(x3, x4)
}
func (c *PoseidonChip) SBoxMonomialExtension(x field.QuadraticExtension) field.QuadraticExtension {
@@ -148,31 +148,31 @@ func (c *PoseidonChip) SBoxLayerExtension(state PoseidonStateExtension) Poseidon
}
func (c *PoseidonChip) MdsRowShf(r int, v [SPONGE_WIDTH]frontend.Variable) frontend.Variable {
res := frontend.Variable(0)
res := ZERO_VAR
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
res1 := c.api.Mul(v[(i+r)%SPONGE_WIDTH], frontend.Variable(MDS_MATRIX_CIRC[i]))
res1 := c.api.Mul(v[(i+r)%SPONGE_WIDTH], MDS_MATRIX_CIRC_VARS[i])
res = c.api.Add(res, res1)
}
}
res = c.api.Add(res, c.api.Mul(v[r], MDS_MATRIX_DIAG[r]))
res = c.api.Add(res, c.api.Mul(v[r], MDS_MATRIX_DIAG_VARS[r]))
return res
}
func (c *PoseidonChip) MdsRowShfExtension(r int, v [SPONGE_WIDTH]field.QuadraticExtension) field.QuadraticExtension {
res := c.qeAPI.FieldToQE(field.NewFieldElement(0))
res := c.qeAPI.FieldToQE(field.ZERO_F)
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
matrixVal := c.qeAPI.FieldToQE(field.NewFieldElement(MDS_MATRIX_CIRC[i]))
matrixVal := c.qeAPI.FieldToQE(MDS_MATRIX_CIRC[i])
res1 := c.qeAPI.MulExtension(v[(i+r)%SPONGE_WIDTH], matrixVal)
res = c.qeAPI.AddExtension(res, res1)
}
}
matrixVal := c.qeAPI.FieldToQE(field.NewFieldElement(MDS_MATRIX_DIAG[r]))
matrixVal := c.qeAPI.FieldToQE(MDS_MATRIX_DIAG[r])
res = c.qeAPI.AddExtension(res, c.qeAPI.MulExtension(v[r], matrixVal))
return res
}
@@ -180,19 +180,21 @@ func (c *PoseidonChip) MdsRowShfExtension(r int, v [SPONGE_WIDTH]field.Quadratic
func (c *PoseidonChip) MdsLayer(state_ PoseidonState) PoseidonState {
var result PoseidonState
for i := 0; i < SPONGE_WIDTH; i++ {
result[i] = field.NewFieldElement(0)
result[i] = field.ZERO_F
}
var state [SPONGE_WIDTH]frontend.Variable
for i := 0; i < SPONGE_WIDTH; i++ {
state[i] = c.api.FromBinary(c.fieldAPI.ToBinary(state_[i])...)
reducedState := c.fieldAPI.Reduce(state_[i])
//state[i] = c.api.FromBinary(c.fieldAPI.ToBits(reducedState)...)
state[i] = reducedState.Limbs[0]
}
for r := 0; r < 12; r++ {
if r < SPONGE_WIDTH {
sum := c.MdsRowShf(r, state)
bits := c.api.ToBinary(sum)
result[r] = c.fieldAPI.FromBinary(bits).(field.F)
result[r] = c.fieldAPI.FromBits(bits...)
}
}
@@ -215,7 +217,7 @@ func (c *PoseidonChip) MdsLayerExtension(state_ PoseidonStateExtension) Poseidon
func (c *PoseidonChip) PartialFirstConstantLayer(state PoseidonState) PoseidonState {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
state[i] = c.fieldAPI.Add(state[i], field.NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])).(field.F)
state[i] = c.fieldAPI.Add(state[i], FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])
}
}
return state
@@ -224,7 +226,7 @@ func (c *PoseidonChip) PartialFirstConstantLayer(state PoseidonState) PoseidonSt
func (c *PoseidonChip) PartialFirstConstantLayerExtension(state PoseidonStateExtension) PoseidonStateExtension {
for i := 0; i < 12; i++ {
if i < SPONGE_WIDTH {
state[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.FieldToQE(field.NewFieldElement(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i])))
state[i] = c.qeAPI.AddExtension(state[i], c.qeAPI.FieldToQE(FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]))
}
}
return state
@@ -233,7 +235,7 @@ func (c *PoseidonChip) PartialFirstConstantLayerExtension(state PoseidonStateExt
func (c *PoseidonChip) MdsPartialLayerInit(state PoseidonState) PoseidonState {
var result PoseidonState
for i := 0; i < 12; i++ {
result[i] = field.NewFieldElement(0)
result[i] = field.ZERO_F
}
result[0] = state[0]
@@ -242,8 +244,8 @@ func (c *PoseidonChip) MdsPartialLayerInit(state PoseidonState) PoseidonState {
if r < SPONGE_WIDTH {
for d := 1; d < 12; d++ {
if d < SPONGE_WIDTH {
t := field.NewFieldElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])
result[d] = c.fieldAPI.Add(result[d], c.fieldAPI.Mul(state[r], t)).(field.F)
t := FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1]
result[d] = c.fieldAPI.Add(result[d], c.fieldAPI.Mul(state[r], t))
}
}
}
@@ -255,7 +257,7 @@ func (c *PoseidonChip) MdsPartialLayerInit(state PoseidonState) PoseidonState {
func (c *PoseidonChip) MdsPartialLayerInitExtension(state PoseidonStateExtension) PoseidonStateExtension {
var result PoseidonStateExtension
for i := 0; i < 12; i++ {
result[i] = c.qeAPI.FieldToQE(field.NewFieldElement(0))
result[i] = c.qeAPI.FieldToQE(field.ZERO_F)
}
result[0] = state[0]
@@ -264,7 +266,7 @@ func (c *PoseidonChip) MdsPartialLayerInitExtension(state PoseidonStateExtension
if r < SPONGE_WIDTH {
for d := 1; d < 12; d++ {
if d < SPONGE_WIDTH {
t := c.qeAPI.FieldToQE(field.NewFieldElement(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1]))
t := c.qeAPI.FieldToQE(FAST_PARTIAL_ROUND_INITIAL_MATRIX[r-1][d-1])
result[d] = c.qeAPI.AddExtension(result[d], c.qeAPI.MulExtension(state[r], t))
}
}
@@ -275,31 +277,35 @@ func (c *PoseidonChip) MdsPartialLayerInitExtension(state PoseidonStateExtension
}
func (c *PoseidonChip) MdsPartialLayerFast(state PoseidonState, r int) PoseidonState {
dSum := frontend.Variable(0)
dSum := ZERO_VAR
for i := 1; i < 12; i++ {
if i < SPONGE_WIDTH {
t := frontend.Variable(FAST_PARTIAL_ROUND_W_HATS[r][i-1])
si := c.api.FromBinary(c.fieldAPI.ToBinary(state[i])...)
t := FAST_PARTIAL_ROUND_W_HATS_VARS[r][i-1]
reducedState := c.fieldAPI.Reduce(state[i])
//si := c.api.FromBinary(c.fieldAPI.ToBits(reducedState)...)
si := reducedState.Limbs[0]
dSum = c.api.Add(dSum, c.api.Mul(si, t))
}
}
s0 := c.api.FromBinary(c.fieldAPI.ToBinary(state[0])...)
mds0to0 := frontend.Variable(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0])
dSum = c.api.Add(dSum, c.api.Mul(s0, mds0to0))
d := c.fieldAPI.FromBinary(c.api.ToBinary(dSum))
reducedState := c.fieldAPI.Reduce(state[0])
//s0 := c.api.FromBinary(c.fieldAPI.ToBits(reducedState)...)
s0 := reducedState.Limbs[0]
dSum = c.api.Add(dSum, c.api.Mul(s0, MDS0TO0_VAR))
d := c.fieldAPI.FromBits(c.api.ToBinary(dSum)...)
//d := c.fieldAPI.NewElement(dSum)
var result PoseidonState
for i := 0; i < SPONGE_WIDTH; i++ {
result[i] = field.NewFieldElement(0)
result[i] = field.ZERO_F
}
result[0] = d.(field.F)
result[0] = d
for i := 1; i < 12; i++ {
if i < SPONGE_WIDTH {
t := field.NewFieldElement(FAST_PARTIAL_ROUND_VS[r][i-1])
result[i] = c.fieldAPI.Add(state[i], c.fieldAPI.Mul(state[0], t)).(field.F)
t := FAST_PARTIAL_ROUND_VS[r][i-1]
result[i] = c.fieldAPI.Add(state[i], c.fieldAPI.Mul(state[0], t))
}
}
@@ -308,11 +314,11 @@ func (c *PoseidonChip) MdsPartialLayerFast(state PoseidonState, r int) PoseidonS
func (c *PoseidonChip) MdsPartialLayerFastExtension(state PoseidonStateExtension, r int) PoseidonStateExtension {
s0 := state[0]
mds0to0 := c.qeAPI.FieldToQE(field.NewFieldElement(MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0]))
mds0to0 := c.qeAPI.FieldToQE(MDS0TO0)
d := c.qeAPI.MulExtension(s0, mds0to0)
for i := 1; i < 12; i++ {
if i < SPONGE_WIDTH {
t := c.qeAPI.FieldToQE(field.NewFieldElement(FAST_PARTIAL_ROUND_W_HATS[r][i-1]))
t := c.qeAPI.FieldToQE(FAST_PARTIAL_ROUND_W_HATS[r][i-1])
d = c.qeAPI.AddExtension(d, c.qeAPI.MulExtension(state[i], t))
}
}
@@ -321,7 +327,7 @@ func (c *PoseidonChip) MdsPartialLayerFastExtension(state PoseidonStateExtension
result[0] = d
for i := 1; i < 12; i++ {
if i < SPONGE_WIDTH {
t := c.qeAPI.FieldToQE(field.NewFieldElement(FAST_PARTIAL_ROUND_VS[r][i-1]))
t := c.qeAPI.FieldToQE(FAST_PARTIAL_ROUND_VS[r][i-1])
result[i] = c.qeAPI.AddExtension(c.qeAPI.MulExtension(state[0], t), state[i])
}
}

View File

@@ -18,11 +18,11 @@ type TestPoseidonCircuit struct {
func (circuit *TestPoseidonCircuit) Define(api frontend.API) error {
goldilocksApi := field.NewFieldAPI(api)
qeAPI := field.NewQuadraticExtensionAPI(goldilocksApi, 3)
qeAPI := field.NewQuadraticExtensionAPI(api, goldilocksApi, 3)
var input PoseidonState
for i := 0; i < 12; i++ {
input[i] = goldilocksApi.FromBinary(api.ToBinary(circuit.In[i], 64)).(field.F)
input[i] = goldilocksApi.FromBits(api.ToBinary(circuit.In[i], 64)...)
}
poseidonChip := NewPoseidonChip(api, goldilocksApi, qeAPI)
@@ -31,7 +31,7 @@ func (circuit *TestPoseidonCircuit) Define(api frontend.API) error {
for i := 0; i < 12; i++ {
goldilocksApi.AssertIsEqual(
output[i],
goldilocksApi.FromBinary(api.ToBinary(circuit.Out[i])).(field.F),
goldilocksApi.FromBits(api.ToBinary(circuit.Out[i])...),
)
}

View File

@@ -23,7 +23,7 @@ func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error {
// BN254 -> Binary(64) -> F
var input [3]field.F
for i := 0; i < 3; i++ {
input[i] = fieldAPI.FromBinary(api.ToBinary(circuit.In[i], 64)).(field.F)
input[i] = fieldAPI.FromBits(api.ToBinary(circuit.In[i], 64)...)
}
poseidonChip := &PoseidonChip{api: api, fieldAPI: fieldAPI}
@@ -33,7 +33,7 @@ func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error {
for i := 0; i < 4; i++ {
fieldAPI.AssertIsEqual(
output[i],
fieldAPI.FromBinary(api.ToBinary(circuit.Out[i])).(field.F),
fieldAPI.FromBits(api.ToBinary(circuit.Out[i])...),
)
}