You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. package repo
  2. import (
  3. "archive/zip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "os"
  9. "path"
  10. "path/filepath"
  11. "strings"
  12. "code.gitea.io/gitea/models"
  13. "code.gitea.io/gitea/modules/context"
  14. "code.gitea.io/gitea/modules/log"
  15. "code.gitea.io/gitea/modules/setting"
  16. "code.gitea.io/gitea/modules/storage"
  17. uuid "github.com/satori/go.uuid"
  18. )
  19. func SaveModelByParameters(trainTaskId string, name string, version string, label string, description string, userId int64) {
  20. aiTask, err := models.GetCloudbrainByJobID(trainTaskId)
  21. if err != nil {
  22. log.Info("query task error." + err.Error())
  23. //ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  24. return
  25. }
  26. uuid := uuid.NewV4()
  27. id := uuid.String()
  28. modelPath := id
  29. parent := id
  30. var modelSize int64
  31. cloudType := models.TypeCloudBrainTwo
  32. log.Info("find task name:" + aiTask.JobName)
  33. aimodels := models.QueryModelByName(name, userId)
  34. if len(aimodels) > 0 {
  35. for _, model := range aimodels {
  36. if model.ID == model.Parent {
  37. parent = model.ID
  38. }
  39. }
  40. }
  41. cloudType = aiTask.Type
  42. //download model zip //train type
  43. if cloudType == models.TypeCloudBrainTwo {
  44. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  45. if err == nil {
  46. } else {
  47. log.Info("download model from CloudBrainTwo faild." + err.Error())
  48. //ctx.Error(500, fmt.Sprintf("%v", err))
  49. return
  50. }
  51. }
  52. model := &models.AiModelManage{
  53. ID: id,
  54. Version: version,
  55. Label: label,
  56. Name: name,
  57. Description: description,
  58. Parent: parent,
  59. Type: cloudType,
  60. Path: modelPath,
  61. Size: modelSize,
  62. AttachmentId: aiTask.Uuid,
  63. RepoId: aiTask.RepoID,
  64. UserId: userId,
  65. }
  66. models.SaveModelToDb(model)
  67. log.Info("save model end.")
  68. }
  69. func SaveModel(ctx *context.Context) {
  70. log.Info("save model start.")
  71. trainTaskId := ctx.QueryInt64("TrainTask")
  72. name := ctx.Query("Name")
  73. version := ctx.Query("Version")
  74. label := ctx.Query("Label")
  75. description := ctx.Query("Description")
  76. aiTasks, _, err := models.Cloudbrains(&models.CloudbrainsOptions{
  77. JobID: trainTaskId,
  78. })
  79. if err != nil {
  80. log.Info("query task error." + err.Error())
  81. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  82. return
  83. }
  84. uuid := uuid.NewV4()
  85. id := uuid.String()
  86. modelPath := id
  87. parent := id
  88. var modelSize int64
  89. cloudType := models.TypeCloudBrainTwo
  90. if len(aiTasks) != 1 {
  91. log.Info("query task error. len=" + fmt.Sprint(len(aiTasks)))
  92. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  93. return
  94. }
  95. aiTask := aiTasks[0]
  96. log.Info("find task name:" + aiTask.JobName)
  97. aimodels := models.QueryModelByName(name, ctx.User.ID)
  98. if len(aimodels) > 0 {
  99. for _, model := range aimodels {
  100. if model.ID == model.Parent {
  101. parent = model.ID
  102. }
  103. }
  104. }
  105. cloudType = aiTask.Cloudbrain.Type
  106. //download model zip //train type
  107. if cloudType == models.TypeCloudBrainTrainJob {
  108. modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "")
  109. if err == nil {
  110. } else {
  111. log.Info("download model from CloudBrainTwo faild." + err.Error())
  112. ctx.Error(500, fmt.Sprintf("%v", err))
  113. return
  114. }
  115. }
  116. model := &models.AiModelManage{
  117. ID: id,
  118. Version: version,
  119. Label: label,
  120. Name: name,
  121. Description: description,
  122. Parent: parent,
  123. Type: cloudType,
  124. Path: modelPath,
  125. Size: modelSize,
  126. AttachmentId: aiTask.Uuid,
  127. RepoId: aiTask.RepoID,
  128. UserId: ctx.User.ID,
  129. }
  130. models.SaveModelToDb(model)
  131. log.Info("save model end.")
  132. }
  133. func downloadModelFromCloudBrainTwo(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  134. dataActualPath := setting.Bucket + "/" +
  135. "aimodels/" +
  136. models.AttachmentRelativePath(modelUUID) +
  137. "/"
  138. modelDbResult, err := storage.GetObsListObject(jobName, parentDir)
  139. if err != nil {
  140. log.Info("get TrainJobListModel failed:", err)
  141. return "", 0, err
  142. }
  143. if len(modelDbResult) == 0 {
  144. return "", 0, errors.New("cannot create model, as model is empty.")
  145. }
  146. prefix := strings.TrimPrefix(path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, parentDir), "/")
  147. for _, modelFile := range modelDbResult {
  148. destKeyNamePrefix := "aimodels/" + models.AttachmentRelativePath(modelUUID) + "/"
  149. log.Info("copy file, bucket=" + setting.Bucket + ", src keyname=" + prefix + modelFile.FileName)
  150. log.Info("Dest key name=" + destKeyNamePrefix + modelFile.FileName)
  151. err := storage.ObsCopyFile(setting.Bucket, prefix+modelFile.FileName, setting.Bucket, destKeyNamePrefix+modelFile.FileName)
  152. if err != nil {
  153. log.Info("copy failed.")
  154. }
  155. }
  156. return dataActualPath, 0, nil
  157. }
  158. func DeleteModel(ctx *context.Context) {
  159. log.Info("delete model start.")
  160. id := ctx.Query("ID")
  161. err := DeleteModelByID(id)
  162. if err != nil {
  163. ctx.JSON(500, err.Error())
  164. } else {
  165. ctx.JSON(200, map[string]string{
  166. "result_code": "0",
  167. })
  168. }
  169. }
  170. func DeleteModelByID(id string) error {
  171. log.Info("delete model start. id=" + id)
  172. return models.DeleteModelById(id)
  173. }
  174. func DownloadModel(ctx *context.Context) {
  175. log.Info("download model start.")
  176. }
  177. func QueryModelByParameters(repoId int64, page int) ([]*models.AiModelManage, int64, error) {
  178. return models.QueryModel(&models.AiModelQueryOptions{
  179. ListOptions: models.ListOptions{
  180. Page: page,
  181. PageSize: setting.UI.IssuePagingNum,
  182. },
  183. RepoID: repoId,
  184. Type: -1,
  185. })
  186. }
  187. func ShowModelInfo(ctx *context.Context) {
  188. log.Info("ShowModelInfo start.")
  189. page := ctx.QueryInt("page")
  190. if page <= 0 {
  191. page = 1
  192. }
  193. repoId := ctx.QueryInt64("repoId")
  194. Type := -1
  195. modelResult, count, err := models.QueryModel(&models.AiModelQueryOptions{
  196. ListOptions: models.ListOptions{
  197. Page: page,
  198. PageSize: setting.UI.IssuePagingNum,
  199. },
  200. RepoID: repoId,
  201. Type: Type,
  202. })
  203. if err != nil {
  204. ctx.ServerError("Cloudbrain", err)
  205. return
  206. }
  207. pager := context.NewPagination(int(count), setting.UI.IssuePagingNum, page, 5)
  208. pager.SetDefaultParams(ctx)
  209. ctx.Data["Page"] = pager
  210. ctx.Data["PageIsCloudBrain"] = true
  211. ctx.Data["Tasks"] = modelResult
  212. ctx.HTML(200, "")
  213. }
  214. func downloadModelFromCloudBrainOne(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  215. modelActualPath := setting.Attachment.Minio.RealPath +
  216. setting.Attachment.Minio.Bucket + "/" +
  217. "aimodels/" +
  218. models.AttachmentRelativePath(modelUUID) +
  219. "/"
  220. os.MkdirAll(modelActualPath, 0755)
  221. zipFile := modelActualPath + "model.zip"
  222. modelDir := setting.JobPath + jobName + "/model/"
  223. dir, _ := ioutil.ReadDir(modelDir)
  224. if len(dir) == 0 {
  225. return "", 0, errors.New("cannot create model, as model is empty.")
  226. }
  227. err := zipDir(modelDir, zipFile)
  228. if err != nil {
  229. return "", 0, err
  230. }
  231. fi, err := os.Stat(zipFile)
  232. if err == nil {
  233. return modelActualPath, fi.Size(), nil
  234. } else {
  235. return "", 0, err
  236. }
  237. }
  238. func zipDir(dir, zipFile string) error {
  239. fz, err := os.Create(zipFile)
  240. if err != nil {
  241. log.Info("Create zip file failed: %s\n", err.Error())
  242. return err
  243. }
  244. defer fz.Close()
  245. w := zip.NewWriter(fz)
  246. defer w.Close()
  247. err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
  248. if !info.IsDir() {
  249. fDest, err := w.Create(path[len(dir)+1:])
  250. if err != nil {
  251. log.Info("Create failed: %s\n", err.Error())
  252. return err
  253. }
  254. fSrc, err := os.Open(path)
  255. if err != nil {
  256. log.Info("Open failed: %s\n", err.Error())
  257. return err
  258. }
  259. defer fSrc.Close()
  260. _, err = io.Copy(fDest, fSrc)
  261. if err != nil {
  262. log.Info("Copy failed: %s\n", err.Error())
  263. return err
  264. }
  265. }
  266. return nil
  267. })
  268. if err != nil {
  269. return err
  270. }
  271. return nil
  272. }