You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelarts.go 25 kB

4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
3 years ago
3 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
4 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago
4 years ago
4 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
4 years ago
3 years ago
4 years ago
3 years ago

  1. package modelarts
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "path"
  7. "strconv"
  8. "code.gitea.io/gitea/modules/timeutil"
  9. "code.gitea.io/gitea/models"
  10. "code.gitea.io/gitea/modules/context"
  11. "code.gitea.io/gitea/modules/log"
  12. "code.gitea.io/gitea/modules/notification"
  13. "code.gitea.io/gitea/modules/setting"
  14. "code.gitea.io/gitea/modules/storage"
  15. )
  16. const (
  17. //notebook
  18. storageTypeOBS = "obs"
  19. autoStopDuration = 4 * 60 * 60
  20. autoStopDurationMs = 4 * 60 * 60 * 1000
  21. MORDELART_USER_IMAGE_ENGINE_ID = -1
  22. DataSetMountPath = "/home/ma-user/work"
  23. NotebookEnv = "Python3"
  24. NotebookType = "Ascend"
  25. FlavorInfo = "Ascend: 1*Ascend 910 CPU: 24 核 96GiB (modelarts.kat1.xlarge)"
  26. //train-job
  27. // ResourcePools = "{\"resource_pool\":[{\"id\":\"pool1328035d\", \"value\":\"专属资源池\"}]}"
  28. // Engines = "{\"engine\":[{\"id\":1, \"value\":\"Ascend-Powered-Engine\"}]}"
  29. // EngineVersions = "{\"version\":[{\"id\":118,\"value\":\"MindSpore-1.0.0-c75-python3.7-euleros2.8-aarch64\"}," +
  30. // "{\"id\":119,\"value\":\"MindSpore-1.1.1-c76-python3.7-euleros2.8-aarch64\"}," +
  31. // "{\"id\":120,\"value\":\"MindSpore-1.1.1-c76-tr5-python3.7-euleros2.8-aarch64\"}," +
  32. // "{\"id\":117,\"value\":\"TF-1.15-c75-python3.7-euleros2.8-aarch64\"}" +
  33. // "]}"
  34. // TrainJobFlavorInfo = "{\"flavor\":[{\"code\":\"modelarts.bm.910.arm.public.2\",\"value\":\"Ascend : 2 * Ascend 910 CPU:48 核 512GiB\"}," +
  35. // "{\"code\":\"modelarts.bm.910.arm.public.8\",\"value\":\"Ascend : 8 * Ascend 910 CPU:192 核 2048GiB\"}," +
  36. // "{\"code\":\"modelarts.bm.910.arm.public.4\",\"value\":\"Ascend : 4 * Ascend 910 CPU:96 核 1024GiB\"}," +
  37. // "{\"code\":\"modelarts.bm.910.arm.public.1\",\"value\":\"Ascend : 1 * Ascend 910 CPU:24 核 256GiB\"}" +
  38. // "]}"
  39. CodePath = "/code/"
  40. OutputPath = "/output/"
  41. ResultPath = "/result/"
  42. LogPath = "/log/"
  43. JobPath = "/job/"
  44. OrderDesc = "desc" //向下查询
  45. OrderAsc = "asc" //向上查询
  46. Lines = 500
  47. TrainUrl = "train_url"
  48. DataUrl = "data_url"
  49. MultiDataUrl = "multi_data_url"
  50. ResultUrl = "result_url"
  51. CkptUrl = "ckpt_url"
  52. DeviceTarget = "device_target"
  53. Ascend = "Ascend"
  54. PerPage = 10
  55. IsLatestVersion = "1"
  56. NotLatestVersion = "0"
  57. VersionCount = 1
  58. SortByCreateTime = "create_time"
  59. ConfigTypeCustom = "custom"
  60. TotalVersionCount = 1
  61. )
  62. var (
  63. poolInfos *models.PoolInfos
  64. FlavorInfos *models.FlavorInfos
  65. ImageInfos *models.ImageInfosModelArts
  66. TrainFlavorInfos *Flavor
  67. SpecialPools *models.SpecialPools
  68. )
  69. type GenerateTrainJobReq struct {
  70. JobName string
  71. DisplayJobName string
  72. Uuid string
  73. Description string
  74. CodeObsPath string
  75. BootFile string
  76. BootFileUrl string
  77. DataUrl string
  78. TrainUrl string
  79. FlavorCode string
  80. LogUrl string
  81. PoolID string
  82. WorkServerNumber int
  83. EngineID int64
  84. Parameters []models.Parameter
  85. CommitID string
  86. IsLatestVersion string
  87. Params string
  88. BranchName string
  89. PreVersionId int64
  90. PreVersionName string
  91. FlavorName string
  92. VersionCount int
  93. EngineName string
  94. TotalVersionCount int
  95. UserImageUrl string
  96. UserCommand string
  97. DatasetName string
  98. }
  99. type GenerateInferenceJobReq struct {
  100. JobName string
  101. DisplayJobName string
  102. Uuid string
  103. Description string
  104. CodeObsPath string
  105. BootFile string
  106. BootFileUrl string
  107. DataUrl string
  108. TrainUrl string
  109. FlavorCode string
  110. LogUrl string
  111. PoolID string
  112. WorkServerNumber int
  113. EngineID int64
  114. Parameters []models.Parameter
  115. CommitID string
  116. Params string
  117. BranchName string
  118. FlavorName string
  119. EngineName string
  120. LabelName string
  121. IsLatestVersion string
  122. VersionCount int
  123. TotalVersionCount int
  124. ModelName string
  125. ModelVersion string
  126. CkptName string
  127. ResultUrl string
  128. }
  129. type VersionInfo struct {
  130. Version []struct {
  131. ID int `json:"id"`
  132. Value string `json:"value"`
  133. Url string `json:"url"`
  134. } `json:"version"`
  135. }
  136. type Flavor struct {
  137. Info []struct {
  138. Code string `json:"code"`
  139. Value string `json:"value"`
  140. } `json:"flavor"`
  141. }
  142. type Engine struct {
  143. Info []struct {
  144. ID int `json:"id"`
  145. Value string `json:"value"`
  146. } `json:"engine"`
  147. }
  148. type ResourcePool struct {
  149. Info []struct {
  150. ID string `json:"id"`
  151. Value string `json:"value"`
  152. } `json:"resource_pool"`
  153. }
  154. // type Parameter struct {
  155. // Label string `json:"label"`
  156. // Value string `json:"value"`
  157. // }
  158. // type Parameters struct {
  159. // Parameter []Parameter `json:"parameter"`
  160. // }
  161. type Parameters struct {
  162. Parameter []struct {
  163. Label string `json:"label"`
  164. Value string `json:"value"`
  165. } `json:"parameter"`
  166. }
  167. func GenerateTask(ctx *context.Context, jobName, uuid, description, flavor string) error {
  168. var dataActualPath string
  169. if uuid != "" {
  170. dataActualPath = setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + "/"
  171. } else {
  172. userPath := setting.UserBasePath + ctx.User.Name + "/"
  173. isExist, err := storage.ObsHasObject(userPath)
  174. if err != nil {
  175. log.Error("ObsHasObject failed:%v", err.Error(), ctx.Data["MsgID"])
  176. return err
  177. }
  178. if !isExist {
  179. if err = storage.ObsCreateObject(userPath); err != nil {
  180. log.Error("ObsCreateObject failed:%v", err.Error(), ctx.Data["MsgID"])
  181. return err
  182. }
  183. }
  184. dataActualPath = setting.Bucket + "/" + userPath
  185. }
  186. if poolInfos == nil {
  187. json.Unmarshal([]byte(setting.PoolInfos), &poolInfos)
  188. }
  189. createTime := timeutil.TimeStampNow()
  190. jobResult, err := CreateJob(models.CreateNotebookParams{
  191. JobName: jobName,
  192. Description: description,
  193. ProfileID: setting.ProfileID,
  194. Flavor: flavor,
  195. Pool: models.Pool{
  196. ID: poolInfos.PoolInfo[0].PoolId,
  197. Name: poolInfos.PoolInfo[0].PoolName,
  198. Type: poolInfos.PoolInfo[0].PoolType,
  199. },
  200. Spec: models.Spec{
  201. Storage: models.Storage{
  202. Type: storageTypeOBS,
  203. Location: models.Location{
  204. Path: dataActualPath,
  205. },
  206. },
  207. AutoStop: models.AutoStop{
  208. Enable: true,
  209. Duration: autoStopDuration,
  210. },
  211. },
  212. })
  213. if err != nil {
  214. log.Error("CreateJob failed: %v", err.Error())
  215. return err
  216. }
  217. err = models.CreateCloudbrain(&models.Cloudbrain{
  218. Status: string(models.JobWaiting),
  219. UserID: ctx.User.ID,
  220. RepoID: ctx.Repo.Repository.ID,
  221. JobID: jobResult.ID,
  222. JobName: jobName,
  223. JobType: string(models.JobTypeDebug),
  224. Type: models.TypeCloudBrainTwo,
  225. Uuid: uuid,
  226. ComputeResource: models.NPUResource,
  227. CreatedUnix: createTime,
  228. UpdatedUnix: createTime,
  229. })
  230. if err != nil {
  231. return err
  232. }
  233. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobResult.ID, jobName, models.ActionCreateDebugNPUTask)
  234. return nil
  235. }
  236. func GenerateNotebook2(ctx *context.Context, displayJobName, jobName, uuid, description, flavor, imageId string) error {
  237. if poolInfos == nil {
  238. json.Unmarshal([]byte(setting.PoolInfos), &poolInfos)
  239. }
  240. imageName, err := GetNotebookImageName(imageId)
  241. if err != nil {
  242. log.Error("GetNotebookImageName failed: %v", err.Error())
  243. return err
  244. }
  245. createTime := timeutil.TimeStampNow()
  246. jobResult, err := createNotebook2(models.CreateNotebook2Params{
  247. JobName: jobName,
  248. Description: description,
  249. Flavor: flavor,
  250. Duration: autoStopDurationMs,
  251. ImageID: imageId,
  252. PoolID: poolInfos.PoolInfo[0].PoolId,
  253. Feature: models.NotebookFeature,
  254. Volume: models.VolumeReq{
  255. Capacity: setting.Capacity,
  256. Category: models.EVSCategory,
  257. Ownership: models.ManagedOwnership,
  258. },
  259. WorkspaceID: "0",
  260. })
  261. if err != nil {
  262. log.Error("createNotebook2 failed: %v", err.Error())
  263. return err
  264. }
  265. err = models.CreateCloudbrain(&models.Cloudbrain{
  266. Status: jobResult.Status,
  267. UserID: ctx.User.ID,
  268. RepoID: ctx.Repo.Repository.ID,
  269. JobID: jobResult.ID,
  270. JobName: jobName,
  271. FlavorCode: flavor,
  272. DisplayJobName: displayJobName,
  273. JobType: string(models.JobTypeDebug),
  274. Type: models.TypeCloudBrainTwo,
  275. Uuid: uuid,
  276. ComputeResource: models.NPUResource,
  277. Image: imageName,
  278. Description: description,
  279. CreatedUnix: createTime,
  280. UpdatedUnix: createTime,
  281. })
  282. if err != nil {
  283. return err
  284. }
  285. task, err := models.GetCloudbrainByName(jobName)
  286. if err != nil {
  287. log.Error("GetCloudbrainByName failed: %v", err.Error())
  288. return err
  289. }
  290. stringId := strconv.FormatInt(task.ID, 10)
  291. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, stringId, displayJobName, models.ActionCreateDebugNPUTask)
  292. return nil
  293. }
  294. func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error) {
  295. createTime := timeutil.TimeStampNow()
  296. var jobResult *models.CreateTrainJobResult
  297. var createErr error
  298. if req.EngineID < 0 {
  299. jobResult, createErr = createTrainJobUserImage(models.CreateUserImageTrainJobParams{
  300. JobName: req.JobName,
  301. Description: req.Description,
  302. Config: models.UserImageConfig{
  303. WorkServerNum: req.WorkServerNumber,
  304. AppUrl: req.CodeObsPath,
  305. BootFileUrl: req.BootFileUrl,
  306. DataUrl: req.DataUrl,
  307. TrainUrl: req.TrainUrl,
  308. LogUrl: req.LogUrl,
  309. PoolID: req.PoolID,
  310. CreateVersion: true,
  311. Flavor: models.Flavor{
  312. Code: req.FlavorCode,
  313. },
  314. Parameter: req.Parameters,
  315. UserImageUrl: req.UserImageUrl,
  316. UserCommand: req.UserCommand,
  317. },
  318. })
  319. } else {
  320. jobResult, createErr = createTrainJob(models.CreateTrainJobParams{
  321. JobName: req.JobName,
  322. Description: req.Description,
  323. Config: models.Config{
  324. WorkServerNum: req.WorkServerNumber,
  325. AppUrl: req.CodeObsPath,
  326. BootFileUrl: req.BootFileUrl,
  327. DataUrl: req.DataUrl,
  328. EngineID: req.EngineID,
  329. TrainUrl: req.TrainUrl,
  330. LogUrl: req.LogUrl,
  331. PoolID: req.PoolID,
  332. CreateVersion: true,
  333. Flavor: models.Flavor{
  334. Code: req.FlavorCode,
  335. },
  336. Parameter: req.Parameters,
  337. },
  338. })
  339. }
  340. if createErr != nil {
  341. log.Error("CreateJob failed: %v", createErr.Error())
  342. return createErr
  343. }
  344. jobId := strconv.FormatInt(jobResult.JobID, 10)
  345. createErr = models.CreateCloudbrain(&models.Cloudbrain{
  346. Status: TransTrainJobStatus(jobResult.Status),
  347. UserID: ctx.User.ID,
  348. RepoID: ctx.Repo.Repository.ID,
  349. JobID: jobId,
  350. JobName: req.JobName,
  351. DisplayJobName: req.DisplayJobName,
  352. JobType: string(models.JobTypeTrain),
  353. Type: models.TypeCloudBrainTwo,
  354. VersionID: jobResult.VersionID,
  355. VersionName: jobResult.VersionName,
  356. Uuid: req.Uuid,
  357. DatasetName: req.DatasetName,
  358. CommitID: req.CommitID,
  359. IsLatestVersion: req.IsLatestVersion,
  360. ComputeResource: models.NPUResource,
  361. EngineID: req.EngineID,
  362. TrainUrl: req.TrainUrl,
  363. BranchName: req.BranchName,
  364. Parameters: req.Params,
  365. BootFile: req.BootFile,
  366. DataUrl: req.DataUrl,
  367. LogUrl: req.LogUrl,
  368. FlavorCode: req.FlavorCode,
  369. Description: req.Description,
  370. WorkServerNumber: req.WorkServerNumber,
  371. FlavorName: req.FlavorName,
  372. EngineName: req.EngineName,
  373. VersionCount: req.VersionCount,
  374. TotalVersionCount: req.TotalVersionCount,
  375. CreatedUnix: createTime,
  376. UpdatedUnix: createTime,
  377. })
  378. if createErr != nil {
  379. log.Error("CreateCloudbrain(%s) failed:%v", req.DisplayJobName, createErr.Error())
  380. return createErr
  381. }
  382. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobId, req.DisplayJobName, models.ActionCreateTrainTask)
  383. return nil
  384. }
  385. func GenerateModelConvertTrainJob(req *GenerateTrainJobReq) (*models.CreateTrainJobResult, error) {
  386. return createTrainJobUserImage(models.CreateUserImageTrainJobParams{
  387. JobName: req.JobName,
  388. Description: req.Description,
  389. Config: models.UserImageConfig{
  390. WorkServerNum: req.WorkServerNumber,
  391. AppUrl: req.CodeObsPath,
  392. BootFileUrl: req.BootFileUrl,
  393. DataUrl: req.DataUrl,
  394. TrainUrl: req.TrainUrl,
  395. LogUrl: req.LogUrl,
  396. PoolID: req.PoolID,
  397. CreateVersion: true,
  398. Flavor: models.Flavor{
  399. Code: req.FlavorCode,
  400. },
  401. Parameter: req.Parameters,
  402. UserImageUrl: req.UserImageUrl,
  403. UserCommand: req.UserCommand,
  404. },
  405. })
  406. }
  407. func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, jobId string) (err error) {
  408. createTime := timeutil.TimeStampNow()
  409. var jobResult *models.CreateTrainJobResult
  410. var createErr error
  411. log.Info(" req.EngineID =" + fmt.Sprint(req.EngineID))
  412. if req.EngineID < 0 {
  413. jobResult, createErr = createTrainJobVersionUserImage(models.CreateTrainJobVersionUserImageParams{
  414. Description: req.Description,
  415. Config: models.TrainJobVersionUserImageConfig{
  416. WorkServerNum: req.WorkServerNumber,
  417. AppUrl: req.CodeObsPath,
  418. BootFileUrl: req.BootFileUrl,
  419. DataUrl: req.DataUrl,
  420. TrainUrl: req.TrainUrl,
  421. LogUrl: req.LogUrl,
  422. PoolID: req.PoolID,
  423. Flavor: models.Flavor{
  424. Code: req.FlavorCode,
  425. },
  426. Parameter: req.Parameters,
  427. PreVersionId: req.PreVersionId,
  428. UserImageUrl: req.UserImageUrl,
  429. UserCommand: req.UserCommand,
  430. },
  431. }, jobId)
  432. } else {
  433. jobResult, createErr = createTrainJobVersion(models.CreateTrainJobVersionParams{
  434. Description: req.Description,
  435. Config: models.TrainJobVersionConfig{
  436. WorkServerNum: req.WorkServerNumber,
  437. AppUrl: req.CodeObsPath,
  438. BootFileUrl: req.BootFileUrl,
  439. DataUrl: req.DataUrl,
  440. EngineID: req.EngineID,
  441. TrainUrl: req.TrainUrl,
  442. LogUrl: req.LogUrl,
  443. PoolID: req.PoolID,
  444. Flavor: models.Flavor{
  445. Code: req.FlavorCode,
  446. },
  447. Parameter: req.Parameters,
  448. PreVersionId: req.PreVersionId,
  449. },
  450. }, jobId)
  451. }
  452. if createErr != nil {
  453. log.Error("CreateJob failed: %v", createErr.Error())
  454. return createErr
  455. }
  456. var jobTypes []string
  457. jobTypes = append(jobTypes, string(models.JobTypeTrain))
  458. repo := ctx.Repo.Repository
  459. VersionTaskList, VersionListCount, createErr := models.CloudbrainsVersionList(&models.CloudbrainsOptions{
  460. RepoID: repo.ID,
  461. Type: models.TypeCloudBrainTwo,
  462. JobTypes: jobTypes,
  463. JobID: strconv.FormatInt(jobResult.JobID, 10),
  464. })
  465. if createErr != nil {
  466. ctx.ServerError("Cloudbrain", createErr)
  467. return createErr
  468. }
  469. //将当前版本的isLatestVersion设置为"1"和任务数量更新,任务数量包括当前版本数VersionCount和历史创建的总版本数TotalVersionCount
  470. createErr = models.CreateCloudbrain(&models.Cloudbrain{
  471. Status: TransTrainJobStatus(jobResult.Status),
  472. UserID: ctx.User.ID,
  473. RepoID: ctx.Repo.Repository.ID,
  474. JobID: strconv.FormatInt(jobResult.JobID, 10),
  475. JobName: req.JobName,
  476. DisplayJobName: req.DisplayJobName,
  477. JobType: string(models.JobTypeTrain),
  478. Type: models.TypeCloudBrainTwo,
  479. VersionID: jobResult.VersionID,
  480. VersionName: jobResult.VersionName,
  481. Uuid: req.Uuid,
  482. DatasetName: req.DatasetName,
  483. CommitID: req.CommitID,
  484. IsLatestVersion: req.IsLatestVersion,
  485. PreVersionName: req.PreVersionName,
  486. ComputeResource: models.NPUResource,
  487. EngineID: req.EngineID,
  488. TrainUrl: req.TrainUrl,
  489. BranchName: req.BranchName,
  490. Parameters: req.Params,
  491. BootFile: req.BootFile,
  492. DataUrl: req.DataUrl,
  493. LogUrl: req.LogUrl,
  494. PreVersionId: req.PreVersionId,
  495. FlavorCode: req.FlavorCode,
  496. Description: req.Description,
  497. WorkServerNumber: req.WorkServerNumber,
  498. FlavorName: req.FlavorName,
  499. EngineName: req.EngineName,
  500. TotalVersionCount: VersionTaskList[0].TotalVersionCount + 1,
  501. VersionCount: VersionListCount + 1,
  502. CreatedUnix: createTime,
  503. UpdatedUnix: createTime,
  504. })
  505. if createErr != nil {
  506. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, createErr.Error())
  507. return createErr
  508. }
  509. //将训练任务的上一版本的isLatestVersion设置为"0"
  510. createErr = models.SetVersionCountAndLatestVersion(strconv.FormatInt(jobResult.JobID, 10), VersionTaskList[0].VersionName, VersionCount, NotLatestVersion, TotalVersionCount)
  511. if createErr != nil {
  512. ctx.ServerError("Update IsLatestVersion failed", createErr)
  513. return createErr
  514. }
  515. return createErr
  516. }
  517. func GenerateTrainJobVersionByUserImage(ctx *context.Context, req *GenerateTrainJobReq, jobId string) (err error) {
  518. createTime := timeutil.TimeStampNow()
  519. jobResult, err := createTrainJobUserImage(models.CreateUserImageTrainJobParams{
  520. JobName: req.JobName,
  521. Description: req.Description,
  522. Config: models.UserImageConfig{
  523. WorkServerNum: req.WorkServerNumber,
  524. AppUrl: req.CodeObsPath,
  525. BootFileUrl: req.BootFileUrl,
  526. DataUrl: req.DataUrl,
  527. TrainUrl: req.TrainUrl,
  528. LogUrl: req.LogUrl,
  529. PoolID: req.PoolID,
  530. CreateVersion: true,
  531. Flavor: models.Flavor{
  532. Code: req.FlavorCode,
  533. },
  534. Parameter: req.Parameters,
  535. UserImageUrl: req.UserImageUrl,
  536. UserCommand: req.UserCommand,
  537. },
  538. })
  539. if err != nil {
  540. log.Error("CreateJob failed: %v", err.Error())
  541. return err
  542. }
  543. var jobTypes []string
  544. jobTypes = append(jobTypes, string(models.JobTypeTrain))
  545. repo := ctx.Repo.Repository
  546. VersionTaskList, VersionListCount, err := models.CloudbrainsVersionList(&models.CloudbrainsOptions{
  547. RepoID: repo.ID,
  548. Type: models.TypeCloudBrainTwo,
  549. JobTypes: jobTypes,
  550. JobID: strconv.FormatInt(jobResult.JobID, 10),
  551. })
  552. if err != nil {
  553. ctx.ServerError("Cloudbrain", err)
  554. return err
  555. }
  556. //将当前版本的isLatestVersion设置为"1"和任务数量更新,任务数量包括当前版本数VersionCount和历史创建的总版本数TotalVersionCount
  557. err = models.CreateCloudbrain(&models.Cloudbrain{
  558. Status: TransTrainJobStatus(jobResult.Status),
  559. UserID: ctx.User.ID,
  560. RepoID: ctx.Repo.Repository.ID,
  561. JobID: strconv.FormatInt(jobResult.JobID, 10),
  562. JobName: req.JobName,
  563. DisplayJobName: req.DisplayJobName,
  564. JobType: string(models.JobTypeTrain),
  565. Type: models.TypeCloudBrainTwo,
  566. VersionID: jobResult.VersionID,
  567. VersionName: jobResult.VersionName,
  568. Uuid: req.Uuid,
  569. DatasetName: req.DatasetName,
  570. CommitID: req.CommitID,
  571. IsLatestVersion: req.IsLatestVersion,
  572. PreVersionName: req.PreVersionName,
  573. ComputeResource: models.NPUResource,
  574. EngineID: MORDELART_USER_IMAGE_ENGINE_ID,
  575. Image: req.UserImageUrl,
  576. TrainUrl: req.TrainUrl,
  577. BranchName: req.BranchName,
  578. Parameters: req.Params,
  579. BootFile: req.BootFile,
  580. DataUrl: req.DataUrl,
  581. LogUrl: req.LogUrl,
  582. PreVersionId: req.PreVersionId,
  583. FlavorCode: req.FlavorCode,
  584. Description: req.Description,
  585. WorkServerNumber: req.WorkServerNumber,
  586. FlavorName: req.FlavorName,
  587. EngineName: req.EngineName,
  588. TotalVersionCount: VersionTaskList[0].TotalVersionCount + 1,
  589. VersionCount: VersionListCount + 1,
  590. CreatedUnix: createTime,
  591. UpdatedUnix: createTime,
  592. })
  593. if err != nil {
  594. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  595. return err
  596. }
  597. //将训练任务的上一版本的isLatestVersion设置为"0"
  598. err = models.SetVersionCountAndLatestVersion(strconv.FormatInt(jobResult.JobID, 10), VersionTaskList[0].VersionName, VersionCount, NotLatestVersion, TotalVersionCount)
  599. if err != nil {
  600. ctx.ServerError("Update IsLatestVersion failed", err)
  601. return err
  602. }
  603. return err
  604. }
  605. func TransTrainJobStatus(status int) string {
  606. switch status {
  607. case 0:
  608. return "UNKNOWN"
  609. case 1:
  610. return "INIT"
  611. case 2:
  612. return "IMAGE_CREATING"
  613. case 3:
  614. return "IMAGE_FAILED"
  615. case 4:
  616. return "SUBMIT_TRYING"
  617. case 5:
  618. return "SUBMIT_FAILED"
  619. case 6:
  620. return "DELETE_FAILED"
  621. case 7:
  622. return "WAITING"
  623. case 8:
  624. return "RUNNING"
  625. case 9:
  626. return "KILLING"
  627. case 10:
  628. return "COMPLETED"
  629. case 11:
  630. return "FAILED"
  631. case 12:
  632. return "KILLED"
  633. case 13:
  634. return "CANCELED"
  635. case 14:
  636. return "LOST"
  637. case 15:
  638. return "SCALING"
  639. case 16:
  640. return "SUBMIT_MODEL_FAILED"
  641. case 17:
  642. return "DEPLOY_SERVICE_FAILED"
  643. case 18:
  644. return "CHECK_INIT"
  645. case 19:
  646. return "CHECK_RUNNING"
  647. case 20:
  648. return "CHECK_RUNNING_COMPLETED"
  649. case 21:
  650. return "CHECK_FAILED"
  651. default:
  652. return strconv.Itoa(status)
  653. }
  654. }
  655. func GetOutputPathByCount(TotalVersionCount int) (VersionOutputPath string) {
  656. talVersionCountToString := fmt.Sprintf("%04d", TotalVersionCount)
  657. VersionOutputPath = "V" + talVersionCountToString
  658. return VersionOutputPath
  659. }
  660. func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) {
  661. createTime := timeutil.TimeStampNow()
  662. jobResult, err := createInferenceJob(models.CreateInferenceJobParams{
  663. JobName: req.JobName,
  664. Description: req.Description,
  665. InfConfig: models.InfConfig{
  666. WorkServerNum: req.WorkServerNumber,
  667. AppUrl: req.CodeObsPath,
  668. BootFileUrl: req.BootFileUrl,
  669. DataUrl: req.DataUrl,
  670. EngineID: req.EngineID,
  671. // TrainUrl: req.TrainUrl,
  672. LogUrl: req.LogUrl,
  673. PoolID: req.PoolID,
  674. CreateVersion: true,
  675. Flavor: models.Flavor{
  676. Code: req.FlavorCode,
  677. },
  678. Parameter: req.Parameters,
  679. },
  680. })
  681. if err != nil {
  682. log.Error("CreateJob failed: %v", err.Error())
  683. return err
  684. }
  685. attach, err := models.GetAttachmentByUUID(req.Uuid)
  686. if err != nil {
  687. log.Error("GetAttachmentByUUID(%s) failed:%v", strconv.FormatInt(jobResult.JobID, 10), err.Error())
  688. return err
  689. }
  690. jobID := strconv.FormatInt(jobResult.JobID, 10)
  691. err = models.CreateCloudbrain(&models.Cloudbrain{
  692. Status: TransTrainJobStatus(jobResult.Status),
  693. UserID: ctx.User.ID,
  694. RepoID: ctx.Repo.Repository.ID,
  695. JobID: jobID,
  696. JobName: req.JobName,
  697. DisplayJobName: req.DisplayJobName,
  698. JobType: string(models.JobTypeInference),
  699. Type: models.TypeCloudBrainTwo,
  700. VersionID: jobResult.VersionID,
  701. VersionName: jobResult.VersionName,
  702. Uuid: req.Uuid,
  703. DatasetName: attach.Name,
  704. CommitID: req.CommitID,
  705. EngineID: req.EngineID,
  706. TrainUrl: req.TrainUrl,
  707. BranchName: req.BranchName,
  708. Parameters: req.Params,
  709. BootFile: req.BootFile,
  710. DataUrl: req.DataUrl,
  711. LogUrl: req.LogUrl,
  712. FlavorCode: req.FlavorCode,
  713. Description: req.Description,
  714. WorkServerNumber: req.WorkServerNumber,
  715. FlavorName: req.FlavorName,
  716. EngineName: req.EngineName,
  717. LabelName: req.LabelName,
  718. IsLatestVersion: req.IsLatestVersion,
  719. ComputeResource: models.NPUResource,
  720. VersionCount: req.VersionCount,
  721. TotalVersionCount: req.TotalVersionCount,
  722. ModelName: req.ModelName,
  723. ModelVersion: req.ModelVersion,
  724. CkptName: req.CkptName,
  725. ResultUrl: req.ResultUrl,
  726. CreatedUnix: createTime,
  727. UpdatedUnix: createTime,
  728. })
  729. if err != nil {
  730. log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error())
  731. return err
  732. }
  733. notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, jobID, req.DisplayJobName, models.ActionCreateInferenceTask)
  734. return nil
  735. }
  736. func GetNotebookImageName(imageId string) (string, error) {
  737. var validImage = false
  738. var imageName = ""
  739. if ImageInfos == nil {
  740. json.Unmarshal([]byte(setting.ImageInfos), &ImageInfos)
  741. }
  742. for _, imageInfo := range ImageInfos.ImageInfo {
  743. if imageInfo.Id == imageId {
  744. validImage = true
  745. imageName = imageInfo.Value
  746. }
  747. }
  748. if !validImage {
  749. log.Error("the image id(%s) is invalid", imageId)
  750. return imageName, errors.New("the image id is invalid")
  751. }
  752. return imageName, nil
  753. }
  754. func InitSpecialPool() {
  755. if SpecialPools == nil && setting.ModelArtsSpecialPools != "" {
  756. json.Unmarshal([]byte(setting.ModelArtsSpecialPools), &SpecialPools)
  757. }
  758. }