You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package repo
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "path"
  7. "strings"
  8. "code.gitea.io/gitea/models"
  9. "code.gitea.io/gitea/modules/context"
  10. "code.gitea.io/gitea/modules/log"
  11. "code.gitea.io/gitea/modules/setting"
  12. "code.gitea.io/gitea/modules/storage"
  13. uuid "github.com/satori/go.uuid"
  14. )
  15. const (
  16. Model_prefix = "aimodels/"
  17. tplModelManageIndex = "repo/modelmanage/index"
  18. tplModelManageDownload = "repo/modelmanage/download"
  19. )
  20. func SaveModelByParameters(jobId string, name string, version string, label string, description string, userId int64) error {
  21. aiTask, err := models.GetCloudbrainByJobID(jobId)
  22. if err != nil {
  23. log.Info("query task error." + err.Error())
  24. return err
  25. }
  26. uuid := uuid.NewV4()
  27. id := uuid.String()
  28. modelPath := id
  29. parent := id
  30. var modelSize int64
  31. cloudType := models.TypeCloudBrainTwo
  32. log.Info("find task name:" + aiTask.JobName)
  33. aimodels := models.QueryModelByName(name, userId)
  34. if len(aimodels) > 0 {
  35. for _, model := range aimodels {
  36. if model.ID == model.Parent {
  37. parent = model.ID
  38. }
  39. }
  40. }
  41. cloudType = aiTask.Type
  42. //download model zip //train type
  43. if cloudType == models.TypeCloudBrainTwo {
  44. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  45. if err != nil {
  46. log.Info("download model from CloudBrainTwo faild." + err.Error())
  47. return err
  48. }
  49. }
  50. model := &models.AiModelManage{
  51. ID: id,
  52. Version: version,
  53. Label: label,
  54. Name: name,
  55. Description: description,
  56. Parent: parent,
  57. Type: cloudType,
  58. Path: modelPath,
  59. Size: modelSize,
  60. AttachmentId: aiTask.Uuid,
  61. RepoId: aiTask.RepoID,
  62. UserId: userId,
  63. }
  64. models.SaveModelToDb(model)
  65. log.Info("save model end.")
  66. return nil
  67. }
  68. func SaveModel(ctx *context.Context) {
  69. log.Info("save model start.")
  70. JobId := ctx.Query("JobId")
  71. name := ctx.Query("Name")
  72. version := ctx.Query("Version")
  73. label := ctx.Query("Label")
  74. description := ctx.Query("Description")
  75. err := SaveModelByParameters(JobId, name, version, label, description, ctx.User.ID)
  76. if err != nil {
  77. log.Info("save model error." + err.Error())
  78. ctx.Error(500, fmt.Sprintf("save model error. %v", err))
  79. return
  80. }
  81. log.Info("save model end.")
  82. }
  83. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  84. dataActualPath := setting.Bucket + "/" + Model_prefix +
  85. models.AttachmentRelativePath(modelUUID) +
  86. "/"
  87. objectkey := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
  88. modelDbResult, err := storage.GetAllObsListObjectUnderDir(setting.Bucket, objectkey)
  89. if err != nil {
  90. log.Info("get TrainJobListModel failed:", err)
  91. return "", 0, err
  92. }
  93. if len(modelDbResult) == 0 {
  94. return "", 0, errors.New("cannot create model, as model is empty.")
  95. }
  96. var size int64
  97. prefix := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/") + "/"
  98. destKeyNamePrefix := Model_prefix + models.AttachmentRelativePath(modelUUID) + "/"
  99. for _, modelFile := range modelDbResult {
  100. if modelFile.IsDir {
  101. log.Info("copy dir, continue. dir=" + modelFile.FileName)
  102. continue
  103. }
  104. srcKeyName := prefix + modelFile.ParenDir + modelFile.FileName
  105. log.Info("copy file, bucket=" + setting.Bucket + ", src keyname=" + srcKeyName)
  106. destKeyName := destKeyNamePrefix + modelFile.ParenDir + modelFile.FileName
  107. log.Info("Dest key name=" + destKeyName)
  108. err := storage.ObsCopyFile(setting.Bucket, srcKeyName, setting.Bucket, destKeyName)
  109. if err != nil {
  110. log.Info("copy failed.")
  111. }
  112. size += modelFile.Size
  113. }
  114. return dataActualPath, size, nil
  115. }
  116. func DeleteModel(ctx *context.Context) {
  117. log.Info("delete model start.")
  118. id := ctx.Query("ID")
  119. err := DeleteModelByID(id)
  120. if err != nil {
  121. ctx.JSON(500, err.Error())
  122. } else {
  123. ctx.JSON(200, map[string]string{
  124. "result_code": "0",
  125. })
  126. }
  127. }
  128. func DeleteModelByID(id string) error {
  129. log.Info("delete model start. id=" + id)
  130. model, err := models.QueryModelById(id)
  131. if err == nil {
  132. log.Info("bucket=" + setting.Bucket + " path=" + model.Path)
  133. if strings.HasPrefix(model.Path, setting.Bucket+"/"+Model_prefix) {
  134. err := storage.ObsRemoveObject(setting.Bucket, model.Path[len(setting.Bucket)+1:])
  135. if err != nil {
  136. log.Info("Failed to delete model. id=" + id)
  137. return err
  138. }
  139. }
  140. return models.DeleteModelById(id)
  141. }
  142. return err
  143. }
  144. func DownloadModel(ctx *context.Context) {
  145. log.Info("download model start.")
  146. }
  147. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  148. return models.QueryModel(&models.AiModelQueryOptions{
  149. ListOptions: models.ListOptions{
  150. Page: page,
  151. PageSize: setting.UI.IssuePagingNum,
  152. },
  153. RepoID: repoId,
  154. Type: -1,
  155. })
  156. }
  157. func DownloadMultiModelFile(ctx *context.Context) {
  158. log.Info("DownloadMultiModelFile start.")
  159. id := ctx.Query("ID")
  160. log.Info("id=" + id)
  161. }
  162. func DownloadSingleModelFile(ctx *context.Context) {
  163. log.Info("DownloadSingleModelFile start.")
  164. path := ctx.Query("path")
  165. url, err := storage.GetObsCreateSignedUrlByBucketAndKey(setting.Bucket, path[len(setting.Bucket)+1:])
  166. if err != nil {
  167. log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"])
  168. ctx.ServerError("GetObsCreateSignedUrl", err)
  169. return
  170. }
  171. http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
  172. }
  173. func ShowSingleModel(ctx *context.Context) {
  174. id := ctx.Params(":ID")
  175. log.Info("Show single ModelInfo start.id=" + id)
  176. task, err := models.QueryModelById(id)
  177. if err != nil {
  178. log.Error("no such model!", err.Error())
  179. ctx.ServerError("no such model:", err)
  180. return
  181. }
  182. log.Info("bucket=" + setting.Bucket + " key=" + task.Path[len(setting.Bucket)+1:])
  183. models, err := storage.GetObsListObjectByBucketAndPrefix(setting.Bucket, task.Path[len(setting.Bucket)+1:])
  184. if err != nil {
  185. log.Info("get model list failed:", err)
  186. ctx.ServerError("GetObsListObject:", err)
  187. return
  188. } else {
  189. log.Info("get model file,size=" + fmt.Sprint(len(models)))
  190. }
  191. ctx.Data["Dirs"] = models
  192. ctx.Data["task"] = task
  193. ctx.Data["ID"] = id
  194. ctx.HTML(200, tplModelManageDownload)
  195. }
  196. func ShowModelPageInfo(ctx *context.Context) {
  197. log.Info("ShowModelInfo start.")
  198. page := ctx.QueryInt("page")
  199. if page <= 0 {
  200. page = 1
  201. }
  202. repoId := ctx.Repo.Repository.ID
  203. Type := -1
  204. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  205. ListOptions: models.ListOptions{
  206. Page: page,
  207. PageSize: setting.UI.IssuePagingNum,
  208. },
  209. RepoID: repoId,
  210. Type: Type,
  211. })
  212. if err != nil {
  213. ctx.ServerError("Cloudbrain", err)
  214. return
  215. }
  216. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  217. pager.SetDefaultParams(ctx)
  218. ctx.Data["Page"] = pager
  219. ctx.Data["PageIsCloudBrain"] = true
  220. ctx.Data["Tasks"] = modelResult
  221. ctx.HTML(200, tplModelManageIndex)
  222. }
  223. func ModifyModel(id string, description string) error {
  224. err := models.ModifyModelDescription(id, description)
  225. if err == nil {
  226. log.Info("modify success.")
  227. } else {
  228. log.Info("Failed to modify.id=" + id + " desc=" + description + " error:" + err.Error())
  229. }
  230. return err
  231. }
  232. func ModifyModelInfo(ctx *context.Context) {
  233. log.Info("delete model start.")
  234. id := ctx.Query("ID")
  235. description := ctx.Query("Description")
  236. err := ModifyModel(id, description)
  237. if err == nil {
  238. ctx.HTML(200, "success")
  239. } else {
  240. ctx.HTML(500, "Failed.")
  241. }
  242. }