You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ai_model_manage.go 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package repo
  2. import (
  3. "archive/zip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "os"
  9. "path/filepath"
  10. "code.gitea.io/gitea/models"
  11. "code.gitea.io/gitea/modules/context"
  12. "code.gitea.io/gitea/modules/log"
  13. "code.gitea.io/gitea/modules/setting"
  14. uuid "github.com/satori/go.uuid"
  15. )
  16. func SaveModel(ctx *context.Context) {
  17. log.Info("save model start.")
  18. trainTaskId := ctx.QueryInt64("TrainTask")
  19. name := ctx.Query("Name")
  20. version := ctx.Query("Version")
  21. label := ctx.Query("Label")
  22. description := ctx.Query("Description")
  23. aiTasks, _, err := models.Cloudbrains(&models.CloudbrainsOptions{
  24. JobID: trainTaskId,
  25. })
  26. if err != nil {
  27. log.Info("query task error." + err.Error())
  28. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  29. return
  30. }
  31. uuid := uuid.NewV4()
  32. id := uuid.String()
  33. modelPath := id
  34. parent := id
  35. var modelSize int64
  36. cloudType := models.TypeCloudBrainTwo
  37. if len(aiTasks) != 1 {
  38. log.Info("query task error. len=" + fmt.Sprint(len(aiTasks)))
  39. ctx.Error(500, fmt.Sprintf("query cloud brain train task error. %v", err))
  40. return
  41. }
  42. aiTask := aiTasks[0]
  43. log.Info("find task name:" + aiTask.JobName)
  44. aimodels := models.QueryModelByName(name, ctx.User.ID)
  45. if len(aimodels) > 0 {
  46. for _, model := range aimodels {
  47. if model.ID == model.Parent {
  48. parent = model.ID
  49. }
  50. }
  51. }
  52. cloudType = aiTask.Cloudbrain.Type
  53. //download model zip
  54. if cloudType == models.TypeCloudBrainOne {
  55. modelPath, modelSize, err = downloadModelFromCloudBrainOne(id, aiTask.JobName, "")
  56. if err != nil {
  57. log.Info("download model from CloudBrainOne faild." + err.Error())
  58. ctx.Error(500, fmt.Sprintf("%v", err))
  59. return
  60. }
  61. } else if cloudType == models.TypeCloudBrainTwo {
  62. modelPath, err = downloadModelFromCloudBrainTwo(id)
  63. if err == nil {
  64. } else {
  65. log.Info("download model from CloudBrainTwo faild." + err.Error())
  66. ctx.Error(500, fmt.Sprintf("%v", err))
  67. return
  68. }
  69. }
  70. model := &models.AiModelManage{
  71. ID: id,
  72. Version: version,
  73. Label: label,
  74. Name: name,
  75. Description: description,
  76. Parent: parent,
  77. Type: cloudType,
  78. Path: modelPath,
  79. Size: modelSize,
  80. AttachmentId: aiTask.Uuid,
  81. RepoId: aiTask.RepoID,
  82. UserId: ctx.User.ID,
  83. }
  84. models.SaveModelToDb(model)
  85. log.Info("save model end.")
  86. }
  87. func downloadModelFromCloudBrainOne(modelUUID string, jobName string, parentDir string) (string, int64, error) {
  88. modelActualPath := setting.Attachment.Minio.RealPath +
  89. setting.Attachment.Minio.Bucket + "/" +
  90. "aimodels/" +
  91. models.AttachmentRelativePath(modelUUID) +
  92. "/"
  93. os.MkdirAll(modelActualPath, 0755)
  94. zipFile := modelActualPath + "model.zip"
  95. modelDir := setting.JobPath + jobName + "/model/"
  96. dir, _ := ioutil.ReadDir(modelDir)
  97. if len(dir) == 0 {
  98. return "", 0, errors.New("cannot create model, as model is empty.")
  99. }
  100. err := zipDir(modelDir, zipFile)
  101. if err != nil {
  102. return "", 0, err
  103. }
  104. fi, err := os.Stat(zipFile)
  105. if err == nil {
  106. return modelActualPath, fi.Size(), nil
  107. } else {
  108. return "", 0, err
  109. }
  110. }
  111. func zipDir(dir, zipFile string) error {
  112. fz, err := os.Create(zipFile)
  113. if err != nil {
  114. log.Info("Create zip file failed: %s\n", err.Error())
  115. return err
  116. }
  117. defer fz.Close()
  118. w := zip.NewWriter(fz)
  119. defer w.Close()
  120. err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
  121. if !info.IsDir() {
  122. fDest, err := w.Create(path[len(dir)+1:])
  123. if err != nil {
  124. log.Info("Create failed: %s\n", err.Error())
  125. return err
  126. }
  127. fSrc, err := os.Open(path)
  128. if err != nil {
  129. log.Info("Open failed: %s\n", err.Error())
  130. return err
  131. }
  132. defer fSrc.Close()
  133. _, err = io.Copy(fDest, fSrc)
  134. if err != nil {
  135. log.Info("Copy failed: %s\n", err.Error())
  136. return err
  137. }
  138. }
  139. return nil
  140. })
  141. if err != nil {
  142. return err
  143. }
  144. return nil
  145. }
  146. func downloadModelFromCloudBrainTwo(modelUUID string) (string, error) {
  147. dataActualPath := setting.Bucket + "/" +
  148. "aimodels/" +
  149. models.AttachmentRelativePath(modelUUID) +
  150. "/"
  151. return dataActualPath, nil
  152. }
  153. func DeleteModel(ctx *context.Context) {
  154. log.Info("delete model start.")
  155. id := ctx.Query("ID")
  156. err := models.DeleteModelById(id)
  157. if err != nil {
  158. ctx.JSON(500, err.Error())
  159. } else {
  160. ctx.JSON(200, map[string]string{
  161. "result_code": "0",
  162. })
  163. }
  164. }
  165. func DownloadModel(ctx *context.Context) {
  166. log.Info("download model start.")
  167. }
  168. func ShowModelInfo(ctx *context.Context) {
  169. log.Info("ShowModelInfo.")
  170. }