diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index dde17c059..d66aa698d 100755 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -880,6 +880,8 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/log", repo.TrainJobGetLog) m.Post("/del_version", repo.DelTrainJobVersion) m.Post("/stop_version", repo.StopTrainJobVersion) + m.Get("/model_list", repo.ModelList) + m.Get("/model_download", repo.ModelDownload) // m.Group("/:version-name", func() { // m.Get("", repo.GetModelArtsTrainJobVersion) // }) diff --git a/routers/api/v1/repo/modelarts.go b/routers/api/v1/repo/modelarts.go index aea529da2..ba4015d3b 100755 --- a/routers/api/v1/repo/modelarts.go +++ b/routers/api/v1/repo/modelarts.go @@ -8,12 +8,14 @@ package repo import ( "net/http" "strconv" + "strings" "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/modelarts" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/storage" ) func GetModelArtsNotebook(ctx *context.APIContext) { @@ -298,3 +300,61 @@ func StopTrainJobVersion(ctx *context.APIContext) { "StatusOK": 0, }) } + +func ModelList(ctx *context.APIContext) { + var ( + err error + ) + + var jobID = ctx.Params(":jobid") + var versionName = ctx.Query("version_name") + parentDir := ctx.Query("parentDir") + dirArray := strings.Split(parentDir, "/") + task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) + if err != nil { + log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) + return + } + VersionOutputPath := "V" + strconv.Itoa(task.TotalVersionCount) + models, err := storage.GetObsListObjectVersion(task.JobName, parentDir, VersionOutputPath) + if err != nil { + log.Info("get TrainJobListModel failed:", err) + ctx.ServerError("GetObsListObject:", err) + return + } + + ctx.JSON(http.StatusOK, map[string]interface{}{ + "JobID": jobID, + "VersionName": versionName, + "StatusOK": 0, + "Path": dirArray, + "Dirs": models, + "task": task, + "PageIsCloudBrain": true, + }) +} + +func ModelDownload(ctx *context.APIContext) { + var ( + err error + ) + + var jobID = ctx.Params(":jobid") + versionName := ctx.Query("version_name") + parentDir := ctx.Query("parent_dir") + fileName := ctx.Query("file_name") + task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) + if err != nil { + log.Error("GetCloudbrainByJobID(%s) failed:%v", task.JobName, err.Error()) + return + } + VersionOutputPath := "V" + strconv.Itoa(task.TotalVersionCount) + + url, err := storage.GetObsCreateVersionSignedUrl(task.JobName, parentDir, fileName, VersionOutputPath) + if err != nil { + log.Error("GetObsCreateSignedUrl failed: %v", err.Error(), ctx.Data["msgID"]) + ctx.ServerError("GetObsCreateSignedUrl", err) + return + } + http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) +}