You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_convert.go 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. package repo
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "os"
  10. "strings"
  11. "code.gitea.io/gitea/models"
  12. "code.gitea.io/gitea/modules/cloudbrain"
  13. "code.gitea.io/gitea/modules/context"
  14. "code.gitea.io/gitea/modules/git"
  15. "code.gitea.io/gitea/modules/log"
  16. "code.gitea.io/gitea/modules/modelarts"
  17. "code.gitea.io/gitea/modules/setting"
  18. "code.gitea.io/gitea/modules/storage"
  19. "code.gitea.io/gitea/modules/timeutil"
  20. uuid "github.com/satori/go.uuid"
  21. )
  22. const (
  23. tplModelManageConvertIndex = "repo/modelmanage/convertIndex"
  24. tplModelConvertInfo = "repo/modelmanage/convertshowinfo"
  25. PYTORCH_ENGINE = 0
  26. TENSORFLOW_ENGINE = 1
  27. MINDSPORE_ENGIN = 2
  28. ModelMountPath = "/model"
  29. CodeMountPath = "/code"
  30. DataSetMountPath = "/dataset"
  31. LogFile = "log.txt"
  32. DefaultBranchName = "master"
  33. SubTaskName = "task1"
  34. GpuQueue = "openidgx"
  35. Success = "S000"
  36. GPU_PYTORCH_IMAGE = "dockerhub.pcl.ac.cn:5000/user-images/openi:tensorRT_7_zouap"
  37. GPU_TENSORFLOW_IMAGE = "dockerhub.pcl.ac.cn:5000/user-images/openi:tf2onnx"
  38. PytorchBootFile = "convert_pytorch.py"
  39. MindsporeBootFile = "convert_mindspore.py"
  40. TensorFlowNpuBootFile = "convert_tensorflow.py"
  41. TensorFlowGpuBootFile = "convert_tensorflow_gpu.py"
  42. ConvertRepoPath = "https://git.openi.org.cn/zouap/npu_test"
  43. REPO_ID = 33267
  44. NPU_MINDSPORE_IMAGE_ID = 122
  45. NPU_TENSORFLOW_IMAGE_ID = 121
  46. NPU_FlavorCode = "modelarts.bm.910.arm.public.1"
  47. NPU_PoolID = "pool7908321a"
  48. )
  49. var (
  50. TrainResourceSpecs *models.ResourceSpecs
  51. )
  52. func SaveModelConvert(ctx *context.Context) {
  53. log.Info("save model convert start.")
  54. if !ctx.Repo.CanWrite(models.UnitTypeModelManage) {
  55. ctx.JSON(403, ctx.Tr("repo.model_noright"))
  56. return
  57. }
  58. name := ctx.Query("name")
  59. desc := ctx.Query("desc")
  60. modelId := ctx.Query("modelId")
  61. modelPath := ctx.Query("ModelFile")
  62. SrcEngine := ctx.QueryInt("SrcEngine")
  63. InputShape := ctx.Query("inputshape")
  64. InputDataFormat := ctx.Query("inputdataformat")
  65. DestFormat := ctx.QueryInt("DestFormat")
  66. NetOutputFormat := ctx.QueryInt("NetOutputFormat")
  67. task, err := models.QueryModelById(modelId)
  68. if err != nil {
  69. log.Error("no such model!", err.Error())
  70. ctx.ServerError("no such model:", err)
  71. return
  72. }
  73. uuid := uuid.NewV4()
  74. id := uuid.String()
  75. modelConvert := &models.AiModelConvert{
  76. ID: id,
  77. Name: name,
  78. Description: desc,
  79. Status: string(models.JobWaiting),
  80. SrcEngine: SrcEngine,
  81. RepoId: ctx.Repo.Repository.ID,
  82. ModelName: task.Name,
  83. ModelVersion: task.Version,
  84. ModelId: modelId,
  85. ModelPath: modelPath,
  86. DestFormat: DestFormat,
  87. NetOutputFormat: NetOutputFormat,
  88. InputShape: InputShape,
  89. InputDataFormat: InputDataFormat,
  90. UserId: ctx.User.ID,
  91. }
  92. models.SaveModelConvert(modelConvert)
  93. if modelConvert.IsGpuTrainTask() {
  94. log.Info("create gpu train job.")
  95. err = createGpuTrainJob(modelConvert, ctx, task)
  96. } else {
  97. //create npu job
  98. log.Info("create npu train job.")
  99. createNpuTrainJob(modelConvert, ctx, task.Path)
  100. }
  101. if err == nil {
  102. ctx.JSON(200, map[string]string{
  103. "result_code": "0",
  104. })
  105. } else {
  106. ctx.JSON(200, map[string]string{
  107. "result_code": "1",
  108. })
  109. }
  110. }
  111. func createNpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context, modelRelativePath string) {
  112. VersionOutputPath := "V0001"
  113. codeLocalPath := setting.JobPath + modelConvert.ID + modelarts.CodePath
  114. codeObsPath := "/" + setting.Bucket + modelarts.JobPath + modelConvert.ID + modelarts.CodePath
  115. outputObsPath := "/" + setting.Bucket + modelarts.JobPath + modelConvert.ID + modelarts.OutputPath + VersionOutputPath + "/"
  116. logObsPath := "/" + setting.Bucket + modelarts.JobPath + modelConvert.ID + modelarts.LogPath + VersionOutputPath + "/"
  117. dataPath := "/" + modelRelativePath
  118. _, err := ioutil.ReadDir(codeLocalPath)
  119. if err == nil {
  120. os.RemoveAll(codeLocalPath)
  121. }
  122. if err := downloadConvertCode(ConvertRepoPath, codeLocalPath, DefaultBranchName); err != nil {
  123. log.Error("downloadCode failed, server timed out: %s (%v)", ConvertRepoPath, err)
  124. return
  125. }
  126. if err := obsMkdir(setting.CodePathPrefix + modelConvert.ID + modelarts.OutputPath + VersionOutputPath + "/"); err != nil {
  127. log.Error("Failed to obsMkdir_output: %s (%v)", modelConvert.ID+modelarts.OutputPath, err)
  128. return
  129. }
  130. if err := obsMkdir(setting.CodePathPrefix + modelConvert.ID + modelarts.LogPath + VersionOutputPath + "/"); err != nil {
  131. log.Error("Failed to obsMkdir_log: %s (%v)", modelConvert.ID+modelarts.LogPath, err)
  132. return
  133. }
  134. if err := uploadCodeToObs(codeLocalPath, modelConvert.ID, ""); err != nil {
  135. log.Error("Failed to uploadCodeToObs: %s (%v)", modelConvert.ID, err)
  136. return
  137. }
  138. intputshape := strings.Split(modelConvert.InputShape, ",")
  139. n := "256"
  140. c := "1"
  141. h := "28"
  142. w := "28"
  143. if len(intputshape) == 4 {
  144. n = intputshape[0]
  145. c = intputshape[1]
  146. h = intputshape[2]
  147. w = intputshape[3]
  148. }
  149. param := make([]models.Parameter, 0)
  150. modelPara := models.Parameter{
  151. Label: "model",
  152. Value: modelConvert.ModelPath,
  153. }
  154. param = append(param, modelPara)
  155. batchSizePara := models.Parameter{
  156. Label: "n",
  157. Value: fmt.Sprint(n),
  158. }
  159. param = append(param, batchSizePara)
  160. channelSizePara := models.Parameter{
  161. Label: "c",
  162. Value: fmt.Sprint(c),
  163. }
  164. param = append(param, channelSizePara)
  165. heightPara := models.Parameter{
  166. Label: "h",
  167. Value: fmt.Sprint(h),
  168. }
  169. param = append(param, heightPara)
  170. widthPara := models.Parameter{
  171. Label: "w",
  172. Value: fmt.Sprint(w),
  173. }
  174. param = append(param, widthPara)
  175. var engineId int64
  176. engineId = int64(NPU_MINDSPORE_IMAGE_ID)
  177. bootfile := MindsporeBootFile
  178. if modelConvert.SrcEngine == TENSORFLOW_ENGINE {
  179. engineId = int64(NPU_TENSORFLOW_IMAGE_ID)
  180. bootfile = TensorFlowNpuBootFile
  181. }
  182. req := &modelarts.GenerateTrainJobReq{
  183. JobName: modelConvert.ID,
  184. DisplayJobName: modelConvert.Name,
  185. DataUrl: dataPath,
  186. Description: modelConvert.Description,
  187. CodeObsPath: codeObsPath,
  188. BootFileUrl: codeObsPath + bootfile,
  189. BootFile: bootfile,
  190. TrainUrl: outputObsPath,
  191. FlavorCode: NPU_FlavorCode,
  192. WorkServerNumber: 1,
  193. IsLatestVersion: modelarts.IsLatestVersion,
  194. EngineID: engineId,
  195. LogUrl: logObsPath,
  196. PoolID: NPU_PoolID,
  197. Parameters: param,
  198. BranchName: DefaultBranchName,
  199. }
  200. result, err := modelarts.GenerateModelConvertTrainJob(req)
  201. log.Info("jobId=" + fmt.Sprint(result.JobID) + " versionid=" + fmt.Sprint(result.VersionID))
  202. models.UpdateModelConvertModelArts(modelConvert.ID, fmt.Sprint(result.JobID), fmt.Sprint(result.VersionID))
  203. }
  204. func downloadConvertCode(repopath string, codePath, branchName string) error {
  205. //add "file:///" prefix to make the depth valid
  206. if err := git.Clone(repopath, codePath, git.CloneRepoOptions{Branch: branchName, Depth: 1}); err != nil {
  207. log.Error("Failed to clone repository: %s (%v)", repopath, err)
  208. return err
  209. }
  210. log.Info("srcPath=" + repopath + " codePath=" + codePath)
  211. configFile, err := os.OpenFile(codePath+"/.git/config", os.O_RDWR, 0666)
  212. if err != nil {
  213. log.Error("open file(%s) failed:%v", codePath+"/,git/config", err)
  214. return err
  215. }
  216. defer configFile.Close()
  217. pos := int64(0)
  218. reader := bufio.NewReader(configFile)
  219. for {
  220. line, err := reader.ReadString('\n')
  221. if err != nil {
  222. if err == io.EOF {
  223. log.Error("not find the remote-url")
  224. return nil
  225. } else {
  226. log.Error("read error: %v", err)
  227. return err
  228. }
  229. }
  230. if strings.Contains(line, "url") && strings.Contains(line, ".git") {
  231. originUrl := "\turl = " + repopath + "\n"
  232. if len(line) > len(originUrl) {
  233. originUrl += strings.Repeat(" ", len(line)-len(originUrl))
  234. }
  235. bytes := []byte(originUrl)
  236. _, err := configFile.WriteAt(bytes, pos)
  237. if err != nil {
  238. log.Error("WriteAt failed:%v", err)
  239. return err
  240. }
  241. break
  242. }
  243. pos += int64(len(line))
  244. }
  245. return nil
  246. }
  247. func downloadFromObsToLocal(task *models.AiModelManage, localPath string) error {
  248. path := Model_prefix + models.AttachmentRelativePath(task.ID) + "/"
  249. allFile, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, path)
  250. if err == nil {
  251. _, errState := os.Stat(localPath)
  252. if errState != nil {
  253. if err = os.MkdirAll(localPath, os.ModePerm); err != nil {
  254. return err
  255. }
  256. }
  257. for _, oneFile := range allFile {
  258. if oneFile.IsDir {
  259. log.Info(" dir name:" + oneFile.FileName)
  260. } else {
  261. allFileName := localPath + "/" + oneFile.FileName
  262. index := strings.LastIndex(allFileName, "/")
  263. if index != -1 {
  264. parentDir := allFileName[0:index]
  265. if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
  266. log.Info("make dir may be error," + err.Error())
  267. }
  268. }
  269. fDest, err := os.Create(allFileName)
  270. if err != nil {
  271. log.Info("create file error, download file failed: %s\n", err.Error())
  272. return err
  273. }
  274. body, err := storage.ObsDownloadAFile(setting.Bucket, path+oneFile.FileName)
  275. if err != nil {
  276. log.Info("download file failed: %s\n", err.Error())
  277. return err
  278. } else {
  279. defer body.Close()
  280. p := make([]byte, 1024)
  281. var readErr error
  282. var readCount int
  283. // 读取对象内容
  284. for {
  285. readCount, readErr = body.Read(p)
  286. if readCount > 0 {
  287. fDest.Write(p[:readCount])
  288. }
  289. if readErr != nil {
  290. break
  291. }
  292. }
  293. }
  294. }
  295. }
  296. } else {
  297. log.Info("error,msg=" + err.Error())
  298. return err
  299. }
  300. return nil
  301. }
  302. func createGpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context, model *models.AiModelManage) error {
  303. modelRelativePath := model.Path
  304. command := ""
  305. IMAGE_URL := GPU_PYTORCH_IMAGE
  306. dataActualPath := setting.Attachment.Minio.RealPath + modelRelativePath
  307. if modelConvert.SrcEngine == PYTORCH_ENGINE {
  308. command = getGpuModelConvertCommand(modelConvert.ID, modelConvert.ModelPath, modelConvert, PytorchBootFile)
  309. } else if modelConvert.SrcEngine == TENSORFLOW_ENGINE {
  310. IMAGE_URL = GPU_TENSORFLOW_IMAGE
  311. command = getGpuModelConvertCommand(modelConvert.ID, modelConvert.ModelPath, modelConvert, TensorFlowGpuBootFile)
  312. //如果模型在OBS上,需要下载到本地,并上传到minio中
  313. if model.Type == models.TypeCloudBrainTwo {
  314. relatetiveModelPath := setting.JobPath + modelConvert.ID + "/dataset"
  315. log.Info("local dataset path:" + relatetiveModelPath)
  316. downloadFromObsToLocal(model, relatetiveModelPath)
  317. uploadCodeToMinio(relatetiveModelPath+"/", modelConvert.ID, "/dataset/")
  318. dataActualPath = setting.Attachment.Minio.RealPath + setting.Attachment.Minio.Bucket + "/" + setting.CBCodePathPrefix + modelConvert.ID + "/dataset"
  319. }
  320. }
  321. log.Info("dataActualPath=" + dataActualPath)
  322. log.Info("command=" + command)
  323. codePath := setting.JobPath + modelConvert.ID + CodeMountPath
  324. downloadConvertCode(ConvertRepoPath, codePath, DefaultBranchName)
  325. uploadCodeToMinio(codePath+"/", modelConvert.ID, CodeMountPath+"/")
  326. minioCodePath := setting.Attachment.Minio.RealPath + setting.Attachment.Minio.Bucket + "/" + setting.CBCodePathPrefix + modelConvert.ID + "/code"
  327. log.Info("minio codePath=" + minioCodePath)
  328. modelPath := setting.JobPath + modelConvert.ID + ModelMountPath + "/"
  329. log.Info("local modelPath=" + modelPath)
  330. mkModelPath(modelPath)
  331. uploadCodeToMinio(modelPath, modelConvert.ID, ModelMountPath+"/")
  332. minioModelPath := setting.Attachment.Minio.RealPath + setting.Attachment.Minio.Bucket + "/" + setting.CBCodePathPrefix + modelConvert.ID + "/model"
  333. log.Info("minio model path=" + minioModelPath)
  334. if TrainResourceSpecs == nil {
  335. json.Unmarshal([]byte(setting.TrainResourceSpecs), &TrainResourceSpecs)
  336. }
  337. resourceSpec := TrainResourceSpecs.ResourceSpec[1]
  338. jobResult, err := cloudbrain.CreateJob(modelConvert.ID, models.CreateJobParams{
  339. JobName: modelConvert.ID,
  340. RetryCount: 1,
  341. GpuType: GpuQueue,
  342. Image: IMAGE_URL,
  343. TaskRoles: []models.TaskRole{
  344. {
  345. Name: SubTaskName,
  346. TaskNumber: 1,
  347. MinSucceededTaskCount: 1,
  348. MinFailedTaskCount: 1,
  349. CPUNumber: resourceSpec.CpuNum,
  350. GPUNumber: resourceSpec.GpuNum,
  351. MemoryMB: resourceSpec.MemMiB,
  352. ShmMB: resourceSpec.ShareMemMiB,
  353. Command: command,
  354. NeedIBDevice: false,
  355. IsMainRole: false,
  356. UseNNI: false,
  357. },
  358. },
  359. Volumes: []models.Volume{
  360. {
  361. HostPath: models.StHostPath{
  362. Path: minioCodePath,
  363. MountPath: CodeMountPath,
  364. ReadOnly: false,
  365. },
  366. },
  367. {
  368. HostPath: models.StHostPath{
  369. Path: dataActualPath,
  370. MountPath: DataSetMountPath,
  371. ReadOnly: true,
  372. },
  373. },
  374. {
  375. HostPath: models.StHostPath{
  376. Path: minioModelPath,
  377. MountPath: ModelMountPath,
  378. ReadOnly: false,
  379. },
  380. },
  381. },
  382. })
  383. if err != nil {
  384. log.Error("CreateJob failed:", err.Error(), ctx.Data["MsgID"])
  385. return err
  386. }
  387. if jobResult.Code != Success {
  388. log.Error("CreateJob(%s) failed:%s", modelConvert.ID, jobResult.Msg, ctx.Data["MsgID"])
  389. return errors.New(jobResult.Msg)
  390. }
  391. var jobID = jobResult.Payload["jobId"].(string)
  392. log.Info("jobId=" + jobID)
  393. models.UpdateModelConvertCBTI(modelConvert.ID, jobID)
  394. return nil
  395. }
  396. func getGpuModelConvertCommand(name string, modelFile string, modelConvert *models.AiModelConvert, bootfile string) string {
  397. var command string
  398. intputshape := strings.Split(modelConvert.InputShape, ",")
  399. n := "256"
  400. c := "1"
  401. h := "28"
  402. w := "28"
  403. if len(intputshape) == 4 {
  404. n = intputshape[0]
  405. c = intputshape[1]
  406. h = intputshape[2]
  407. w = intputshape[3]
  408. }
  409. command += "python3 /code/" + bootfile + " --model " + modelFile + " --n " + n + " --c " + c + " --h " + h + " --w " + w + " > " + ModelMountPath + "/" + name + "-" + LogFile
  410. return command
  411. }
  412. func DeleteModelConvert(ctx *context.Context) {
  413. log.Info("delete model convert start.")
  414. id := ctx.Params(":id")
  415. err := models.DeleteModelConvertById(id)
  416. if err != nil {
  417. ctx.JSON(500, err.Error())
  418. } else {
  419. ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelmanage/convert_model")
  420. }
  421. }
  422. func StopModelConvert(ctx *context.Context) {
  423. id := ctx.Params(":id")
  424. log.Info("stop model convert start.id=" + id)
  425. job, err := models.QueryModelConvertById(ctx.Query("ID"))
  426. if err != nil {
  427. ctx.ServerError("Not found task.", err)
  428. return
  429. }
  430. if job.IsGpuTrainTask() {
  431. err = cloudbrain.StopJob(job.CloudBrainTaskId)
  432. if err != nil {
  433. log.Error("Stop cloudbrain Job(%s) failed:%v", job.CloudBrainTaskId, err)
  434. }
  435. } else {
  436. _, err = modelarts.StopTrainJob(job.CloudBrainTaskId, job.ModelArtsVersionId)
  437. if err != nil {
  438. log.Error("Stop modelarts Job(%s) failed:%v", job.CloudBrainTaskId, err)
  439. }
  440. }
  441. job.Status = string(models.JobStopped)
  442. if job.EndTime == 0 {
  443. job.EndTime = timeutil.TimeStampNow()
  444. }
  445. models.ModelConvertSetDuration(job)
  446. err = models.UpdateModelConvert(job)
  447. if err != nil {
  448. log.Error("UpdateModelConvert failed:", err)
  449. }
  450. ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelmanage/convert_model")
  451. }
  452. func ShowModelConvertInfo(ctx *context.Context) {
  453. ctx.Data["ID"] = ctx.Query("ID")
  454. ctx.Data["isModelManage"] = true
  455. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  456. job, err := models.QueryModelConvertById(ctx.Query("ID"))
  457. if err == nil {
  458. ctx.Data["task"] = job
  459. } else {
  460. ctx.ServerError("Not found task.", err)
  461. return
  462. }
  463. ctx.Data["canDownload"] = isOper(ctx, job.UserId)
  464. user, err := models.GetUserByID(job.UserId)
  465. if err == nil {
  466. job.UserName = user.Name
  467. job.UserRelAvatarLink = user.RelAvatarLink()
  468. }
  469. if job.IsGpuTrainTask() {
  470. ctx.Data["npu_display"] = "none"
  471. ctx.Data["gpu_display"] = "block"
  472. result, err := cloudbrain.GetJob(job.CloudBrainTaskId)
  473. if err != nil {
  474. log.Info("error:" + err.Error())
  475. ctx.Data["error"] = err.Error()
  476. return
  477. }
  478. if result != nil {
  479. jobRes, _ := models.ConvertToJobResultPayload(result.Payload)
  480. ctx.Data["result"] = jobRes
  481. taskRoles := jobRes.TaskRoles
  482. taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
  483. ctx.Data["taskRes"] = taskRes
  484. ctx.Data["ExitDiagnostics"] = taskRes.TaskStatuses[0].ExitDiagnostics
  485. ctx.Data["AppExitDiagnostics"] = jobRes.JobStatus.AppExitDiagnostics
  486. job.Status = jobRes.JobStatus.State
  487. if jobRes.JobStatus.State != string(models.JobWaiting) && jobRes.JobStatus.State != string(models.JobFailed) {
  488. job.ContainerIp = taskRes.TaskStatuses[0].ContainerIP
  489. job.ContainerID = taskRes.TaskStatuses[0].ContainerID
  490. job.Status = taskRes.TaskStatuses[0].State
  491. }
  492. if jobRes.JobStatus.State != string(models.JobWaiting) {
  493. models.ModelComputeAndSetDuration(job, jobRes)
  494. err = models.UpdateModelConvert(job)
  495. if err != nil {
  496. log.Error("UpdateModelConvert failed:", err)
  497. }
  498. }
  499. }
  500. } else {
  501. ctx.Data["npu_display"] = "block"
  502. ctx.Data["gpu_display"] = "none"
  503. ctx.Data["ExitDiagnostics"] = ""
  504. ctx.Data["AppExitDiagnostics"] = ""
  505. }
  506. ctx.HTML(200, tplModelConvertInfo)
  507. }
  508. func ConvertModelTemplate(ctx *context.Context) {
  509. ctx.Data["isModelManage"] = true
  510. ctx.Data["MODEL_COUNT"] = 0
  511. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  512. ctx.Data["TRAIN_COUNT"] = 0
  513. ShowModelConvertPageInfo(ctx)
  514. ctx.HTML(200, tplModelManageConvertIndex)
  515. }
  516. func ShowModelConvertPageInfo(ctx *context.Context) {
  517. log.Info("ShowModelConvertInfo start.")
  518. if !isQueryRight(ctx) {
  519. log.Info("no right.")
  520. ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
  521. return
  522. }
  523. page := ctx.QueryInt("page")
  524. if page <= 0 {
  525. page = 1
  526. }
  527. pageSize := ctx.QueryInt("pageSize")
  528. if pageSize <= 0 {
  529. pageSize = setting.UI.IssuePagingNum
  530. }
  531. repoId := ctx.Repo.Repository.ID
  532. modelResult, count, err := models.QueryModelConvert(&models.AiModelQueryOptions{
  533. ListOptions: models.ListOptions{
  534. Page: page,
  535. PageSize: pageSize,
  536. },
  537. RepoID: repoId,
  538. })
  539. if err != nil {
  540. log.Info("query db error." + err.Error())
  541. ctx.ServerError("Cloudbrain", err)
  542. return
  543. }
  544. userIds := make([]int64, len(modelResult))
  545. for i, model := range modelResult {
  546. model.IsCanOper = isOper(ctx, model.UserId)
  547. model.IsCanDelete = isCanDelete(ctx, model.UserId)
  548. userIds[i] = model.UserId
  549. }
  550. userNameMap := queryUserName(userIds)
  551. for _, model := range modelResult {
  552. value := userNameMap[model.UserId]
  553. if value != nil {
  554. model.UserName = value.Name
  555. model.UserRelAvatarLink = value.RelAvatarLink()
  556. }
  557. }
  558. pager := context.NewPagination(int(count), page, pageSize, 5)
  559. ctx.Data["Page"] = pager
  560. ctx.Data["Tasks"] = modelResult
  561. }
  562. func ModelConvertDownloadModel(ctx *context.Context) {
  563. log.Info("enter here......")
  564. id := ctx.Params(":id")
  565. job, err := models.QueryModelConvertById(id)
  566. if err != nil {
  567. ctx.ServerError("Not found task.", err)
  568. return
  569. }
  570. AllDownload := ctx.QueryBool("AllDownload")
  571. if AllDownload {
  572. if job.IsGpuTrainTask() {
  573. path := setting.CBCodePathPrefix + job.ID + "/model"
  574. allFile, err := storage.GetAllObjectByBucketAndPrefixMinio(setting.Attachment.Minio.Bucket, path)
  575. if err == nil {
  576. returnFileName := job.Name + ".zip"
  577. MinioDownloadManyFile(path, ctx, returnFileName, allFile)
  578. } else {
  579. log.Info("error,msg=" + err.Error())
  580. ctx.ServerError("no file to download.", err)
  581. }
  582. } else {
  583. }
  584. } else {
  585. if job.IsGpuTrainTask() {
  586. parentDir := ctx.Query("parentDir")
  587. fileName := ctx.Query("fileName")
  588. jobName := ctx.Query("jobName")
  589. filePath := "jobs/" + jobName + "/model/" + parentDir
  590. url, err := storage.Attachments.PresignedGetURL(filePath, fileName)
  591. if err != nil {
  592. log.Error("PresignedGetURL failed: %v", err.Error(), ctx.Data["msgID"])
  593. ctx.ServerError("PresignedGetURL", err)
  594. return
  595. }
  596. ctx.JSON(200, url)
  597. //http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
  598. } else {
  599. }
  600. }
  601. }