You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_convert.go 8.7 kB


  1. package repo
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "code.gitea.io/gitea/models"
  6. "code.gitea.io/gitea/modules/cloudbrain"
  7. "code.gitea.io/gitea/modules/context"
  8. "code.gitea.io/gitea/modules/log"
  9. "code.gitea.io/gitea/modules/setting"
  10. uuid "github.com/satori/go.uuid"
  11. )
  12. const (
  13. tplModelManageConvertIndex = "repo/modelmanage/convertIndex"
  14. tplModelConvertInfo = "repo/modelmanage/convertshowinfo"
  15. PYTORCH_ENGINE = 0
  16. TENSORFLOW_ENGINE = 1
  17. MINDSPORE_ENGIN = 2
  18. ModelMountPath = "/model"
  19. CodeMountPath = "/code"
  20. DataSetMountPath = "/dataset"
  21. LogFile = "log.txt"
  22. DefaultBranchName = "master"
  23. SubTaskName = "task1"
  24. GpuQueue = "openidebug"
  25. Success = "S000"
  26. GPU_PYTORCH_IMAGE = "dockerhub.pcl.ac.cn:5000/user-images/openi:tensorRT_7_zouap"
  27. )
  28. var (
  29. TrainResourceSpecs *models.ResourceSpecs
  30. )
  31. func SaveModelConvert(ctx *context.Context) {
  32. log.Info("save model convert start.")
  33. if !ctx.Repo.CanWrite(models.UnitTypeModelManage) {
  34. ctx.JSON(403, ctx.Tr("repo.model_noright"))
  35. return
  36. }
  37. name := ctx.Query("name")
  38. desc := ctx.Query("desc")
  39. modelId := ctx.Query("modelId")
  40. modelPath := ctx.Query("ModelFile")
  41. SrcEngine := ctx.QueryInt("SrcEngine")
  42. InputShape := ctx.Query("inputshape")
  43. InputDataFormat := ctx.Query("inputdataformat")
  44. DestFormat := ctx.QueryInt("DestFormat")
  45. NetOutputFormat := ctx.QueryInt("NetOutputFormat")
  46. task, err := models.QueryModelById(modelId)
  47. if err != nil {
  48. log.Error("no such model!", err.Error())
  49. ctx.ServerError("no such model:", err)
  50. return
  51. }
  52. uuid := uuid.NewV4()
  53. id := uuid.String()
  54. modelConvert := &models.AiModelConvert{
  55. ID: id,
  56. Name: name,
  57. Description: desc,
  58. Status: string(models.JobWaiting),
  59. SrcEngine: SrcEngine,
  60. RepoId: ctx.Repo.Repository.ID,
  61. ModelName: task.Name,
  62. ModelVersion: task.Version,
  63. ModelId: modelId,
  64. ModelPath: modelPath,
  65. DestFormat: DestFormat,
  66. NetOutputFormat: NetOutputFormat,
  67. InputShape: InputShape,
  68. InputDataFormat: InputDataFormat,
  69. UserId: ctx.User.ID,
  70. }
  71. models.SaveModelConvert(modelConvert)
  72. err = createTrainJob(modelConvert, ctx, task.Path)
  73. if err == nil {
  74. ctx.JSON(200, map[string]string{
  75. "result_code": "0",
  76. })
  77. } else {
  78. ctx.JSON(200, map[string]string{
  79. "result_code": "1",
  80. })
  81. }
  82. }
  83. func createTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context, modelRelativePath string) error {
  84. repo, _ := models.GetRepositoryByID(ctx.Repo.Repository.ID)
  85. if modelConvert.SrcEngine == PYTORCH_ENGINE {
  86. codePath := setting.JobPath + modelConvert.ID + CodeMountPath
  87. downloadCode(repo, codePath, DefaultBranchName)
  88. uploadCodeToMinio(codePath+"/", modelConvert.ID, CodeMountPath+"/")
  89. log.Info("minio code path=" + setting.CBCodePathPrefix + modelConvert.ID)
  90. modelPath := setting.JobPath + modelConvert.ID + ModelMountPath + "/"
  91. mkModelPath(modelPath)
  92. uploadCodeToMinio(modelPath, modelConvert.ID, ModelMountPath+"/")
  93. command := getModelConvertCommand(modelConvert.ID, modelConvert.ModelPath)
  94. log.Info("command=" + command)
  95. dataActualPath := setting.Attachment.Minio.RealPath + modelRelativePath
  96. if TrainResourceSpecs == nil {
  97. json.Unmarshal([]byte(setting.TrainResourceSpecs), &TrainResourceSpecs)
  98. }
  99. resourceSpec := TrainResourceSpecs.ResourceSpec[1]
  100. jobResult, err := cloudbrain.CreateJob(modelConvert.ID, models.CreateJobParams{
  101. JobName: modelConvert.ID,
  102. RetryCount: 1,
  103. GpuType: GpuQueue,
  104. Image: GPU_PYTORCH_IMAGE,
  105. TaskRoles: []models.TaskRole{
  106. {
  107. Name: SubTaskName,
  108. TaskNumber: 1,
  109. MinSucceededTaskCount: 1,
  110. MinFailedTaskCount: 1,
  111. CPUNumber: resourceSpec.CpuNum,
  112. GPUNumber: resourceSpec.GpuNum,
  113. MemoryMB: resourceSpec.MemMiB,
  114. ShmMB: resourceSpec.ShareMemMiB,
  115. Command: command,
  116. NeedIBDevice: false,
  117. IsMainRole: false,
  118. UseNNI: false,
  119. },
  120. },
  121. Volumes: []models.Volume{
  122. {
  123. HostPath: models.StHostPath{
  124. Path: codePath,
  125. MountPath: CodeMountPath,
  126. ReadOnly: false,
  127. },
  128. },
  129. {
  130. HostPath: models.StHostPath{
  131. Path: dataActualPath,
  132. MountPath: DataSetMountPath,
  133. ReadOnly: true,
  134. },
  135. },
  136. {
  137. HostPath: models.StHostPath{
  138. Path: modelPath,
  139. MountPath: ModelMountPath,
  140. ReadOnly: false,
  141. },
  142. },
  143. },
  144. })
  145. if err != nil {
  146. log.Error("CreateJob failed:", err.Error(), ctx.Data["MsgID"])
  147. return err
  148. }
  149. if jobResult.Code != Success {
  150. log.Error("CreateJob(%s) failed:%s", modelConvert.ID, jobResult.Msg, ctx.Data["MsgID"])
  151. return errors.New(jobResult.Msg)
  152. }
  153. var jobID = jobResult.Payload["jobId"].(string)
  154. log.Info("jobId=" + jobID)
  155. models.UpdateModelConvertCBTI(modelConvert.ID, jobID)
  156. }
  157. return nil
  158. }
  159. func getModelConvertCommand(name string, modelFile string) string {
  160. var command string
  161. bootFile := "convert_pytorch.py"
  162. command += "python3 /code/" + bootFile + " --model " + modelFile + " > " + ModelMountPath + "/" + name + "-" + LogFile
  163. return command
  164. }
  165. func DeleteModelConvert(ctx *context.Context) {
  166. log.Info("delete model convert start.")
  167. id := ctx.Params(":id")
  168. err := models.DeleteModelConvertById(id)
  169. if err != nil {
  170. ctx.JSON(500, err.Error())
  171. } else {
  172. ctx.Redirect(setting.AppSubURL + ctx.Repo.RepoLink + "/modelmanage/convert_model")
  173. }
  174. }
  175. func StopModelConvert(ctx *context.Context) {
  176. id := ctx.Params(":id")
  177. log.Info("stop model convert start.id=" + id)
  178. }
  179. func ShowModelConvertInfo(ctx *context.Context) {
  180. ctx.Data["ID"] = ctx.Query("ID")
  181. ctx.Data["isModelManage"] = true
  182. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  183. job, err := models.QueryModelConvertById(ctx.Query("ID"))
  184. if err == nil {
  185. ctx.Data["task"] = job
  186. }
  187. result, err := cloudbrain.GetJob(job.CloudBrainTaskId)
  188. if err != nil {
  189. log.Info("error:" + err.Error())
  190. ctx.Data["error"] = err.Error()
  191. return
  192. }
  193. if result != nil {
  194. jobRes, _ := models.ConvertToJobResultPayload(result.Payload)
  195. ctx.Data["result"] = jobRes
  196. taskRoles := jobRes.TaskRoles
  197. taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
  198. ctx.Data["taskRes"] = taskRes
  199. ctx.Data["ExitDiagnostics"] = taskRes.TaskStatuses[0].ExitDiagnostics
  200. job.Status = jobRes.JobStatus.State
  201. if jobRes.JobStatus.State != string(models.JobWaiting) && jobRes.JobStatus.State != string(models.JobFailed) {
  202. job.ContainerIp = taskRes.TaskStatuses[0].ContainerIP
  203. job.ContainerID = taskRes.TaskStatuses[0].ContainerID
  204. job.Status = taskRes.TaskStatuses[0].State
  205. }
  206. if jobRes.JobStatus.State != string(models.JobWaiting) {
  207. models.ModelComputeAndSetDuration(job, jobRes)
  208. err = models.UpdateModelConvert(job)
  209. if err != nil {
  210. log.Error("UpdateModelConvert failed:", err)
  211. }
  212. }
  213. }
  214. ctx.HTML(200, tplModelConvertInfo)
  215. }
  216. func ConvertModelTemplate(ctx *context.Context) {
  217. ctx.Data["isModelManage"] = true
  218. ctx.Data["MODEL_COUNT"] = 0
  219. ctx.Data["ModelManageAccess"] = ctx.Repo.CanWrite(models.UnitTypeModelManage)
  220. ctx.Data["TRAIN_COUNT"] = 0
  221. ShowModelConvertPageInfo(ctx)
  222. ctx.HTML(200, tplModelManageConvertIndex)
  223. }
  224. func ShowModelConvertPageInfo(ctx *context.Context) {
  225. log.Info("ShowModelConvertInfo start.")
  226. if !isQueryRight(ctx) {
  227. log.Info("no right.")
  228. ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
  229. return
  230. }
  231. page := ctx.QueryInt("page")
  232. if page <= 0 {
  233. page = 1
  234. }
  235. pageSize := ctx.QueryInt("pageSize")
  236. if pageSize <= 0 {
  237. pageSize = setting.UI.IssuePagingNum
  238. }
  239. repoId := ctx.Repo.Repository.ID
  240. modelResult, count, err := models.QueryModelConvert(&models.AiModelQueryOptions{
  241. ListOptions: models.ListOptions{
  242. Page: page,
  243. PageSize: pageSize,
  244. },
  245. RepoID: repoId,
  246. })
  247. if err != nil {
  248. log.Info("query db error." + err.Error())
  249. ctx.ServerError("Cloudbrain", err)
  250. return
  251. }
  252. userIds := make([]int64, len(modelResult))
  253. for i, model := range modelResult {
  254. model.IsCanOper = isOper(ctx, model.UserId)
  255. model.IsCanDelete = isCanDelete(ctx, model.UserId)
  256. userIds[i] = model.UserId
  257. }
  258. userNameMap := queryUserName(userIds)
  259. for _, model := range modelResult {
  260. value := userNameMap[model.UserId]
  261. if value != nil {
  262. model.UserName = value.Name
  263. model.UserRelAvatarLink = value.RelAvatarLink()
  264. }
  265. }
  266. pager := context.NewPagination(int(count), page, pageSize, 5)
  267. ctx.Data["Page"] = pager
  268. ctx.Data["Tasks"] = modelResult
  269. }