You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grampus.go 4.3 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package grampus
  2. import (
  3. "code.gitea.io/gitea/models"
  4. "code.gitea.io/gitea/modules/context"
  5. "code.gitea.io/gitea/modules/log"
  6. "code.gitea.io/gitea/modules/notification"
  7. "code.gitea.io/gitea/modules/timeutil"
  8. "strings"
  9. )
  10. const (
  11. //notebook
  12. storageTypeOBS = "obs"
  13. autoStopDuration = 4 * 60 * 60
  14. autoStopDurationMs = 4 * 60 * 60 * 1000
  15. DataSetMountPath = "/home/ma-user/work"
  16. NotebookEnv = "Python3"
  17. NotebookType = "Ascend"
  18. FlavorInfo = "Ascend: 1*Ascend 910 CPU: 24 核 96GiB (modelarts.kat1.xlarge)"
  19. CodePath = "/code/"
  20. OutputPath = "/output/"
  21. ResultPath = "/result/"
  22. LogPath = "/log/"
  23. JobPath = "/job/"
  24. OrderDesc = "desc" //向下查询
  25. OrderAsc = "asc" //向上查询
  26. Lines = 500
  27. TrainUrl = "train_url"
  28. DataUrl = "data_url"
  29. ResultUrl = "result_url"
  30. CkptUrl = "ckpt_url"
  31. DeviceTarget = "device_target"
  32. Ascend = "Ascend"
  33. PerPage = 10
  34. IsLatestVersion = "1"
  35. NotLatestVersion = "0"
  36. VersionCount = 1
  37. SortByCreateTime = "create_time"
  38. ConfigTypeCustom = "custom"
  39. TotalVersionCount = 1
  40. )
  41. var (
  42. poolInfos *models.PoolInfos
  43. FlavorInfos *models.FlavorInfos
  44. ImageInfos *models.ImageInfosModelArts
  45. )
  46. type GenerateTrainJobReq struct {
  47. JobName string
  48. Command string
  49. ResourceSpecId string
  50. ImageUrl string
  51. ImageId string
  52. DisplayJobName string
  53. Uuid string
  54. Description string
  55. CodeObsPath string
  56. BootFile string
  57. BootFileUrl string
  58. DataUrl string
  59. TrainUrl string
  60. WorkServerNumber int
  61. EngineID int64
  62. CommitID string
  63. IsLatestVersion string
  64. BranchName string
  65. PreVersionId int64
  66. PreVersionName string
  67. FlavorName string
  68. VersionCount int
  69. EngineName string
  70. TotalVersionCount int
  71. ComputeResource string
  72. DatasetName string
  73. }
  74. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  75. createTime := timeutil.TimeStampNow()
  76. jobResult, err := createJob(models.CreateGrampusJobRequest{
  77. Name: req.JobName,
  78. Tasks: []models.GrampusTasks{
  79. {
  80. Name: req.JobName,
  81. Command: req.Command,
  82. ResourceSpecId: req.ResourceSpecId,
  83. ImageId: req.ImageId,
  84. ImageUrl: req.ImageUrl,
  85. },
  86. },
  87. })
  88. if err != nil {
  89. log.Error("createJob failed: %v", err.Error())
  90. return err
  91. }
  92. jobID := jobResult.JobInfo.JobID
  93. err = models.CreateCloudbrain(&models.Cloudbrain{
  94. Status: TransTrainJobStatus(jobResult.JobInfo.Status),
  95. UserID: ctx.User.ID,
  96. RepoID: ctx.Repo.Repository.ID,
  97. JobID: jobID,
  98. JobName: req.JobName,
  99. DisplayJobName: req.DisplayJobName,
  100. JobType: string(models.JobTypeTrain),
  101. Type: models.TypeCloudBrainGrampus,
  102. //VersionID: jobResult.VersionID,
  103. //VersionName: jobResult.VersionName,
  104. Uuid: req.Uuid,
  105. DatasetName: req.DatasetName,
  106. CommitID: req.CommitID,
  107. IsLatestVersion: req.IsLatestVersion,
  108. ComputeResource: req.ComputeResource,
  109. //EngineID: req.EngineID,
  110. TrainUrl: req.TrainUrl,
  111. BranchName: req.BranchName,
  112. //Parameters: req.Params,
  113. BootFile: req.BootFile,
  114. DataUrl: req.DataUrl,
  115. //LogUrl: req.LogUrl,
  116. //FlavorCode: req.FlavorCode,
  117. Description: req.Description,
  118. WorkServerNumber: req.WorkServerNumber,
  119. FlavorName: req.FlavorName,
  120. EngineName: req.EngineName,
  121. VersionCount: req.VersionCount,
  122. TotalVersionCount: req.TotalVersionCount,
  123. CreatedUnix: createTime,
  124. UpdatedUnix: createTime,
  125. })
  126. if err != nil {
  127. log.Error("CreateCloudbrain(%s) failed:%v", req.DisplayJobName, err.Error())
  128. return err
  129. }
  130. var actionType models.ActionType
  131. if req.ComputeResource == models.NPUResource {
  132. actionType = models.ActionCreateTrainTask
  133. } else if req.ComputeResource == models.GPUResource {
  134. actionType = models.ActionCreateGPUTrainTask
  135. }
  136. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, req.DisplayJobName, actionType)
  137. return nil
  138. }
  139. func TransTrainJobStatus(status string) string {
  140. if status == "pending" {
  141. status = "waiting"
  142. }
  143. return strings.ToUpper(status)
  144. }