You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
2.2 KiB

  1. //nolint:gomnd
  2. package utils
  3. import (
  4. "encoding/binary"
  5. "errors"
  6. "math/big"
  7. )
  8. var (
  9. // ErrRoundingLoss is used when converted big.Int to Float16 causes rounding loss
  10. ErrRoundingLoss = errors.New("input value causes rounding loss")
  11. )
  12. // Float16 represents a float in a 16 bit format
  13. type Float16 uint16
  14. // Bytes return a byte array of length 2 with the Float16 value encoded in LittleEndian
  15. func (f16 Float16) Bytes() []byte {
  16. var b [2]byte
  17. binary.LittleEndian.PutUint16(b[:], uint16(f16))
  18. return b[:]
  19. }
  20. // BigInt converts the Float16 to a *big.Int integer
  21. func (fl16 *Float16) BigInt() *big.Int {
  22. fl := int64(*fl16)
  23. m := big.NewInt(fl & 0x3FF)
  24. e := big.NewInt(fl >> 11)
  25. e5 := (fl >> 10) & 0x01
  26. exp := big.NewInt(0).Exp(big.NewInt(10), e, nil)
  27. res := m.Mul(m, exp)
  28. if e5 != 0 && e.Cmp(big.NewInt(0)) != 0 {
  29. res.Add(res, exp.Div(exp, big.NewInt(2)))
  30. }
  31. return res
  32. }
  33. // floorFix2Float converts a fix to a float, always rounding down
  34. func floorFix2Float(_f *big.Int) Float16 {
  35. zero := big.NewInt(0)
  36. ten := big.NewInt(10)
  37. e := int64(0)
  38. m := big.NewInt(0)
  39. m.Set(_f)
  40. if m.Cmp(zero) == 0 {
  41. return 0
  42. }
  43. s := big.NewInt(0).Rsh(m, 10)
  44. for s.Cmp(zero) != 0 {
  45. m.Div(m, ten)
  46. s.Rsh(m, 10)
  47. e++
  48. }
  49. return Float16(m.Int64() | e<<11)
  50. }
  51. // NewFloat16 encodes a *big.Int integer as a Float16, returning error in case
  52. // of loss during the encoding.
  53. func NewFloat16(f *big.Int) (Float16, error) {
  54. fl1 := floorFix2Float(f)
  55. fi1 := fl1.BigInt()
  56. fl2 := fl1 | 0x400
  57. fi2 := fl2.BigInt()
  58. m3 := (fl1 & 0x3FF) + 1
  59. e3 := fl1 >> 11
  60. if m3&0x400 == 0 {
  61. m3 = 0x66
  62. e3++
  63. }
  64. fl3 := m3 + e3<<11
  65. fi3 := fl3.BigInt()
  66. res := fl1
  67. d := big.NewInt(0).Abs(fi1.Sub(fi1, f))
  68. d2 := big.NewInt(0).Abs(fi2.Sub(fi2, f))
  69. if d.Cmp(d2) == 1 {
  70. res = fl2
  71. d = d2
  72. }
  73. d3 := big.NewInt(0).Abs(fi3.Sub(fi3, f))
  74. if d.Cmp(d3) == 1 {
  75. res = fl3
  76. }
  77. // Do rounding check
  78. if res.BigInt().Cmp(f) == 0 {
  79. return res, nil
  80. }
  81. return res, ErrRoundingLoss
  82. }
  83. // NewFloat16Floor encodes a big.Int integer as a Float16, rounding down in
  84. // case of loss during the encoding.
  85. func NewFloat16Floor(f *big.Int) Float16 {
  86. fl1 := floorFix2Float(f)
  87. fl2 := fl1 | 0x400
  88. fi2 := fl2.BigInt()
  89. if fi2.Cmp(f) < 1 {
  90. return fl2
  91. }
  92. return fl1
  93. }