diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..4076eeb --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,127 @@ +package utils + +import ( + "errors" + "math/big" +) + +var ( + // ErrRoundingLoss is used when converted big.Int to Float16 causes rounding loss + ErrRoundingLoss = errors.New("input value causes rounding loss") +) + +// Float16 represents a float in a 16 bit format +type Float16 uint16 + +// BigInt converts the Float16 to a big.Int integer +func (fl16 *Float16) BigInt() *big.Int { + + fl := int64(*fl16) + + m := big.NewInt(fl & 0x3FF) + e := big.NewInt(fl >> 11) + e5 := (fl >> 10) & 0x01 + + exp := big.NewInt(0).Exp(big.NewInt(10), e, nil) + res := m.Mul(m, exp) + + if e5 != 0 && e.Cmp(big.NewInt(0)) != 0 { + + res.Add(res, exp.Div(exp, big.NewInt(2))) + + } + + return res + +} + +// floorFix2Float converts a fix to a float, always rounding down +func floorFix2Float(_f *big.Int) Float16 { + + zero := big.NewInt(0) + ten := big.NewInt(10) + e := int64(0) + + m := big.NewInt(0) + m.Set(_f) + + if m.Cmp(zero) == 0 { + return 0 + } + + s := big.NewInt(0).Rsh(m, 10) + + for s.Cmp(zero) != 0 { + + m.Div(m, ten) + s.Rsh(m, 10) + e++ + + } + + return Float16(m.Int64() | e<<11) + +} + +// NewFloat16 encodes a big.Int integer as a Float16, returning error in case +// of loss during the encoding. +func NewFloat16(f *big.Int) (Float16, error) { + + fl1 := floorFix2Float(f) + fi1 := fl1.BigInt() + fl2 := fl1 | 0x400 + fi2 := fl2.BigInt() + + m3 := (fl1 & 0x3FF) + 1 + e3 := fl1 >> 11 + + if m3&0x400 == 0 { + m3 = 0x66 + e3++ + } + + fl3 := m3 + e3<<11 + fi3 := fl3.BigInt() + + res := fl1 + + d := big.NewInt(0).Abs(fi1.Sub(fi1, f)) + d2 := big.NewInt(0).Abs(fi2.Sub(fi2, f)) + + if d.Cmp(d2) == 1 { + res = fl2 + d = d2 + } + + d3 := big.NewInt(0).Abs(fi3.Sub(fi3, f)) + + if d.Cmp(d3) == 1 { + + res = fl3 + } + + // Do rounding check + + if res.BigInt().Cmp(f) == 0 { + + return res, nil + } + + return res, ErrRoundingLoss + +} + +// NewFloat16Floor encodes a big.Int integer as a Float16, rounding down in +// case of loss during the encoding. +func NewFloat16Floor(f *big.Int) Float16 { + + fl1 := floorFix2Float(f) + fl2 := fl1 | 0x400 + fi2 := fl2.BigInt() + + if fi2.Cmp(f) < 1 { + return fl2 + } + return fl1 + +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 0000000..ac05fba --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,140 @@ +package utils + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConversions(t *testing.T) { + + testVector := map[Float16]string{ + 0x307B: "123000000", + 0x1DC6: "454500", + 0xFFFF: "10235000000000000000000000000000000", + 0x0000: "0", + 0x0400: "0", + 0x0001: "1", + 0x0401: "1", + 0x0800: "0", + 0x0c00: "5", + 0x0801: "10", + 0x0c01: "15", + } + + for test := range testVector { + + fix := test.BigInt() + + assert.Equal(t, fix.String(), testVector[test]) + + bi := big.NewInt(0) + bi.SetString(testVector[test], 10) + + fl, err := NewFloat16(bi) + assert.Equal(t, nil, err) + + fx2 := fl.BigInt() + assert.Equal(t, fx2.String(), testVector[test]) + + } + +} + +func TestFloorFix2Float(t *testing.T) { + + testVector := map[string]Float16{ + "87999990000000000": 0x776f, + "87950000000000001": 0x776f, + "87950000000000000": 0x776f, + "87949999999999999": 0x736f, + } + + for test := range testVector { + + bi := big.NewInt(0) + bi.SetString(test, 10) + + testFloat := NewFloat16Floor(bi) + + assert.Equal(t, testFloat, testVector[test]) + + } + +} + +func TestConversionLosses(t *testing.T) { + a := big.NewInt(1000) + b, err := NewFloat16(a) + assert.Equal(t, nil, err) + c := b.BigInt() + assert.Equal(t, c, a) + + a = big.NewInt(1024) + b, err = NewFloat16(a) + assert.Equal(t, ErrRoundingLoss, err) + c = b.BigInt() + assert.NotEqual(t, c, a) + + a = big.NewInt(32767) + b, err = NewFloat16(a) + assert.Equal(t, ErrRoundingLoss, err) + c = b.BigInt() + assert.NotEqual(t, c, a) + + a = big.NewInt(32768) + b, err = NewFloat16(a) + assert.Equal(t, ErrRoundingLoss, err) + c = b.BigInt() + assert.NotEqual(t, c, a) + + a = big.NewInt(65536000) + b, err = NewFloat16(a) + assert.Equal(t, ErrRoundingLoss, err) + c = b.BigInt() + assert.NotEqual(t, c, a) + +} + +func BenchmarkFloat16(b *testing.B) { + newBigInt := func(s string) *big.Int { + bigInt, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("Bad big int") + } + return bigInt + } + type pair struct { + Float16 Float16 + BigInt *big.Int + } + testVector := []pair{ + pair{0x307B, newBigInt("123000000")}, + pair{0x1DC6, newBigInt("454500")}, + pair{0xFFFF, newBigInt("10235000000000000000000000000000000")}, + pair{0x0000, newBigInt("0")}, + pair{0x0400, newBigInt("0")}, + pair{0x0001, newBigInt("1")}, + pair{0x0401, newBigInt("1")}, + pair{0x0800, newBigInt("0")}, + pair{0x0c00, newBigInt("5")}, + pair{0x0801, newBigInt("10")}, + pair{0x0c01, newBigInt("15")}, + } + b.Run("floorFix2Float()", func(b *testing.B) { + for i := 0; i < b.N; i++ { + NewFloat16Floor(testVector[i%len(testVector)].BigInt) + } + }) + b.Run("NewFloat16()", func(b *testing.B) { + for i := 0; i < b.N; i++ { + NewFloat16(testVector[i%len(testVector)].BigInt) //nolint:errcheck + } + }) + b.Run("Float16.BigInt()", func(b *testing.B) { + for i := 0; i < b.N; i++ { + testVector[i%len(testVector)].Float16.BigInt() + } + }) +}