You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
3.1 KiB

  1. // Copyright 2017 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build !plan9
  5. package main
  6. import (
  7. "bufio"
  8. "context"
  9. "fmt"
  10. "os"
  11. "os/user"
  12. "path/filepath"
  13. "runtime"
  14. "strings"
  15. )
  16. const (
  17. bashConfig = ".bash_profile"
  18. zshConfig = ".zshrc"
  19. )
  20. // appendToPATH adds the given path to the PATH environment variable and
  21. // persists it for future sessions.
  22. func appendToPATH(value string) error {
  23. if isInPATH(value) {
  24. return nil
  25. }
  26. return persistEnvVar("PATH", pathVar+envSeparator+value)
  27. }
  28. func isInPATH(dir string) bool {
  29. p := os.Getenv("PATH")
  30. paths := strings.Split(p, envSeparator)
  31. for _, d := range paths {
  32. if d == dir {
  33. return true
  34. }
  35. }
  36. return false
  37. }
  38. func getHomeDir() (string, error) {
  39. home := os.Getenv(homeKey)
  40. if home != "" {
  41. return home, nil
  42. }
  43. u, err := user.Current()
  44. if err != nil {
  45. return "", err
  46. }
  47. return u.HomeDir, nil
  48. }
  49. func checkStringExistsFile(filename, value string) (bool, error) {
  50. file, err := os.OpenFile(filename, os.O_RDONLY, 0600)
  51. if err != nil {
  52. if os.IsNotExist(err) {
  53. return false, nil
  54. }
  55. return false, err
  56. }
  57. defer file.Close()
  58. scanner := bufio.NewScanner(file)
  59. for scanner.Scan() {
  60. line := scanner.Text()
  61. if line == value {
  62. return true, nil
  63. }
  64. }
  65. return false, scanner.Err()
  66. }
  67. func appendToFile(filename, value string) error {
  68. verbosef("Adding %q to %s", value, filename)
  69. ok, err := checkStringExistsFile(filename, value)
  70. if err != nil {
  71. return err
  72. }
  73. if ok {
  74. // Nothing to do.
  75. return nil
  76. }
  77. f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
  78. if err != nil {
  79. return err
  80. }
  81. defer f.Close()
  82. _, err = f.WriteString(lineEnding + value + lineEnding)
  83. return err
  84. }
  85. func isShell(name string) bool {
  86. return strings.Contains(currentShell(), name)
  87. }
  88. // persistEnvVarWindows sets an environment variable in the Windows
  89. // registry.
  90. func persistEnvVarWindows(name, value string) error {
  91. _, err := runCommand(context.Background(), "powershell", "-command",
  92. fmt.Sprintf(`[Environment]::SetEnvironmentVariable("%s", "%s", "User")`, name, value))
  93. return err
  94. }
  95. func persistEnvVar(name, value string) error {
  96. if runtime.GOOS == "windows" {
  97. if err := persistEnvVarWindows(name, value); err != nil {
  98. return err
  99. }
  100. if isShell("cmd.exe") || isShell("powershell.exe") {
  101. return os.Setenv(strings.ToUpper(name), value)
  102. }
  103. // User is in bash, zsh, etc.
  104. // Also set the environment variable in their shell config.
  105. }
  106. rc, err := shellConfigFile()
  107. if err != nil {
  108. return err
  109. }
  110. line := fmt.Sprintf("export %s=%s", strings.ToUpper(name), value)
  111. if err := appendToFile(rc, line); err != nil {
  112. return err
  113. }
  114. return os.Setenv(strings.ToUpper(name), value)
  115. }
  116. func shellConfigFile() (string, error) {
  117. home, err := getHomeDir()
  118. if err != nil {
  119. return "", err
  120. }
  121. switch {
  122. case isShell("bash"):
  123. return filepath.Join(home, bashConfig), nil
  124. case isShell("zsh"):
  125. return filepath.Join(home, zshConfig), nil
  126. default:
  127. return "", fmt.Errorf("%q is not a supported shell", currentShell())
  128. }
  129. }