You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_convert.go 9.0 kB


  1. package repo
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "code.gitea.io/gitea/models"
  6. "code.gitea.io/gitea/modules/cloudbrain"
  7. "code.gitea.io/gitea/modules/context"
  8. "code.gitea.io/gitea/modules/log"
  9. "code.gitea.io/gitea/modules/setting"
  10. uuid "github.com/satori/go.uuid"
  11. )
  12. const (
  13. tplModelManageConvertIndex = "repo/modelmanage/convertIndex"
  14. tplModelConvertInfo = "repo/modelmanage/convertshowinfo"
  15. PYTORCH_ENGINE = 0
  16. TENSORFLOW_ENGINE = 1
  17. MINDSPORE_ENGIN = 2
  18. ModelMountPath = "/model"
  19. CodeMountPath = "/code"
  20. DataSetMountPath = "/dataset"
  21. LogFile = "log.txt"
  22. DefaultBranchName = "master"
  23. SubTaskName = "task1"
  24. GpuQueue = "openidgx"
  25. Success = "S000"
  26. GPU_PYTORCH_IMAGE = "dockerhub.pcl.ac.cn:5000/user-images/openi:tensorRT_7_zouap"
  27. )
  28. var (
  29. TrainResourceSpecs *models.ResourceSpecs
  30. )
  31. func SaveModelConvert(ctx *context.Context) {
  32. log.Info("save model convert start.")
  33. if !ctx.Repo.CanWrite(models.UnitTypeModelManage) {
  34. ctx.JSON(403, ctx.Tr("repo.model_noright"))
  35. return
  36. }
  37. name := ctx.Query("name")
  38. desc := ctx.Query("desc")
  39. modelId := ctx.Query("modelId")
  40. modelPath := ctx.Query("ModelFile")
  41. SrcEngine := ctx.QueryInt("SrcEngine")
  42. InputShape := ctx.Query("inputshape")
  43. InputDataFormat := ctx.Query("inputdataformat")
  44. DestFormat := ctx.QueryInt("DestFormat")
  45. NetOutputFormat := ctx.QueryInt("NetOutputFormat")
  46. task, err := models.QueryModelById(modelId)
  47. if err != nil {
  48. log.Error("no such model!", err.Error())
  49. ctx.ServerError("no such model:", err)
  50. return
  51. }
  52. uuid := uuid.NewV4()
  53. id := uuid.String()
  54. modelConvert := &models.AiModelConvert{
  55. ID: id,
  56. Name: name,
  57. Description: desc,
  58. Status: string(models.JobWaiting),
  59. SrcEngine: SrcEngine,
  60. RepoId: ctx.Repo.Repository.ID,
  61. ModelName: task.Name,
  62. ModelVersion: task.Version,
  63. ModelId: modelId,
  64. ModelPath: modelPath,
  65. DestFormat: DestFormat,
  66. NetOutputFormat: NetOutputFormat,
  67. InputShape: InputShape,
  68. InputDataFormat: InputDataFormat,
  69. UserId: ctx.User.ID,
  70. }
  71. models.SaveModelConvert(modelConvert)
  72. err = createTrainJob(modelConvert, ctx, task.Path)
  73. if err == nil {
  74. ctx.JSON(200, map[string]string{
  75. "result_code": "0",
  76. })
  77. } else {
  78. ctx.JSON(200, map[string]string{
  79. "result_code": "1",
  80. })
  81. }
  82. }
  83. func createTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context, modelRelativePath string) error {
  84. repo, _ := models.GetRepositoryByID(ctx.Repo.Repository.ID)
  85. if modelConvert.SrcEngine == PYTORCH_ENGINE {
  86. codePath := setting.JobPath + modelConvert.ID + CodeMountPath
  87. downloadCode(repo, codePath, DefaultBranchName)
  88. uploadCodeToMinio(codePath+"/", modelConvert.ID, CodeMountPath+"/")
  89. log.Info("minio code path=" + setting.CBCodePathPrefix + modelConvert.ID)
  90. minioCodePath := setting.Attachment.Minio.RealPath + setting.Attachment.Minio.Bucket + "/" + setting.CBCodePathPrefix + modelConvert.ID + "/code"
  91. log.Info("Volume codePath=" + minioCodePath)
  92. modelPath := setting.JobPath + modelConvert.ID + ModelMountPath + "/"
  93. log.Info("modelPath=" + modelPath)
  94. mkModelPath(modelPath)
  95. uploadCodeToMinio(modelPath, modelConvert.ID, ModelMountPath+"/")
  96. command := getModelConvertCommand(modelConvert.ID, modelConvert.ModelPath)
  97. log.Info("command=" + command)
  98. dataActualPath := setting.Attachment.Minio.RealPath + modelRelativePath
  99. log.Info("dataActualPath=" + dataActualPath)
  100. if TrainResourceSpecs == nil {
  101. json.Unmarshal([]byte(setting.TrainResourceSpecs), &TrainResourceSpecs)
  102. }
  103. resourceSpec := TrainResourceSpecs.ResourceSpec[1]
  104. jobResult, err := cloudbrain.CreateJob(modelConvert.ID, models.CreateJobParams{
  105. JobName: modelConvert.ID,
  106. RetryCount: 1,
  107. GpuType: GpuQueue,
  108. Image: GPU_PYTORCH_IMAGE,
  109. TaskRoles: []models.TaskRole{
  110. {
  111. Name: SubTaskName,
  112. TaskNumber: 1,
  113. MinSucceededTaskCount: 1,
  114. MinFailedTaskCount: 1,
  115. CPUNumber: resourceSpec.CpuNum,
  116. GPUNumber: resourceSpec.GpuNum,
  117. MemoryMB: resourceSpec.MemMiB,
  118. ShmMB: resourceSpec.ShareMemMiB,
  119. Command: command,
  120. NeedIBDevice: false,
  121. IsMainRole: false,
  122. UseNNI: false,
  123. },
  124. },
  125. Volumes: []models.Volume{
  126. {
  127. HostPath: models.StHostPath{
  128. Path: minioCodePath,
  129. MountPath: CodeMountPath,
  130. ReadOnly: false,
  131. },
  132. },
  133. {
  134. HostPath: models.StHostPath{
  135. Path: dataActualPath,
  136. MountPath: DataSetMountPath,
  137. ReadOnly: true,
  138. },
  139. },
  140. {
  141. HostPath: models.StHostPath{
  142. Path: modelPath,
  143. MountPath: ModelMountPath,
  144. ReadOnly: false,
  145. },
  146. },
  147. },
  148. })
  149. if err != nil {
  150. log.Error("CreateJob failed:", err.Error(), ctx.Data["MsgID"])
  151. return err
  152. }
  153. if jobResult.Code != Success {
  154. log.Error("CreateJob(%s) failed:%s", modelConvert.ID, jobResult.Msg, ctx.Data["MsgID"])
  155. return errors.New(jobResult.Msg)
  156. }
  157. var jobID = jobResult.Payload["jobId"].(string)
  158. log.Info("jobId=" + jobID)
  159. models.UpdateModelConvertCBTI(modelConvert.ID, jobID)
  160. }
  161. return nil
  162. }
  163. func getModelConvertCommand(name string, modelFile string) string {
  164. var command string
  165. bootFile := "convert_pytorch.py"
  166. command += "python3 /code/" + bootFile + " --model " + modelFile + " > " + ModelMountPath + "/" + name + "-" + LogFile
  167. return command
  168. }
  169. func DeleteModelConvert(ctx *context.Context) {
  170. log.Info("delete model convert start.")
  171. id := ctx.Params(":id")
  172. err := models.DeleteModelConvertById(id)
  173. if err != nil {
  174. ctx.JSON(500, err.Error())
  175. } else {
  176. ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelmanage/convert_model")
  177. }
  178. }
  179. func StopModelConvert(ctx *context.Context) {
  180. id := ctx.Params(":id")
  181. log.Info("stop model convert start.id=" + id)
  182. }
  183. func ShowModelConvertInfo(ctx *context.Context) {
  184. ctx.Data["ID"] = ctx.Query("ID")
  185. ctx.Data["isModelManage"] = true
  186. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  187. job, err := models.QueryModelConvertById(ctx.Query("ID"))
  188. if err == nil {
  189. ctx.Data["task"] = job
  190. }
  191. result, err := cloudbrain.GetJob(job.CloudBrainTaskId)
  192. if err != nil {
  193. log.Info("error:" + err.Error())
  194. ctx.Data["error"] = err.Error()
  195. return
  196. }
  197. if result != nil {
  198. jobRes, _ := models.ConvertToJobResultPayload(result.Payload)
  199. ctx.Data["result"] = jobRes
  200. taskRoles := jobRes.TaskRoles
  201. taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
  202. ctx.Data["taskRes"] = taskRes
  203. ctx.Data["ExitDiagnostics"] = taskRes.TaskStatuses[0].ExitDiagnostics
  204. job.Status = jobRes.JobStatus.State
  205. if jobRes.JobStatus.State != string(models.JobWaiting) && jobRes.JobStatus.State != string(models.JobFailed) {
  206. job.ContainerIp = taskRes.TaskStatuses[0].ContainerIP
  207. job.ContainerID = taskRes.TaskStatuses[0].ContainerID
  208. job.Status = taskRes.TaskStatuses[0].State
  209. }
  210. if jobRes.JobStatus.State != string(models.JobWaiting) {
  211. models.ModelComputeAndSetDuration(job, jobRes)
  212. err = models.UpdateModelConvert(job)
  213. if err != nil {
  214. log.Error("UpdateModelConvert failed:", err)
  215. }
  216. }
  217. }
  218. ctx.HTML(200, tplModelConvertInfo)
  219. }
  220. func ConvertModelTemplate(ctx *context.Context) {
  221. ctx.Data["isModelManage"] = true
  222. ctx.Data["MODEL_COUNT"] = 0
  223. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  224. ctx.Data["TRAIN_COUNT"] = 0
  225. ShowModelConvertPageInfo(ctx)
  226. ctx.HTML(200, tplModelManageConvertIndex)
  227. }
  228. func ShowModelConvertPageInfo(ctx *context.Context) {
  229. log.Info("ShowModelConvertInfo start.")
  230. if !isQueryRight(ctx) {
  231. log.Info("no right.")
  232. ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
  233. return
  234. }
  235. page := ctx.QueryInt("page")
  236. if page <= 0 {
  237. page = 1
  238. }
  239. pageSize := ctx.QueryInt("pageSize")
  240. if pageSize <= 0 {
  241. pageSize = setting.UI.IssuePagingNum
  242. }
  243. repoId := ctx.Repo.Repository.ID
  244. modelResult, count, err := models.QueryModelConvert(&models.AiModelQueryOptions{
  245. ListOptions: models.ListOptions{
  246. Page: page,
  247. PageSize: pageSize,
  248. },
  249. RepoID: repoId,
  250. })
  251. if err != nil {
  252. log.Info("query db error." + err.Error())
  253. ctx.ServerError("Cloudbrain", err)
  254. return
  255. }
  256. userIds := make([]int64, len(modelResult))
  257. for i, model := range modelResult {
  258. model.IsCanOper = isOper(ctx, model.UserId)
  259. model.IsCanDelete = isCanDelete(ctx, model.UserId)
  260. userIds[i] = model.UserId
  261. }
  262. userNameMap := queryUserName(userIds)
  263. for _, model := range modelResult {
  264. value := userNameMap[model.UserId]
  265. if value != nil {
  266. model.UserName = value.Name
  267. model.UserRelAvatarLink = value.RelAvatarLink()
  268. }
  269. }
  270. pager := context.NewPagination(int(count), page, pageSize, 5)
  271. ctx.Data["Page"] = pager
  272. ctx.Data["Tasks"] = modelResult
  273. }