You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.go 40 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. package cloudbrainTask
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net/http"
  9. "os"
  10. "path"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "code.gitea.io/gitea/modules/urfs_client/urchin"
  15. "code.gitea.io/gitea/modules/timeutil"
  16. "code.gitea.io/gitea/modules/notification"
  17. "code.gitea.io/gitea/modules/obs"
  18. "code.gitea.io/gitea/modules/git"
  19. "code.gitea.io/gitea/modules/storage"
  20. "github.com/unknwon/com"
  21. "code.gitea.io/gitea/models"
  22. "code.gitea.io/gitea/modules/cloudbrain"
  23. "code.gitea.io/gitea/modules/context"
  24. "code.gitea.io/gitea/modules/grampus"
  25. "code.gitea.io/gitea/modules/log"
  26. "code.gitea.io/gitea/modules/modelarts"
  27. "code.gitea.io/gitea/modules/redis/redis_key"
  28. "code.gitea.io/gitea/modules/redis/redis_lock"
  29. "code.gitea.io/gitea/modules/setting"
  30. api "code.gitea.io/gitea/modules/structs"
  31. "code.gitea.io/gitea/modules/util"
  32. "code.gitea.io/gitea/services/cloudbrain/resource"
  33. "code.gitea.io/gitea/services/reward/point/account"
  34. )
  35. var jobNamePattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-_]{1,34}[a-z0-9-]$`)
  36. const TaskTypeCloudbrainOne = 0
  37. const TaskTypeModelArts = 1
  38. const TaskTypeGrampusGPU = 2
  39. const TaskTypeGrampusNPU = 3
  40. func CloudbrainOneTrainJobCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  41. displayJobName := option.DisplayJobName
  42. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  43. image := strings.TrimSpace(option.Image)
  44. uuids := option.Attachment
  45. jobType := string(models.JobTypeTrain)
  46. codePath := setting.JobPath + jobName + cloudbrain.CodeMountPath
  47. branchName := option.BranchName
  48. repo := ctx.Repo.Repository
  49. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), jobType, displayJobName))
  50. defer lock.UnLock()
  51. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  52. if err != nil {
  53. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  54. return
  55. }
  56. command, err := getTrainJobCommand(option)
  57. if err != nil {
  58. log.Error("getTrainJobCommand failed: %v", err)
  59. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  60. return
  61. }
  62. errStr := loadCodeAndMakeModelPath(repo, codePath, branchName, jobName, cloudbrain.ModelMountPath)
  63. if errStr != "" {
  64. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr(errStr)))
  65. return
  66. }
  67. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  68. req := cloudbrain.GenerateCloudBrainTaskReq{
  69. Ctx: ctx,
  70. DisplayJobName: displayJobName,
  71. JobName: jobName,
  72. Image: image,
  73. Command: command,
  74. Uuids: uuids,
  75. DatasetNames: datasetNames,
  76. DatasetInfos: datasetInfos,
  77. CodePath: storage.GetMinioPath(jobName, cloudbrain.CodeMountPath+"/"),
  78. ModelPath: storage.GetMinioPath(jobName, cloudbrain.ModelMountPath+"/"),
  79. BenchmarkPath: storage.GetMinioPath(jobName, cloudbrain.BenchMarkMountPath+"/"),
  80. Snn4ImageNetPath: storage.GetMinioPath(jobName, cloudbrain.Snn4imagenetMountPath+"/"),
  81. BrainScorePath: storage.GetMinioPath(jobName, cloudbrain.BrainScoreMountPath+"/"),
  82. JobType: jobType,
  83. Description: option.Description,
  84. BranchName: branchName,
  85. BootFile: option.BootFile,
  86. Params: option.Params,
  87. CommitID: commitID,
  88. BenchmarkTypeID: 0,
  89. BenchmarkChildTypeID: 0,
  90. ResultPath: storage.GetMinioPath(jobName, cloudbrain.ResultPath+"/"),
  91. Spec: spec,
  92. }
  93. if option.ModelName != "" { //使用预训练模型训练
  94. req.ModelName = option.ModelName
  95. req.LabelName = option.LabelName
  96. req.CkptName = option.CkptName
  97. req.ModelVersion = option.ModelVersion
  98. req.PreTrainModelPath = setting.Attachment.Minio.RealPath + option.PreTrainModelUrl
  99. req.PreTrainModelUrl = option.PreTrainModelUrl
  100. }
  101. jobId, err := cloudbrain.GenerateTask(req)
  102. if err != nil {
  103. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  104. return
  105. }
  106. ctx.JSON(http.StatusOK, models.BaseMessageApi{
  107. Code: 0,
  108. Message: jobId,
  109. })
  110. }
  111. func ModelArtsTrainJobNpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  112. VersionOutputPath := modelarts.GetOutputPathByCount(modelarts.TotalVersionCount)
  113. displayJobName := option.DisplayJobName
  114. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  115. uuid := option.Attachment
  116. description := option.Description
  117. workServerNumber := option.WorkServerNumber
  118. engineID, _ := strconv.Atoi(option.ImageID)
  119. bootFile := strings.TrimSpace(option.BootFile)
  120. params := option.Params
  121. repo := ctx.Repo.Repository
  122. codeLocalPath := setting.JobPath + jobName + modelarts.CodePath
  123. codeObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.CodePath + VersionOutputPath + "/"
  124. outputObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.OutputPath + VersionOutputPath + "/"
  125. logObsPath := "/" + setting.Bucket + modelarts.JobPath + jobName + modelarts.LogPath + VersionOutputPath + "/"
  126. branchName := option.BranchName
  127. isLatestVersion := modelarts.IsLatestVersion
  128. VersionCount := modelarts.VersionCountOne
  129. EngineName := option.Image
  130. errStr := checkMultiNode(ctx.User.ID, option.WorkServerNumber)
  131. if errStr != "" {
  132. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr(errStr)))
  133. return
  134. }
  135. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  136. defer lock.UnLock()
  137. spec, _, _, err := checkParameters(ctx, option, lock, repo)
  138. if err != nil {
  139. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  140. return
  141. }
  142. //todo: del the codeLocalPath
  143. _, err = ioutil.ReadDir(codeLocalPath)
  144. if err == nil {
  145. os.RemoveAll(codeLocalPath)
  146. }
  147. gitRepo, _ := git.OpenRepository(repo.RepoPath())
  148. commitID, _ := gitRepo.GetBranchCommitID(branchName)
  149. if err := downloadCode(repo, codeLocalPath, branchName); err != nil {
  150. log.Error("downloadCode failed, server timed out: %s (%v)", repo.FullName(), err)
  151. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  152. return
  153. }
  154. //todo: upload code (send to file_server todo this work?)
  155. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.OutputPath + VersionOutputPath + "/"); err != nil {
  156. log.Error("Failed to obsMkdir_output: %s (%v)", repo.FullName(), err)
  157. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to obsMkdir_output"))
  158. return
  159. }
  160. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.LogPath + VersionOutputPath + "/"); err != nil {
  161. log.Error("Failed to obsMkdir_log: %s (%v)", repo.FullName(), err)
  162. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to obsMkdir_log"))
  163. return
  164. }
  165. parentDir := VersionOutputPath + "/"
  166. if err := uploadCodeToObs(codeLocalPath, jobName, parentDir); err != nil {
  167. // if err := uploadCodeToObs(codeLocalPath, jobName, parentDir); err != nil {
  168. log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err)
  169. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  170. return
  171. }
  172. var parameters models.Parameters
  173. param := make([]models.Parameter, 0)
  174. existDeviceTarget := false
  175. if len(params) != 0 {
  176. err := json.Unmarshal([]byte(params), &parameters)
  177. if err != nil {
  178. log.Error("Failed to Unmarshal params: %s (%v)", params, err)
  179. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("运行参数错误"))
  180. return
  181. }
  182. for _, parameter := range parameters.Parameter {
  183. if parameter.Label == modelarts.DeviceTarget {
  184. existDeviceTarget = true
  185. }
  186. if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl {
  187. param = append(param, models.Parameter{
  188. Label: parameter.Label,
  189. Value: parameter.Value,
  190. })
  191. }
  192. }
  193. }
  194. if !existDeviceTarget {
  195. param = append(param, models.Parameter{
  196. Label: modelarts.DeviceTarget,
  197. Value: modelarts.Ascend,
  198. })
  199. }
  200. datasUrlList, dataUrl, datasetNames, isMultiDataset, err := getDatasUrlListByUUIDS(uuid)
  201. if err != nil {
  202. log.Error("Failed to getDatasUrlListByUUIDS: %v", err)
  203. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Failed to getDatasUrlListByUUIDS:"+err.Error()))
  204. return
  205. }
  206. dataPath := dataUrl
  207. jsondatas, err := json.Marshal(datasUrlList)
  208. if err != nil {
  209. log.Error("Failed to Marshal: %v", err)
  210. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("json error:"+err.Error()))
  211. return
  212. }
  213. if isMultiDataset {
  214. param = append(param, models.Parameter{
  215. Label: modelarts.MultiDataUrl,
  216. Value: string(jsondatas),
  217. })
  218. }
  219. if option.ModelName != "" { //使用预训练模型训练
  220. ckptUrl := "/" + option.PreTrainModelUrl + option.CkptName
  221. param = append(param, models.Parameter{
  222. Label: modelarts.CkptUrl,
  223. Value: "s3:/" + ckptUrl,
  224. })
  225. }
  226. req := &modelarts.GenerateTrainJobReq{
  227. JobName: jobName,
  228. DisplayJobName: displayJobName,
  229. DataUrl: dataPath,
  230. Description: description,
  231. CodeObsPath: codeObsPath,
  232. BootFileUrl: codeObsPath + bootFile,
  233. BootFile: bootFile,
  234. TrainUrl: outputObsPath,
  235. WorkServerNumber: workServerNumber,
  236. EngineID: int64(engineID),
  237. LogUrl: logObsPath,
  238. PoolID: getPoolId(),
  239. Uuid: uuid,
  240. Parameters: param,
  241. CommitID: commitID,
  242. IsLatestVersion: isLatestVersion,
  243. BranchName: branchName,
  244. Params: option.Params,
  245. EngineName: EngineName,
  246. VersionCount: VersionCount,
  247. TotalVersionCount: modelarts.TotalVersionCount,
  248. DatasetName: datasetNames,
  249. Spec: spec,
  250. }
  251. if option.ModelName != "" { //使用预训练模型训练
  252. req.ModelName = option.ModelName
  253. req.LabelName = option.LabelName
  254. req.CkptName = option.CkptName
  255. req.ModelVersion = option.ModelVersion
  256. req.PreTrainModelUrl = option.PreTrainModelUrl
  257. }
  258. userCommand, userImageUrl := getUserCommand(engineID, req)
  259. req.UserCommand = userCommand
  260. req.UserImageUrl = userImageUrl
  261. //将params转换Parameters.Parameter,出错时返回给前端
  262. var Parameters modelarts.Parameters
  263. if err := json.Unmarshal([]byte(params), &Parameters); err != nil {
  264. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("json.Unmarshal failed:"+err.Error()))
  265. return
  266. }
  267. jobId, err := modelarts.GenerateTrainJob(ctx, req)
  268. if err != nil {
  269. log.Error("GenerateTrainJob failed:%v", err.Error())
  270. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  271. return
  272. }
  273. ctx.JSON(http.StatusOK, models.BaseMessageApi{
  274. Code: 0,
  275. Message: jobId,
  276. })
  277. }
  278. func GrampusTrainJobGpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  279. displayJobName := option.DisplayJobName
  280. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  281. uuid := option.Attachment
  282. description := option.Description
  283. bootFile := strings.TrimSpace(option.BootFile)
  284. params := option.Params
  285. repo := ctx.Repo.Repository
  286. codeLocalPath := setting.JobPath + jobName + cloudbrain.CodeMountPath + "/"
  287. codeMinioPath := setting.CBCodePathPrefix + jobName + cloudbrain.CodeMountPath + "/"
  288. branchName := option.BranchName
  289. image := strings.TrimSpace(option.Image)
  290. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  291. defer lock.UnLock()
  292. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  293. if err != nil {
  294. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  295. return
  296. }
  297. //prepare code and out path
  298. _, err = ioutil.ReadDir(codeLocalPath)
  299. if err == nil {
  300. os.RemoveAll(codeLocalPath)
  301. }
  302. if err := downloadZipCode(ctx, codeLocalPath, branchName); err != nil {
  303. log.Error("downloadZipCode failed, server timed out: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  304. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  305. }
  306. //todo: upload code (send to file_server todo this work?)
  307. //upload code
  308. if err := uploadCodeToMinio(codeLocalPath+"/", jobName, cloudbrain.CodeMountPath+"/"); err != nil {
  309. log.Error("Failed to uploadCodeToMinio: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  310. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  311. return
  312. }
  313. modelPath := setting.JobPath + jobName + cloudbrain.ModelMountPath + "/"
  314. if err := mkModelPath(modelPath); err != nil {
  315. log.Error("Failed to mkModelPath: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  316. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  317. return
  318. }
  319. //init model readme
  320. if err := uploadCodeToMinio(modelPath, jobName, cloudbrain.ModelMountPath+"/"); err != nil {
  321. log.Error("Failed to uploadCodeToMinio: %s (%v)", repo.FullName(), err, ctx.Data["MsgID"])
  322. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  323. return
  324. }
  325. var datasetRemotePath, allFileName string
  326. for _, datasetInfo := range datasetInfos {
  327. if datasetRemotePath == "" {
  328. datasetRemotePath = datasetInfo.DataLocalPath
  329. allFileName = datasetInfo.FullName
  330. } else {
  331. datasetRemotePath = datasetRemotePath + ";" + datasetInfo.DataLocalPath
  332. allFileName = allFileName + ";" + datasetInfo.FullName
  333. }
  334. }
  335. //prepare command
  336. preTrainModelPath := getPreTrainModelPath(option.PreTrainModelUrl, option.CkptName)
  337. command, err := generateCommand(repo.Name, grampus.ProcessorTypeGPU, codeMinioPath+cloudbrain.DefaultBranchName+".zip", datasetRemotePath, bootFile, params, setting.CBCodePathPrefix+jobName+cloudbrain.ModelMountPath+"/", allFileName, preTrainModelPath, option.CkptName, "")
  338. if err != nil {
  339. log.Error("Failed to generateCommand: %s (%v)", displayJobName, err, ctx.Data["MsgID"])
  340. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Create task failed, internal error"))
  341. return
  342. }
  343. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  344. req := &grampus.GenerateTrainJobReq{
  345. JobName: jobName,
  346. DisplayJobName: displayJobName,
  347. ComputeResource: models.GPUResource,
  348. ProcessType: grampus.ProcessorTypeGPU,
  349. Command: command,
  350. ImageUrl: image,
  351. Description: description,
  352. BootFile: bootFile,
  353. Uuid: uuid,
  354. CommitID: commitID,
  355. BranchName: branchName,
  356. Params: option.Params,
  357. EngineName: image,
  358. DatasetNames: datasetNames,
  359. DatasetInfos: datasetInfos,
  360. IsLatestVersion: modelarts.IsLatestVersion,
  361. VersionCount: modelarts.VersionCountOne,
  362. WorkServerNumber: 1,
  363. Spec: spec,
  364. }
  365. if option.ModelName != "" { //使用预训练模型训练
  366. req.ModelName = option.ModelName
  367. req.LabelName = option.LabelName
  368. req.CkptName = option.CkptName
  369. req.ModelVersion = option.ModelVersion
  370. req.PreTrainModelUrl = option.PreTrainModelUrl
  371. }
  372. jobId, err := grampus.GenerateTrainJob(ctx, req)
  373. if err != nil {
  374. log.Error("GenerateTrainJob failed:%v", err.Error(), ctx.Data["MsgID"])
  375. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  376. return
  377. }
  378. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  379. }
  380. func checkParameters(ctx *context.Context, option api.CreateTrainJobOption, lock *redis_lock.DistributeLock, repo *models.Repository) (*models.Specification, map[string]models.DatasetInfo, string, error) {
  381. isOk, err := lock.Lock(models.CloudbrainKeyDuration)
  382. if !isOk {
  383. log.Error("lock processed failed:%v", err, ctx.Data["MsgID"])
  384. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_samejob_err"))
  385. }
  386. if !jobNamePattern.MatchString(option.DisplayJobName) {
  387. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_jobname_err"))
  388. }
  389. bootFileExist, err := ctx.Repo.FileExists(option.BootFile, option.BranchName)
  390. if err != nil || !bootFileExist {
  391. log.Error("Get bootfile error:", err, ctx.Data["MsgID"])
  392. return nil, nil, "", fmt.Errorf(ctx.Tr("repo.cloudbrain_bootfile_err"))
  393. }
  394. computeResource := models.GPUResource
  395. if isNpuTask(option) {
  396. computeResource = models.NPUResource
  397. }
  398. //check count limit
  399. taskType := option.Type
  400. if isC2NetTask(option) {
  401. taskType = 2
  402. }
  403. count, err := GetNotFinalStatusTaskCount(ctx.User.ID, taskType, string(models.JobTypeTrain), computeResource)
  404. if err != nil {
  405. log.Error("GetCountByUserID failed:%v", err, ctx.Data["MsgID"])
  406. return nil, nil, "", fmt.Errorf("system error")
  407. } else {
  408. if count >= 1 {
  409. log.Error("the user already has running or waiting task", ctx.Data["MsgID"])
  410. return nil, nil, "", fmt.Errorf("you have already a running or waiting task, can not create more.")
  411. }
  412. }
  413. //check param
  414. if err := paramCheckCreateTrainJob(option.BootFile, option.BranchName); err != nil {
  415. log.Error("paramCheckCreateTrainJob failed:(%v)", err, ctx.Data["MsgID"])
  416. return nil, nil, "", err
  417. }
  418. //check whether the task name in the project is duplicated
  419. tasks, err := models.GetCloudbrainsByDisplayJobName(repo.ID, string(models.JobTypeTrain), option.DisplayJobName)
  420. if err == nil {
  421. if len(tasks) != 0 {
  422. log.Error("the job name did already exist", ctx.Data["MsgID"])
  423. return nil, nil, "", fmt.Errorf("The job name did already exist.")
  424. }
  425. } else {
  426. if !models.IsErrJobNotExist(err) {
  427. log.Error("system error, %v", err, ctx.Data["MsgID"])
  428. return nil, nil, "", fmt.Errorf("system error")
  429. }
  430. }
  431. //check specification
  432. computeType := models.GPU
  433. if isNpuTask(option) {
  434. computeType = models.NPU
  435. }
  436. cluster := models.OpenICluster
  437. if isC2NetTask(option) {
  438. cluster = models.C2NetCluster
  439. }
  440. aiCenterCode := ""
  441. if option.Type == TaskTypeCloudbrainOne {
  442. aiCenterCode = models.AICenterOfCloudBrainOne
  443. } else if option.Type == TaskTypeModelArts {
  444. aiCenterCode = models.AICenterOfCloudBrainTwo
  445. }
  446. spec, err := resource.GetAndCheckSpec(ctx.User.ID, option.SpecId, models.FindSpecsOptions{
  447. JobType: models.JobTypeTrain,
  448. ComputeResource: computeType,
  449. Cluster: cluster,
  450. AiCenterCode: aiCenterCode,
  451. })
  452. if err != nil || spec == nil {
  453. return nil, nil, "", fmt.Errorf("Resource specification is not available.")
  454. }
  455. if !account.IsPointBalanceEnough(ctx.User.ID, spec.UnitPrice) {
  456. log.Error("point balance is not enough,userId=%d specId=%d", ctx.User.ID, spec.ID)
  457. return nil, nil, "", fmt.Errorf(ctx.Tr("points.insufficient_points_balance"))
  458. }
  459. //check dataset
  460. var datasetInfos map[string]models.DatasetInfo
  461. var datasetNames string
  462. if option.Type != TaskTypeModelArts {
  463. datasetInfos, datasetNames, err = models.GetDatasetInfo(option.Attachment, computeType)
  464. if err != nil {
  465. log.Error("GetDatasetInfo failed: %v", err, ctx.Data["MsgID"])
  466. return nil, nil, "", fmt.Errorf(ctx.Tr("cloudbrain.error.dataset_select"))
  467. }
  468. }
  469. return spec, datasetInfos, datasetNames, err
  470. }
  471. func isNpuTask(option api.CreateTrainJobOption) bool {
  472. return option.Type == TaskTypeModelArts || option.Type == TaskTypeGrampusNPU
  473. }
  474. func isC2NetTask(option api.CreateTrainJobOption) bool {
  475. return option.Type == TaskTypeGrampusGPU || option.Type == TaskTypeGrampusNPU
  476. }
  477. func GrampusTrainJobNpuCreate(ctx *context.Context, option api.CreateTrainJobOption) {
  478. displayJobName := option.DisplayJobName
  479. jobName := util.ConvertDisplayJobNameToJobName(displayJobName)
  480. uuid := option.Attachment
  481. description := option.Description
  482. bootFile := strings.TrimSpace(option.BootFile)
  483. params := option.Params
  484. repo := ctx.Repo.Repository
  485. codeLocalPath := setting.JobPath + jobName + modelarts.CodePath
  486. codeObsPath := grampus.JobPath + jobName + modelarts.CodePath
  487. branchName := option.BranchName
  488. isLatestVersion := modelarts.IsLatestVersion
  489. versionCount := modelarts.VersionCountOne
  490. engineName := option.Image
  491. lock := redis_lock.NewDistributeLock(redis_key.CloudbrainBindingJobNameKey(fmt.Sprint(repo.ID), string(models.JobTypeTrain), displayJobName))
  492. defer lock.UnLock()
  493. spec, datasetInfos, datasetNames, err := checkParameters(ctx, option, lock, repo)
  494. if err != nil {
  495. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  496. return
  497. }
  498. //prepare code and out path
  499. _, err = ioutil.ReadDir(codeLocalPath)
  500. if err == nil {
  501. os.RemoveAll(codeLocalPath)
  502. }
  503. if err := downloadZipCode(ctx, codeLocalPath, branchName); err != nil {
  504. log.Error("downloadZipCode failed, server timed out: %s (%v)", repo.FullName(), err)
  505. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  506. return
  507. }
  508. //todo: upload code (send to file_server todo this work?)
  509. if err := obsMkdir(setting.CodePathPrefix + jobName + modelarts.OutputPath); err != nil {
  510. log.Error("Failed to obsMkdir_output: %s (%v)", repo.FullName(), err)
  511. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  512. return
  513. }
  514. if err := uploadCodeToObs(codeLocalPath, jobName, ""); err != nil {
  515. log.Error("Failed to uploadCodeToObs: %s (%v)", repo.FullName(), err)
  516. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(ctx.Tr("cloudbrain.load_code_failed")))
  517. return
  518. }
  519. var datasetRemotePath, allFileName string
  520. for _, datasetInfo := range datasetInfos {
  521. if datasetRemotePath == "" {
  522. datasetRemotePath = datasetInfo.DataLocalPath + "'" + datasetInfo.FullName + "'"
  523. allFileName = datasetInfo.FullName
  524. } else {
  525. datasetRemotePath = datasetRemotePath + ";" + datasetInfo.DataLocalPath + "'" + datasetInfo.FullName + "'"
  526. allFileName = allFileName + ";" + datasetInfo.FullName
  527. }
  528. }
  529. //prepare command
  530. preTrainModelPath := getPreTrainModelPath(option.PreTrainModelUrl, option.CkptName)
  531. command, err := generateCommand(repo.Name, grampus.ProcessorTypeNPU, codeObsPath+cloudbrain.DefaultBranchName+".zip", datasetRemotePath, bootFile, params, setting.CodePathPrefix+jobName+modelarts.OutputPath, allFileName, preTrainModelPath, option.CkptName, grampus.GetNpuModelRemoteObsUrl(jobName))
  532. if err != nil {
  533. log.Error("Failed to generateCommand: %s (%v)", displayJobName, err, ctx.Data["MsgID"])
  534. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi("Create task failed, internal error"))
  535. return
  536. }
  537. commitID, _ := ctx.Repo.GitRepo.GetBranchCommitID(branchName)
  538. req := &grampus.GenerateTrainJobReq{
  539. JobName: jobName,
  540. DisplayJobName: displayJobName,
  541. ComputeResource: models.NPUResource,
  542. ProcessType: grampus.ProcessorTypeNPU,
  543. Command: command,
  544. ImageId: option.ImageID,
  545. Description: description,
  546. CodeObsPath: codeObsPath,
  547. BootFileUrl: codeObsPath + bootFile,
  548. BootFile: bootFile,
  549. WorkServerNumber: option.WorkServerNumber,
  550. Uuid: uuid,
  551. CommitID: commitID,
  552. IsLatestVersion: isLatestVersion,
  553. BranchName: branchName,
  554. Params: option.Params,
  555. EngineName: engineName,
  556. VersionCount: versionCount,
  557. TotalVersionCount: modelarts.TotalVersionCount,
  558. DatasetNames: datasetNames,
  559. DatasetInfos: datasetInfos,
  560. Spec: spec,
  561. CodeName: strings.ToLower(repo.Name),
  562. }
  563. if option.ModelName != "" { //使用预训练模型训练
  564. req.ModelName = option.ModelName
  565. req.LabelName = option.LabelName
  566. req.CkptName = option.CkptName
  567. req.ModelVersion = option.ModelVersion
  568. req.PreTrainModelUrl = option.PreTrainModelUrl
  569. req.PreTrainModelPath = preTrainModelPath
  570. }
  571. jobId, err := grampus.GenerateTrainJob(ctx, req)
  572. if err != nil {
  573. log.Error("GenerateTrainJob failed:%v", err.Error())
  574. ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error()))
  575. return
  576. }
  577. ctx.JSON(http.StatusOK, models.BaseMessageApi{Code: 0, Message: jobId})
  578. }
  579. func obsMkdir(dir string) error {
  580. input := &obs.PutObjectInput{}
  581. input.Bucket = setting.Bucket
  582. input.Key = dir
  583. _, err := storage.ObsCli.PutObject(input)
  584. if err != nil {
  585. log.Error("PutObject(%s) failed: %s", input.Key, err.Error())
  586. return err
  587. }
  588. return nil
  589. }
  590. func uploadCodeToObs(codePath, jobName, parentDir string) error {
  591. files, err := readDir(codePath)
  592. if err != nil {
  593. log.Error("readDir(%s) failed: %s", codePath, err.Error())
  594. return err
  595. }
  596. for _, file := range files {
  597. if file.IsDir() {
  598. input := &obs.PutObjectInput{}
  599. input.Bucket = setting.Bucket
  600. input.Key = parentDir + file.Name() + "/"
  601. _, err = storage.ObsCli.PutObject(input)
  602. if err != nil {
  603. log.Error("PutObject(%s) failed: %s", input.Key, err.Error())
  604. return err
  605. }
  606. if err = uploadCodeToObs(codePath+file.Name()+"/", jobName, parentDir+file.Name()+"/"); err != nil {
  607. log.Error("uploadCodeToObs(%s) failed: %s", file.Name(), err.Error())
  608. return err
  609. }
  610. } else {
  611. input := &obs.PutFileInput{}
  612. input.Bucket = setting.Bucket
  613. input.Key = setting.CodePathPrefix + jobName + "/code/" + parentDir + file.Name()
  614. input.SourceFile = codePath + file.Name()
  615. _, err = storage.ObsCli.PutFile(input)
  616. if err != nil {
  617. log.Error("PutFile(%s) failed: %s", input.SourceFile, err.Error())
  618. return err
  619. }
  620. }
  621. }
  622. return nil
  623. }
  624. func paramCheckCreateTrainJob(bootFile string, branchName string) error {
  625. if !strings.HasSuffix(strings.TrimSpace(bootFile), ".py") {
  626. log.Error("the boot file(%s) must be a python file", bootFile)
  627. return errors.New("启动文件必须是python文件")
  628. }
  629. if branchName == "" {
  630. log.Error("the branch must not be null!", branchName)
  631. return errors.New("代码分支不能为空!")
  632. }
  633. return nil
  634. }
  635. func downloadZipCode(ctx *context.Context, codePath, branchName string) error {
  636. archiveType := git.ZIP
  637. archivePath := codePath
  638. if !com.IsDir(archivePath) {
  639. if err := os.MkdirAll(archivePath, os.ModePerm); err != nil {
  640. log.Error("MkdirAll failed:" + err.Error())
  641. return err
  642. }
  643. }
  644. // Get corresponding commit.
  645. var (
  646. commit *git.Commit
  647. err error
  648. )
  649. gitRepo := ctx.Repo.GitRepo
  650. if err != nil {
  651. log.Error("OpenRepository failed:" + err.Error())
  652. return err
  653. }
  654. if gitRepo.IsBranchExist(branchName) {
  655. commit, err = gitRepo.GetBranchCommit(branchName)
  656. if err != nil {
  657. log.Error("GetBranchCommit failed:" + err.Error())
  658. return err
  659. }
  660. } else {
  661. log.Error("the branch is not exist: " + branchName)
  662. return fmt.Errorf("The branch does not exist.")
  663. }
  664. archivePath = path.Join(archivePath, grampus.CodeArchiveName)
  665. if !com.IsFile(archivePath) {
  666. if err := commit.CreateArchive(archivePath, git.CreateArchiveOpts{
  667. Format: archiveType,
  668. Prefix: setting.Repository.PrefixArchiveFiles,
  669. }); err != nil {
  670. log.Error("CreateArchive failed:" + err.Error())
  671. return err
  672. }
  673. }
  674. return nil
  675. }
  676. func uploadCodeToMinio(codePath, jobName, parentDir string) error {
  677. files, err := readDir(codePath)
  678. if err != nil {
  679. log.Error("readDir(%s) failed: %s", codePath, err.Error())
  680. return err
  681. }
  682. for _, file := range files {
  683. if file.IsDir() {
  684. if err = uploadCodeToMinio(codePath+file.Name()+"/", jobName, parentDir+file.Name()+"/"); err != nil {
  685. log.Error("uploadCodeToMinio(%s) failed: %s", file.Name(), err.Error())
  686. return err
  687. }
  688. } else {
  689. destObject := setting.CBCodePathPrefix + jobName + parentDir + file.Name()
  690. sourceFile := codePath + file.Name()
  691. err = storage.Attachments.UploadObject(destObject, sourceFile)
  692. if err != nil {
  693. log.Error("UploadObject(%s) failed: %s", file.Name(), err.Error())
  694. return err
  695. }
  696. }
  697. }
  698. return nil
  699. }
  700. func readDir(dirname string) ([]os.FileInfo, error) {
  701. f, err := os.Open(dirname)
  702. if err != nil {
  703. return nil, err
  704. }
  705. list, err := f.Readdir(0)
  706. f.Close()
  707. if err != nil {
  708. //todo: can not upload empty folder
  709. if err == io.EOF {
  710. return nil, nil
  711. }
  712. return nil, err
  713. }
  714. //sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() })
  715. return list, nil
  716. }
  717. func mkModelPath(modelPath string) error {
  718. return mkPathAndReadMeFile(modelPath, "You can put the files into this directory and download the files by the web page.")
  719. }
  720. func mkPathAndReadMeFile(path string, text string) error {
  721. err := os.MkdirAll(path, os.ModePerm)
  722. if err != nil {
  723. log.Error("MkdirAll(%s) failed:%v", path, err)
  724. return err
  725. }
  726. fileName := path + "README"
  727. f, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
  728. if err != nil {
  729. log.Error("OpenFile failed", err.Error())
  730. return err
  731. }
  732. defer f.Close()
  733. _, err = f.WriteString(text)
  734. if err != nil {
  735. log.Error("WriteString failed", err.Error())
  736. return err
  737. }
  738. return nil
  739. }
  740. func getPreTrainModelPath(pretrainModelDir string, fileName string) string {
  741. index := strings.Index(pretrainModelDir, "/")
  742. if index > 0 {
  743. filterBucket := pretrainModelDir[index+1:]
  744. return filterBucket + fileName
  745. } else {
  746. return ""
  747. }
  748. }
  749. func generateCommand(repoName, processorType, codeRemotePath, dataRemotePath, bootFile, paramSrc, outputRemotePath, datasetName, pretrainModelPath, pretrainModelFileName, modelRemoteObsUrl string) (string, error) {
  750. var command string
  751. //prepare
  752. workDir := grampus.NpuWorkDir
  753. if processorType == grampus.ProcessorTypeNPU {
  754. command += "pwd;cd " + workDir + grampus.CommandPrepareScriptNpu
  755. } else if processorType == grampus.ProcessorTypeGPU {
  756. workDir = grampus.GpuWorkDir
  757. command += "pwd;cd " + workDir + fmt.Sprintf(grampus.CommandPrepareScriptGpu, setting.Grampus.SyncScriptProject, setting.Grampus.SyncScriptProject)
  758. }
  759. //download code & dataset
  760. if processorType == grampus.ProcessorTypeNPU {
  761. //no need to download code & dataset by internet
  762. } else if processorType == grampus.ProcessorTypeGPU {
  763. commandDownload := "./downloader_for_minio " + setting.Grampus.Env + " " + codeRemotePath + " " + grampus.CodeArchiveName + " '" + dataRemotePath + "' '" + datasetName + "'"
  764. commandDownload = processPretrainModelParameter(pretrainModelPath, pretrainModelFileName, commandDownload)
  765. command += commandDownload
  766. }
  767. //unzip code & dataset
  768. if processorType == grampus.ProcessorTypeNPU {
  769. //no need to process
  770. } else if processorType == grampus.ProcessorTypeGPU {
  771. unZipDatasetCommand := generateDatasetUnzipCommand(datasetName)
  772. commandUnzip := "cd " + workDir + "code;unzip -q master.zip;rm -f master.zip;echo \"start to unzip dataset\";cd " + workDir + "dataset;" + unZipDatasetCommand
  773. command += commandUnzip
  774. }
  775. command += "echo \"unzip finished;start to exec code;\";"
  776. // set export
  777. var commandExport string
  778. if processorType == grampus.ProcessorTypeNPU {
  779. commandExport = "export bucket=" + setting.Bucket + " && export remote_path=" + outputRemotePath + ";"
  780. } else if processorType == grampus.ProcessorTypeGPU {
  781. commandExport = "export env=" + setting.Grampus.Env + " && export remote_path=" + outputRemotePath + ";"
  782. }
  783. command += commandExport
  784. //exec code
  785. var parameters models.Parameters
  786. var paramCode string
  787. if len(paramSrc) != 0 {
  788. err := json.Unmarshal([]byte(paramSrc), &parameters)
  789. if err != nil {
  790. log.Error("Failed to Unmarshal params: %s (%v)", paramSrc, err)
  791. return command, err
  792. }
  793. for _, parameter := range parameters.Parameter {
  794. paramCode += " --" + parameter.Label + "=" + parameter.Value
  795. }
  796. }
  797. var commandCode string
  798. if processorType == grampus.ProcessorTypeNPU {
  799. paramCode += " --model_url=" + modelRemoteObsUrl
  800. commandCode = "/bin/bash /home/work/run_train_for_openi.sh /home/work/openi.py " + grampus.NpuLocalLogUrl + paramCode + ";"
  801. } else if processorType == grampus.ProcessorTypeGPU {
  802. if pretrainModelFileName != "" {
  803. paramCode += " --ckpt_url" + "=" + workDir + "pretrainmodel/" + pretrainModelFileName
  804. }
  805. commandCode = "cd " + workDir + "code/" + strings.ToLower(repoName) + ";python " + bootFile + paramCode + ";"
  806. }
  807. command += commandCode
  808. //get exec result
  809. commandGetRes := "result=$?;"
  810. command += commandGetRes
  811. //upload models
  812. if processorType == grampus.ProcessorTypeNPU {
  813. // no need to upload
  814. } else if processorType == grampus.ProcessorTypeGPU {
  815. commandUpload := "cd " + workDir + setting.Grampus.SyncScriptProject + "/;./uploader_for_gpu " + setting.Grampus.Env + " " + outputRemotePath + " " + workDir + "output/;"
  816. command += commandUpload
  817. }
  818. //check exec result
  819. commandCheckRes := "bash -c \"[[ $result -eq 0 ]] && exit 0 || exit -1\""
  820. command += commandCheckRes
  821. return command, nil
  822. }
  823. func processPretrainModelParameter(pretrainModelPath string, pretrainModelFileName string, commandDownload string) string {
  824. commandDownloadTemp := commandDownload
  825. if pretrainModelPath != "" {
  826. commandDownloadTemp += " '" + pretrainModelPath + "' '" + pretrainModelFileName + "'"
  827. }
  828. commandDownloadTemp += ";"
  829. return commandDownloadTemp
  830. }
  831. func generateDatasetUnzipCommand(datasetName string) string {
  832. var unZipDatasetCommand string
  833. datasetNameArray := strings.Split(datasetName, ";")
  834. if len(datasetNameArray) == 1 { //单数据集
  835. unZipDatasetCommand = "unzip -q '" + datasetName + "';"
  836. if strings.HasSuffix(datasetNameArray[0], ".tar.gz") {
  837. unZipDatasetCommand = "tar --strip-components=1 -zxvf '" + datasetName + "';"
  838. }
  839. } else { //多数据集
  840. for _, datasetNameTemp := range datasetNameArray {
  841. if strings.HasSuffix(datasetNameTemp, ".tar.gz") {
  842. unZipDatasetCommand = unZipDatasetCommand + "tar -zxvf '" + datasetNameTemp + "';"
  843. } else {
  844. unZipDatasetCommand = unZipDatasetCommand + "unzip -q '" + datasetNameTemp + "' -d './" + strings.TrimSuffix(datasetNameTemp, ".zip") + "';"
  845. }
  846. }
  847. }
  848. return unZipDatasetCommand
  849. }
  850. func getPoolId() string {
  851. var resourcePools modelarts.ResourcePool
  852. json.Unmarshal([]byte(setting.ResourcePools), &resourcePools)
  853. return resourcePools.Info[0].ID
  854. }
  855. func PrepareSpec4Show(task *models.Cloudbrain) {
  856. s, err := resource.GetCloudbrainSpec(task.ID)
  857. if err != nil {
  858. log.Info("error:" + err.Error())
  859. return
  860. }
  861. task.Spec = s
  862. }
  863. func IsTaskNotStop(task *models.Cloudbrain) bool {
  864. statuses := CloudbrainOneNotFinalStatuses
  865. if task.Type == models.TypeCloudBrainTwo || task.Type == models.TypeCDCenter {
  866. statuses = CloudbrainTwoNotFinalStatuses
  867. } else {
  868. statuses = GrampusNotFinalStatuses
  869. }
  870. for _, status := range statuses {
  871. if task.Status == status {
  872. return true
  873. }
  874. }
  875. return false
  876. }
  877. func SyncTaskStatus(task *models.Cloudbrain) error {
  878. if task.Type == models.TypeCloudBrainOne {
  879. result, err := cloudbrain.GetJob(task.JobID)
  880. if err != nil {
  881. log.Info("error:" + err.Error())
  882. return fmt.Errorf("repo.cloudbrain_query_fail")
  883. }
  884. if result != nil {
  885. jobRes, _ := models.ConvertToJobResultPayload(result.Payload)
  886. taskRoles := jobRes.TaskRoles
  887. taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
  888. oldStatus := task.Status
  889. task.Status = taskRes.TaskStatuses[0].State
  890. task.ContainerID = taskRes.TaskStatuses[0].ContainerID
  891. models.ParseAndSetDurationFromCloudBrainOne(jobRes, task)
  892. if task.DeletedAt.IsZero() { //normal record
  893. if oldStatus != task.Status {
  894. notification.NotifyChangeCloudbrainStatus(task, oldStatus)
  895. }
  896. err = models.UpdateJob(task)
  897. if err != nil {
  898. return fmt.Errorf("repo.cloudbrain_query_fail")
  899. }
  900. }
  901. } else {
  902. log.Info("error:" + err.Error())
  903. return fmt.Errorf("repo.cloudbrain_query_fail")
  904. }
  905. } else if task.Type == models.TypeCloudBrainTwo || task.Type == models.TypeCDCenter {
  906. err := modelarts.HandleTrainJobInfo(task)
  907. if err != nil {
  908. return fmt.Errorf("repo.cloudbrain_query_fail")
  909. }
  910. } else if task.Type == models.TypeC2Net {
  911. result, err := grampus.GetJob(task.JobID)
  912. if err != nil {
  913. log.Error("GetJob failed:" + err.Error())
  914. return fmt.Errorf("repo.cloudbrain_query_fail")
  915. }
  916. if result != nil {
  917. if len(result.JobInfo.Tasks[0].CenterID) == 1 && len(result.JobInfo.Tasks[0].CenterName) == 1 {
  918. task.AiCenter = result.JobInfo.Tasks[0].CenterID[0] + "+" + result.JobInfo.Tasks[0].CenterName[0]
  919. }
  920. oldStatus := task.Status
  921. task.Status = grampus.TransTrainJobStatus(result.JobInfo.Status)
  922. if task.Status != result.JobInfo.Status || result.JobInfo.Status == models.GrampusStatusRunning {
  923. task.Duration = result.JobInfo.RunSec
  924. if task.Duration < 0 {
  925. task.Duration = 0
  926. }
  927. task.TrainJobDuration = models.ConvertDurationToStr(task.Duration)
  928. if task.StartTime == 0 && result.JobInfo.StartedAt > 0 {
  929. task.StartTime = timeutil.TimeStamp(result.JobInfo.StartedAt)
  930. }
  931. if task.EndTime == 0 && models.IsTrainJobTerminal(task.Status) && task.StartTime > 0 {
  932. task.EndTime = task.StartTime.Add(task.Duration)
  933. }
  934. task.CorrectCreateUnix()
  935. if oldStatus != task.Status {
  936. notification.NotifyChangeCloudbrainStatus(task, oldStatus)
  937. if models.IsTrainJobTerminal(task.Status) && task.ComputeResource == models.NPUResource {
  938. if len(result.JobInfo.Tasks[0].CenterID) == 1 {
  939. urchin.GetBackNpuModel(task.ID, grampus.GetRemoteEndPoint(result.JobInfo.Tasks[0].CenterID[0]), grampus.BucketRemote, grampus.GetNpuModelObjectKey(task.JobName), grampus.GetCenterProxy(setting.Grampus.LocalCenterID))
  940. }
  941. }
  942. }
  943. err = models.UpdateJob(task)
  944. if err != nil {
  945. log.Error("UpdateJob failed:" + err.Error())
  946. return fmt.Errorf("repo.cloudbrain_query_fail")
  947. }
  948. }
  949. }
  950. }
  951. return nil
  952. }
  953. func getTrainJobCommand(option api.CreateTrainJobOption) (string, error) {
  954. var command string
  955. bootFile := strings.TrimSpace(option.BootFile)
  956. params := option.Params
  957. if !strings.HasSuffix(bootFile, ".py") {
  958. log.Error("bootFile(%s) format error", bootFile)
  959. return command, errors.New("bootFile format error")
  960. }
  961. var parameters models.Parameters
  962. var param string
  963. if len(params) != 0 {
  964. err := json.Unmarshal([]byte(params), &parameters)
  965. if err != nil {
  966. log.Error("Failed to Unmarshal params: %s (%v)", params, err)
  967. return command, err
  968. }
  969. for _, parameter := range parameters.Parameter {
  970. param += " --" + parameter.Label + "=" + parameter.Value
  971. }
  972. }
  973. if option.CkptName != "" {
  974. param += " --ckpt_url" + "=" + "/pretrainmodel/" + option.CkptName
  975. }
  976. command += "python /code/" + bootFile + param + " > " + cloudbrain.ModelMountPath + "/" + option.DisplayJobName + "-" + cloudbrain.LogFile
  977. return command, nil
  978. }
  979. func checkMultiNode(userId int64, serverNum int) string {
  980. if serverNum == 1 {
  981. return ""
  982. }
  983. modelarts.InitMultiNode()
  984. var isServerNumValid = false
  985. if modelarts.MultiNodeConfig != nil {
  986. for _, info := range modelarts.MultiNodeConfig.Info {
  987. if isInOrg, _ := models.IsOrganizationMemberByOrgName(info.Org, userId); isInOrg {
  988. if isInNodes(info.Node, serverNum) {
  989. isServerNumValid = true
  990. break
  991. }
  992. }
  993. }
  994. }
  995. if isServerNumValid {
  996. return ""
  997. } else {
  998. return "repo.modelarts.no_node_right"
  999. }
  1000. }
  1001. func isInNodes(nodes []int, num int) bool {
  1002. for _, node := range nodes {
  1003. if node == num {
  1004. return true
  1005. }
  1006. }
  1007. return false
  1008. }
  1009. func getUserCommand(engineId int, req *modelarts.GenerateTrainJobReq) (string, string) {
  1010. userImageUrl := ""
  1011. userCommand := ""
  1012. if engineId < 0 {
  1013. tmpCodeObsPath := strings.Trim(req.CodeObsPath, "/")
  1014. tmpCodeObsPaths := strings.Split(tmpCodeObsPath, "/")
  1015. lastCodeDir := "code"
  1016. if len(tmpCodeObsPaths) > 0 {
  1017. lastCodeDir = tmpCodeObsPaths[len(tmpCodeObsPaths)-1]
  1018. }
  1019. userCommand = "/bin/bash /home/work/run_train.sh 's3://" + req.CodeObsPath + "' '" + lastCodeDir + "/" + req.BootFile + "' '/tmp/log/train.log' --'data_url'='s3://" + req.DataUrl + "' --'train_url'='s3://" + req.TrainUrl + "'"
  1020. var versionInfos modelarts.VersionInfo
  1021. if err := json.Unmarshal([]byte(setting.EngineVersions), &versionInfos); err != nil {
  1022. log.Info("json parse err." + err.Error())
  1023. } else {
  1024. for _, engine := range versionInfos.Version {
  1025. if engine.ID == engineId {
  1026. userImageUrl = engine.Url
  1027. break
  1028. }
  1029. }
  1030. }
  1031. for _, param := range req.Parameters {
  1032. userCommand += " --'" + param.Label + "'='" + param.Value + "'"
  1033. }
  1034. return userCommand, userImageUrl
  1035. }
  1036. return userCommand, userImageUrl
  1037. }