Browse Source

刷新训练版本的状态和时长

tags/v1.21.12.1
liuzx 4 years ago
parent
commit
57b1a05ab5
3 changed files with 57 additions and 13 deletions
  1. +5
    -0
      models/cloudbrain.go
  2. +3
    -0
      routers/api/v1/api.go
  3. +49
    -13
      routers/api/v1/repo/modelarts.go

+ 5
- 0
models/cloudbrain.go View File

@@ -1066,6 +1066,11 @@ func GetRepoCloudBrainByJobID(repoID int64, jobID string) (*Cloudbrain, error) {
return getRepoCloudBrain(cb)
}

func GetRepoCloudBrainByJobIDAndVersionName(repoID int64, jobID string, versionName string) (*Cloudbrain, error) {
cb := &Cloudbrain{JobID: jobID, RepoID: repoID, VersionName: versionName}
return getRepoCloudBrain(cb)
}

func GetCloudbrainByJobID(jobID string) (*Cloudbrain, error) {
cb := &Cloudbrain{JobID: jobID}
return getRepoCloudBrain(cb)


+ 3
- 0
routers/api/v1/api.go View File

@@ -875,6 +875,9 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Group("/:jobid", func() {
m.Get("", repo.GetModelArtsTrainJob)
m.Get("/log", repo.TrainJobGetLog)
m.Group("/:version-name", func() {
m.Get("", repo.GetModelArtsTrainJobVersion)
})
})
})
}, reqRepoReader(models.UnitTypeCloudBrain))


+ 49
- 13
routers/api/v1/repo/modelarts.go View File

@@ -6,12 +6,13 @@
package repo

import (
"net/http"
"strconv"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/modelarts"
"net/http"
"strconv"
)

func GetModelArtsNotebook(ctx *context.APIContext) {
@@ -72,9 +73,44 @@ func GetModelArtsTrainJob(ctx *context.APIContext) {
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"JobStatus": job.Status,
"JobDuration": job.Duration,
"JobID": jobID,
"JobStatus": job.Status,
"JobDuration": job.Duration,
})

}

func GetModelArtsTrainJobVersion(ctx *context.APIContext) {
var (
err error
)

jobID := ctx.Params(":jobid")
versionName := ctx.Params(":version-name")
repoID := ctx.Repo.Repository.ID
job, err := models.GetRepoCloudBrainByJobIDAndVersionName(repoID, jobID, versionName)
if err != nil {
ctx.NotFound(err)
return
}
result, err := modelarts.GetTrainJob(jobID, strconv.FormatInt(job.VersionID, 10))
if err != nil {
ctx.NotFound(err)
return
}

job.Status = modelarts.TransTrainJobStatus(result.IntStatus)
job.Duration = result.Duration
job.TrainJobDuration = result.TrainJobDuration
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"JobStatus": job.Status,
"JobDuration": job.Duration,
})

}
@@ -94,7 +130,7 @@ func TrainJobGetLog(ctx *context.APIContext) {
if order != modelarts.OrderDesc && order != modelarts.OrderAsc {
log.Error("order(%s) check failed", order)
ctx.JSON(http.StatusBadRequest, map[string]interface{}{
"err_msg": "order check failed",
"err_msg": "order check failed",
})
return
}
@@ -103,7 +139,7 @@ func TrainJobGetLog(ctx *context.APIContext) {
if err != nil {
log.Error("GetCloudbrainByJobID(%s) failed:%v", jobID, err.Error())
ctx.JSON(http.StatusInternalServerError, map[string]interface{}{
"err_msg": "GetCloudbrainByJobID failed",
"err_msg": "GetCloudbrainByJobID failed",
})
return
}
@@ -112,16 +148,16 @@ func TrainJobGetLog(ctx *context.APIContext) {
if err != nil {
log.Error("GetTrainJobLog(%s) failed:%v", jobID, err.Error())
ctx.JSON(http.StatusInternalServerError, map[string]interface{}{
"err_msg": "GetTrainJobLog failed",
"err_msg": "GetTrainJobLog failed",
})
return
}

ctx.JSON(http.StatusOK, map[string]interface{}{
"JobID": jobID,
"StartLine": result.StartLine,
"EndLine": result.EndLine,
"Content": result.Content,
"Lines": result.Lines,
"JobID": jobID,
"StartLine": result.StartLine,
"EndLine": result.EndLine,
"Content": result.Content,
"Lines": result.Lines,
})
}

Loading…
Cancel
Save