You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
2.2 KiB

  1. package utils
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "math/big"
  6. )
  7. var (
  8. // ErrRoundingLoss is used when converted big.Int to Float16 causes rounding loss
  9. ErrRoundingLoss = errors.New("input value causes rounding loss")
  10. )
  11. // Float16 represents a float in a 16 bit format
  12. type Float16 uint16
  13. // Bytes return a byte array of length 2 with the Float16 value encoded in LittleEndian
  14. func (f16 Float16) Bytes() []byte {
  15. var b [2]byte
  16. binary.LittleEndian.PutUint16(b[:], uint16(f16))
  17. return b[:]
  18. }
  19. // BigInt converts the Float16 to a *big.Int integer
  20. func (fl16 *Float16) BigInt() *big.Int {
  21. fl := int64(*fl16)
  22. m := big.NewInt(fl & 0x3FF)
  23. e := big.NewInt(fl >> 11)
  24. e5 := (fl >> 10) & 0x01
  25. exp := big.NewInt(0).Exp(big.NewInt(10), e, nil)
  26. res := m.Mul(m, exp)
  27. if e5 != 0 && e.Cmp(big.NewInt(0)) != 0 {
  28. res.Add(res, exp.Div(exp, big.NewInt(2)))
  29. }
  30. return res
  31. }
  32. // floorFix2Float converts a fix to a float, always rounding down
  33. func floorFix2Float(_f *big.Int) Float16 {
  34. zero := big.NewInt(0)
  35. ten := big.NewInt(10)
  36. e := int64(0)
  37. m := big.NewInt(0)
  38. m.Set(_f)
  39. if m.Cmp(zero) == 0 {
  40. return 0
  41. }
  42. s := big.NewInt(0).Rsh(m, 10)
  43. for s.Cmp(zero) != 0 {
  44. m.Div(m, ten)
  45. s.Rsh(m, 10)
  46. e++
  47. }
  48. return Float16(m.Int64() | e<<11)
  49. }
  50. // NewFloat16 encodes a *big.Int integer as a Float16, returning error in case
  51. // of loss during the encoding.
  52. func NewFloat16(f *big.Int) (Float16, error) {
  53. fl1 := floorFix2Float(f)
  54. fi1 := fl1.BigInt()
  55. fl2 := fl1 | 0x400
  56. fi2 := fl2.BigInt()
  57. m3 := (fl1 & 0x3FF) + 1
  58. e3 := fl1 >> 11
  59. if m3&0x400 == 0 {
  60. m3 = 0x66
  61. e3++
  62. }
  63. fl3 := m3 + e3<<11
  64. fi3 := fl3.BigInt()
  65. res := fl1
  66. d := big.NewInt(0).Abs(fi1.Sub(fi1, f))
  67. d2 := big.NewInt(0).Abs(fi2.Sub(fi2, f))
  68. if d.Cmp(d2) == 1 {
  69. res = fl2
  70. d = d2
  71. }
  72. d3 := big.NewInt(0).Abs(fi3.Sub(fi3, f))
  73. if d.Cmp(d3) == 1 {
  74. res = fl3
  75. }
  76. // Do rounding check
  77. if res.BigInt().Cmp(f) == 0 {
  78. return res, nil
  79. }
  80. return res, ErrRoundingLoss
  81. }
  82. // NewFloat16Floor encodes a big.Int integer as a Float16, rounding down in
  83. // case of loss during the encoding.
  84. func NewFloat16Floor(f *big.Int) Float16 {
  85. fl1 := floorFix2Float(f)
  86. fl2 := fl1 | 0x400
  87. fi2 := fl2.BigInt()
  88. if fi2.Cmp(f) < 1 {
  89. return fl2
  90. }
  91. return fl1
  92. }