Browse Source

Merge branch 'liuzx_trainjob' of https://git.openi.org.cn/OpenI/aiforge into liuzx_trainjob

tags/v1.21.11.1
zhoupzh 4 years ago
parent
commit
cefd602dc8
4 changed files with 45 additions and 17 deletions
  1. +3
    -3
      modules/modelarts/modelarts.go
  2. +40
    -12
      routers/repo/cloudbrain.go
  3. +1
    -1
      routers/routes/routes.go
  4. +1
    -1
      templates/repo/modelarts/trainjob/index.tmpl

+ 3
- 3
modules/modelarts/modelarts.go View File

@@ -221,7 +221,7 @@ func TransTrainJobStatus(status int) string {
case 0:
return "UNKNOWN"
case 1:
return "INIT"
return "CREATING"
case 2:
return "IMAGE_CREATING"
case 3:
@@ -237,13 +237,13 @@ func TransTrainJobStatus(status int) string {
case 8:
return "RUNNING"
case 9:
return "KILLING"
return "STOPED"
case 10:
return "COMPLETED"
case 11:
return "FAILED"
case 12:
return "KILLED"
return "STOPED"
case 13:
return "CANCELED"
case 14:


+ 40
- 12
routers/repo/cloudbrain.go View File

@@ -202,7 +202,7 @@ func CloudBrainCreate(ctx *context.Context, form auth.CreateCloudBrainForm) {
gpuQueue := setting.JobType
codePath := setting.JobPath + jobName + cloudbrain.CodeMountPath
resourceSpecId := form.ResourceSpecId
if !jobNamePattern.MatchString(jobName) {
ctx.RenderWithErr(ctx.Tr("repo.cloudbrain_jobname_err"), tplModelArtsNew, &form)
return
@@ -474,7 +474,7 @@ func CloudBrainDel(ctx *context.Context) {
return
}

if task.Status != string(models.JobStopped) && task.Status != string(models.JobFailed){
if task.Status != string(models.JobStopped) && task.Status != string(models.JobFailed) {
log.Error("the job(%s) has not been stopped", task.JobName, ctx.Data["msgID"])
ctx.ServerError("the job has not been stopped", errors.New("the job has not been stopped"))
return
@@ -584,19 +584,47 @@ func CloudBrainDownloadModel(ctx *context.Context) {
http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}

// func TrainJobloadModel(ctx *context.Context) {
// parentDir := ctx.Query("parentDir")
// fileName := ctx.Query("fileName")
// jobName := ctx.Query("jobName")
// filePath := "jobs/" + jobName + "/model/" + parentDir
// url, err := storage.Attachments.PresignedGetURL(filePath, fileName)
// if err != nil {
// log.Error("PresignedGetURL failed: %v", err.Error(), ctx.Data["msgID"])
// ctx.ServerError("PresignedGetURL", err)
// return
// }

// http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
// }

func TrainJobloadModel(ctx *context.Context) {
parentDir := ctx.Query("parentDir")
fileName := ctx.Query("fileName")
jobName := ctx.Query("jobName")
filePath := "jobs/" + jobName + "/model/" + parentDir
url, err := storage.Attachments.PresignedGetURL(filePath, fileName)
uuid := ctx.Query("uuid")
fileName := ctx.Query("file_name")

body, err := storage.ObsDownload(uuid, fileName)
if err != nil {
log.Error("PresignedGetURL failed: %v", err.Error(), ctx.Data["msgID"])
ctx.ServerError("PresignedGetURL", err)
return
log.Info("download error.")
} else {
defer body.Close()
ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+fileName)
ctx.Resp.Header().Set("Content-Type", "application/octet-stream")
p := make([]byte, 1024)
var readErr error
var readCount int
// 读取对象内容
for {
readCount, readErr = body.Read(p)
if readCount > 0 {
ctx.Resp.Write(p[:readCount])
//fmt.Printf("%s", p[:readCount])
}
if readErr != nil {
break
}
}
}

http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently)
}

func GetRate(ctx *context.Context) {


+ 1
- 1
routers/routes/routes.go View File

@@ -991,11 +991,11 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Post("/stop", reqRepoCloudBrainWriter, repo.TrainJobStop)
m.Post("/del", reqRepoCloudBrainWriter, repo.TrainJobDel)
m.Get("/log", reqRepoCloudBrainReader, repo.TrainJobGetLog)
m.Get("/download_model", reqRepoCloudBrainReader, repo.TrainJobloadModel)
})
m.Get("/create", reqRepoCloudBrainReader, repo.TrainJobNew)
m.Post("/create", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsTrainJobForm{}), repo.TrainJobCreate)
m.Get("/para-config-list", reqRepoCloudBrainReader, repo.TrainJobGetConfigList)
m.Get("/download_model", reqRepoCloudBrainReader, repo.TrainJobloadModel)
})
}, context.RepoRef())



+ 1
- 1
templates/repo/modelarts/trainjob/index.tmpl View File

@@ -360,7 +360,7 @@
</div>
<div class="ui compact buttons" style="margin-right:10px;">
<!-- 模型下载 -->
<a class="ui basic blue button" href="{{$.Link}}/{{.JobID}}/models" target="_blank">
<a class="ui basic blue button" href="{{$.Link}}/{{.JobID}}/download_model" target="_blank">
模型下载
</a>


Loading…
Cancel
Save