You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_convert.go 8.8 kB


  1. package repo
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "code.gitea.io/gitea/models"
  6. "code.gitea.io/gitea/modules/cloudbrain"
  7. "code.gitea.io/gitea/modules/context"
  8. "code.gitea.io/gitea/modules/log"
  9. "code.gitea.io/gitea/modules/setting"
  10. uuid "github.com/satori/go.uuid"
  11. )
  12. const (
  13. tplModelManageConvertIndex = "repo/modelmanage/convertIndex"
  14. tplModelConvertInfo = "repo/modelmanage/convertshowinfo"
  15. PYTORCH_ENGINE = 0
  16. TENSORFLOW_ENGINE = 1
  17. MINDSPORE_ENGIN = 2
  18. ModelMountPath = "/model"
  19. CodeMountPath = "/code"
  20. DataSetMountPath = "/dataset"
  21. LogFile = "log.txt"
  22. DefaultBranchName = "master"
  23. SubTaskName = "task1"
  24. GpuQueue = "openidebug"
  25. Success = "S000"
  26. GPU_PYTORCH_IMAGE = "dockerhub.pcl.ac.cn:5000/user-images/openi:tensorRT_7_zouap"
  27. )
  28. var (
  29. TrainResourceSpecs *models.ResourceSpecs
  30. )
  31. func SaveModelConvert(ctx *context.Context) {
  32. log.Info("save model convert start.")
  33. if !ctx.Repo.CanWrite(models.UnitTypeModelManage) {
  34. ctx.JSON(403, ctx.Tr("repo.model_noright"))
  35. return
  36. }
  37. name := ctx.Query("name")
  38. desc := ctx.Query("desc")
  39. modelId := ctx.Query("modelId")
  40. modelPath := ctx.Query("ModelFile")
  41. SrcEngine := ctx.QueryInt("SrcEngine")
  42. InputShape := ctx.Query("inputshape")
  43. InputDataFormat := ctx.Query("inputdataformat")
  44. DestFormat := ctx.QueryInt("DestFormat")
  45. NetOutputFormat := ctx.QueryInt("NetOutputFormat")
  46. task, err := models.QueryModelById(modelId)
  47. if err != nil {
  48. log.Error("no such model!", err.Error())
  49. ctx.ServerError("no such model:", err)
  50. return
  51. }
  52. uuid := uuid.NewV4()
  53. id := uuid.String()
  54. modelConvert := &models.AiModelConvert{
  55. ID: id,
  56. Name: name,
  57. Description: desc,
  58. Status: string(models.JobWaiting),
  59. SrcEngine: SrcEngine,
  60. RepoId: ctx.Repo.Repository.ID,
  61. ModelName: task.Name,
  62. ModelVersion: task.Version,
  63. ModelId: modelId,
  64. ModelPath: modelPath,
  65. DestFormat: DestFormat,
  66. NetOutputFormat: NetOutputFormat,
  67. InputShape: InputShape,
  68. InputDataFormat: InputDataFormat,
  69. UserId: ctx.User.ID,
  70. }
  71. models.SaveModelConvert(modelConvert)
  72. err = createTrainJob(modelConvert, ctx, task.Path)
  73. if err == nil {
  74. ctx.JSON(200, map[string]string{
  75. "result_code": "0",
  76. })
  77. } else {
  78. ctx.JSON(200, map[string]string{
  79. "result_code": "1",
  80. })
  81. }
  82. }
  83. func createTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context, modelRelativePath string) error {
  84. repo, _ := models.GetRepositoryByID(ctx.Repo.Repository.ID)
  85. if modelConvert.SrcEngine == PYTORCH_ENGINE {
  86. codePath := setting.JobPath + modelConvert.ID + CodeMountPath
  87. log.Info("Volume codePath=" + codePath)
  88. downloadCode(repo, codePath, DefaultBranchName)
  89. uploadCodeToMinio(codePath+"/", modelConvert.ID, CodeMountPath+"/")
  90. log.Info("minio code path=" + setting.CBCodePathPrefix + modelConvert.ID)
  91. modelPath := setting.JobPath + modelConvert.ID + ModelMountPath + "/"
  92. log.Info("modelPath=" + modelPath)
  93. mkModelPath(modelPath)
  94. uploadCodeToMinio(modelPath, modelConvert.ID, ModelMountPath+"/")
  95. command := getModelConvertCommand(modelConvert.ID, modelConvert.ModelPath)
  96. log.Info("command=" + command)
  97. dataActualPath := setting.Attachment.Minio.RealPath + modelRelativePath
  98. if TrainResourceSpecs == nil {
  99. json.Unmarshal([]byte(setting.TrainResourceSpecs), &TrainResourceSpecs)
  100. }
  101. resourceSpec := TrainResourceSpecs.ResourceSpec[1]
  102. jobResult, err := cloudbrain.CreateJob(modelConvert.ID, models.CreateJobParams{
  103. JobName: modelConvert.ID,
  104. RetryCount: 1,
  105. GpuType: GpuQueue,
  106. Image: GPU_PYTORCH_IMAGE,
  107. TaskRoles: []models.TaskRole{
  108. {
  109. Name: SubTaskName,
  110. TaskNumber: 1,
  111. MinSucceededTaskCount: 1,
  112. MinFailedTaskCount: 1,
  113. CPUNumber: resourceSpec.CpuNum,
  114. GPUNumber: resourceSpec.GpuNum,
  115. MemoryMB: resourceSpec.MemMiB,
  116. ShmMB: resourceSpec.ShareMemMiB,
  117. Command: command,
  118. NeedIBDevice: false,
  119. IsMainRole: false,
  120. UseNNI: false,
  121. },
  122. },
  123. Volumes: []models.Volume{
  124. {
  125. HostPath: models.StHostPath{
  126. Path: codePath,
  127. MountPath: CodeMountPath,
  128. ReadOnly: false,
  129. },
  130. },
  131. {
  132. HostPath: models.StHostPath{
  133. Path: dataActualPath,
  134. MountPath: DataSetMountPath,
  135. ReadOnly: true,
  136. },
  137. },
  138. {
  139. HostPath: models.StHostPath{
  140. Path: modelPath,
  141. MountPath: ModelMountPath,
  142. ReadOnly: false,
  143. },
  144. },
  145. },
  146. })
  147. if err != nil {
  148. log.Error("CreateJob failed:", err.Error(), ctx.Data["MsgID"])
  149. return err
  150. }
  151. if jobResult.Code != Success {
  152. log.Error("CreateJob(%s) failed:%s", modelConvert.ID, jobResult.Msg, ctx.Data["MsgID"])
  153. return errors.New(jobResult.Msg)
  154. }
  155. var jobID = jobResult.Payload["jobId"].(string)
  156. log.Info("jobId=" + jobID)
  157. models.UpdateModelConvertCBTI(modelConvert.ID, jobID)
  158. }
  159. return nil
  160. }
  161. func getModelConvertCommand(name string, modelFile string) string {
  162. var command string
  163. bootFile := "convert_pytorch.py"
  164. command += "python3 /code/" + bootFile + " --model " + modelFile + " > " + ModelMountPath + "/" + name + "-" + LogFile
  165. return command
  166. }
  167. func DeleteModelConvert(ctx *context.Context) {
  168. log.Info("delete model convert start.")
  169. id := ctx.Params(":id")
  170. err := models.DeleteModelConvertById(id)
  171. if err != nil {
  172. ctx.JSON(500, err.Error())
  173. } else {
  174. ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelmanage/convert_model")
  175. }
  176. }
  177. func StopModelConvert(ctx *context.Context) {
  178. id := ctx.Params(":id")
  179. log.Info("stop model convert start.id=" + id)
  180. }
  181. func ShowModelConvertInfo(ctx *context.Context) {
  182. ctx.Data["ID"] = ctx.Query("ID")
  183. ctx.Data["isModelManage"] = true
  184. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  185. job, err := models.QueryModelConvertById(ctx.Query("ID"))
  186. if err == nil {
  187. ctx.Data["task"] = job
  188. }
  189. result, err := cloudbrain.GetJob(job.CloudBrainTaskId)
  190. if err != nil {
  191. log.Info("error:" + err.Error())
  192. ctx.Data["error"] = err.Error()
  193. return
  194. }
  195. if result != nil {
  196. jobRes, _ := models.ConvertToJobResultPayload(result.Payload)
  197. ctx.Data["result"] = jobRes
  198. taskRoles := jobRes.TaskRoles
  199. taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
  200. ctx.Data["taskRes"] = taskRes
  201. ctx.Data["ExitDiagnostics"] = taskRes.TaskStatuses[0].ExitDiagnostics
  202. job.Status = jobRes.JobStatus.State
  203. if jobRes.JobStatus.State != string(models.JobWaiting) && jobRes.JobStatus.State != string(models.JobFailed) {
  204. job.ContainerIp = taskRes.TaskStatuses[0].ContainerIP
  205. job.ContainerID = taskRes.TaskStatuses[0].ContainerID
  206. job.Status = taskRes.TaskStatuses[0].State
  207. }
  208. if jobRes.JobStatus.State != string(models.JobWaiting) {
  209. models.ModelComputeAndSetDuration(job, jobRes)
  210. err = models.UpdateModelConvert(job)
  211. if err != nil {
  212. log.Error("UpdateModelConvert failed:", err)
  213. }
  214. }
  215. }
  216. ctx.HTML(200, tplModelConvertInfo)
  217. }
  218. func ConvertModelTemplate(ctx *context.Context) {
  219. ctx.Data["isModelManage"] = true
  220. ctx.Data["MODEL_COUNT"] = 0
  221. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  222. ctx.Data["TRAIN_COUNT"] = 0
  223. ShowModelConvertPageInfo(ctx)
  224. ctx.HTML(200, tplModelManageConvertIndex)
  225. }
  226. func ShowModelConvertPageInfo(ctx *context.Context) {
  227. log.Info("ShowModelConvertInfo start.")
  228. if !isQueryRight(ctx) {
  229. log.Info("no right.")
  230. ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
  231. return
  232. }
  233. page := ctx.QueryInt("page")
  234. if page <= 0 {
  235. page = 1
  236. }
  237. pageSize := ctx.QueryInt("pageSize")
  238. if pageSize <= 0 {
  239. pageSize = setting.UI.IssuePagingNum
  240. }
  241. repoId := ctx.Repo.Repository.ID
  242. modelResult, count, err := models.QueryModelConvert(&models.AiModelQueryOptions{
  243. ListOptions: models.ListOptions{
  244. Page: page,
  245. PageSize: pageSize,
  246. },
  247. RepoID: repoId,
  248. })
  249. if err != nil {
  250. log.Info("query db error." + err.Error())
  251. ctx.ServerError("Cloudbrain", err)
  252. return
  253. }
  254. userIds := make([]int64, len(modelResult))
  255. for i, model := range modelResult {
  256. model.IsCanOper = isOper(ctx, model.UserId)
  257. model.IsCanDelete = isCanDelete(ctx, model.UserId)
  258. userIds[i] = model.UserId
  259. }
  260. userNameMap := queryUserName(userIds)
  261. for _, model := range modelResult {
  262. value := userNameMap[model.UserId]
  263. if value != nil {
  264. model.UserName = value.Name
  265. model.UserRelAvatarLink = value.RelAvatarLink()
  266. }
  267. }
  268. pager := context.NewPagination(int(count), page, pageSize, 5)
  269. ctx.Data["Page"] = pager
  270. ctx.Data["Tasks"] = modelResult
  271. }