You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 6.9 kB


  1. package repo
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "path"
  7. "strings"
  8. "code.gitea.io/gitea/models"
  9. "code.gitea.io/gitea/modules/context"
  10. "code.gitea.io/gitea/modules/log"
  11. "code.gitea.io/gitea/modules/setting"
  12. "code.gitea.io/gitea/modules/storage"
  13. uuid "github.com/satori/go.uuid"
  14. )
  15. const (
  16. Model_prefix = "aimodels/"
  17. tplModelManageIndex = "repo/modelmanage/index"
  18. )
  19. func SaveModelByParameters(jobId string, name string, version string, label string, description string, userId int64) error {
  20. aiTask, err := models.GetCloudbrainByJobID(jobId)
  21. if err != nil {
  22. log.Info("query task error." + err.Error())
  23. return err
  24. }
  25. uuid := uuid.NewV4()
  26. id := uuid.String()
  27. modelPath := id
  28. parent := id
  29. var modelSize int64
  30. cloudType := models.TypeCloudBrainTwo
  31. log.Info("find task name:" + aiTask.JobName)
  32. aimodels := models.QueryModelByName(name, userId)
  33. if len(aimodels) > 0 {
  34. for _, model := range aimodels {
  35. if model.ID == model.Parent {
  36. parent = model.ID
  37. }
  38. }
  39. }
  40. cloudType = aiTask.Type
  41. //download model zip //train type
  42. if cloudType == models.TypeCloudBrainTwo {
  43. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  44. if err != nil {
  45. log.Info("download model from CloudBrainTwo faild." + err.Error())
  46. return err
  47. }
  48. }
  49. model := &models.AiModelManage{
  50. ID: id,
  51. Version: version,
  52. Label: label,
  53. Name: name,
  54. Description: description,
  55. Parent: parent,
  56. Type: cloudType,
  57. Path: modelPath,
  58. Size: modelSize,
  59. AttachmentId: aiTask.Uuid,
  60. RepoId: aiTask.RepoID,
  61. UserId: userId,
  62. }
  63. models.SaveModelToDb(model)
  64. log.Info("save model end.")
  65. return nil
  66. }
  67. func SaveModel(ctx *context.Context) {
  68. log.Info("save model start.")
  69. JobId := ctx.Query("JobId")
  70. name := ctx.Query("Name")
  71. version := ctx.Query("Version")
  72. label := ctx.Query("Label")
  73. description := ctx.Query("Description")
  74. err := SaveModelByParameters(JobId, name, version, label, description, ctx.User.ID)
  75. if err != nil {
  76. log.Info("save model error." + err.Error())
  77. ctx.Error(500, fmt.Sprintf("save model error. %v", err))
  78. return
  79. }
  80. log.Info("save model end.")
  81. }
  82. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  83. dataActualPath := setting.Bucket + "/" + Model_prefix +
  84. models.AttachmentRelativePath(modelUUID) +
  85. "/"
  86. modelDbResult, err := storage.GetObsListObject(jobName, parentDir)
  87. if err != nil {
  88. log.Info("get TrainJobListModel failed:", err)
  89. return "", 0, err
  90. }
  91. if len(modelDbResult) == 0 {
  92. return "", 0, errors.New("cannot create model, as model is empty.")
  93. }
  94. var size int64
  95. prefix := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/") + "/"
  96. for _, modelFile := range modelDbResult {
  97. destKeyNamePrefix := Model_prefix + models.AttachmentRelativePath(modelUUID) + "/"
  98. log.Info("copy file, bucket=" + setting.Bucket + ", src keyname=" + prefix + modelFile.FileName)
  99. log.Info("Dest key name=" + destKeyNamePrefix + modelFile.FileName)
  100. err := storage.ObsCopyFile(setting.Bucket, prefix+modelFile.FileName, setting.Bucket, destKeyNamePrefix+modelFile.FileName)
  101. if err != nil {
  102. log.Info("copy failed.")
  103. }
  104. size += modelFile.Size
  105. }
  106. return dataActualPath, size, nil
  107. }
  108. func DeleteModel(ctx *context.Context) {
  109. log.Info("delete model start.")
  110. id := ctx.Query("ID")
  111. err := DeleteModelByID(id)
  112. if err != nil {
  113. ctx.JSON(500, err.Error())
  114. } else {
  115. ctx.JSON(200, map[string]string{
  116. "result_code": "0",
  117. })
  118. }
  119. }
  120. func DeleteModelByID(id string) error {
  121. log.Info("delete model start. id=" + id)
  122. model, err := models.QueryModelById(id)
  123. if err == nil {
  124. log.Info("bucket=" + setting.Bucket + " path=" + model.Path)
  125. if strings.HasPrefix(model.Path, setting.Bucket+"/"+Model_prefix) {
  126. err := storage.ObsRemoveObject(setting.Bucket, model.Path[len(setting.Bucket)+1:])
  127. if err != nil {
  128. log.Info("Failed to delete model. id=" + id)
  129. return err
  130. }
  131. }
  132. return models.DeleteModelById(id)
  133. }
  134. return err
  135. }
  136. func DownloadModel(ctx *context.Context) {
  137. log.Info("download model start.")
  138. }
  139. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  140. return models.QueryModel(&models.AiModelQueryOptions{
  141. ListOptions: models.ListOptions{
  142. Page: page,
  143. PageSize: setting.UI.IssuePagingNum,
  144. },
  145. RepoID: repoId,
  146. Type: -1,
  147. })
  148. }
  149. func DownloadMultiModelFile(ctx *context.Context) {
  150. log.Info("DownloadMultiModelFile start.")
  151. id := ctx.Query("ID")
  152. log.Info("id=" + id)
  153. }
  154. func DownloadSingleModelFile(ctx *context.Context) {
  155. log.Info("DownloadSingleModelFile start.")
  156. path := ctx.Query("path")
  157. url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path[len(setting.Bucket)+1:])
  158. if err != nil {
  159. log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
  160. ctx.ServerError("GetObsCreateSignedUrl", err)
  161. return
  162. }
  163. http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
  164. }
  165. func ShowSingleModel(ctx *context.Context) {
  166. log.Info("Show single ModelInfo start.")
  167. id := ctx.Query("ID")
  168. task, err := models.QueryModelById(id)
  169. if err != nil {
  170. log.Error("no such model!", ctx.Data["msgID"])
  171. ctx.ServerError("no such model:", err)
  172. return
  173. }
  174. models, err := storage.GetObsListObjectByBucketAndPrefix(setting.Bucket, task.Path[len(setting.Bucket)+1:])
  175. if err != nil {
  176. log.Info("get TrainJobListModel failed:", err)
  177. ctx.ServerError("GetObsListObject:", err)
  178. return
  179. }
  180. ctx.Data["Dirs"] = models
  181. ctx.Data["task"] = task
  182. ctx.Data["ID"] = id
  183. }
  184. func ShowModelPageInfo(ctx *context.Context) {
  185. log.Info("ShowModelInfo start.")
  186. page := ctx.QueryInt("page")
  187. if page <= 0 {
  188. page = 1
  189. }
  190. repoId := ctx.Repo.Repository.ID
  191. Type := -1
  192. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  193. ListOptions: models.ListOptions{
  194. Page: page,
  195. PageSize: setting.UI.IssuePagingNum,
  196. },
  197. RepoID: repoId,
  198. Type: Type,
  199. })
  200. if err != nil {
  201. ctx.ServerError("Cloudbrain", err)
  202. return
  203. }
  204. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  205. pager.SetDefaultParams(ctx)
  206. ctx.Data["Page"] = pager
  207. ctx.Data["PageIsCloudBrain"] = true
  208. ctx.Data["Tasks"] = modelResult
  209. ctx.HTML(200, tplModelManageIndex)
  210. }
  211. func ModifyModel(id string, description string) error {
  212. err := models.ModifyModelDescription(id, description)
  213. if err == nil {
  214. log.Info("modify success.")
  215. } else {
  216. log.Info("Failed to modify.id=" + id + " desc=" + description + " error:" + err.Error())
  217. }
  218. return err
  219. }
  220. func ModifyModelInfo(ctx *context.Context) {
  221. log.Info("delete model start.")
  222. id := ctx.Query("ID")
  223. description := ctx.Query("Description")
  224. err := ModifyModel(id, description)
  225. if err == nil {
  226. ctx.HTML(200, "success")
  227. } else {
  228. ctx.HTML(500, "Failed.")
  229. }
  230. }