From e964d8e452eed9dae574dbbd89ae402e94b4d130 Mon Sep 17 00:00:00 2001 From: liuzx Date: Wed, 10 Nov 2021 18:58:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BC=95=E6=93=8E=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/cloudbrain.go | 1 + modules/auth/modelarts.go | 1 + modules/modelarts/modelarts.go | 4 ++++ routers/repo/modelarts.go | 27 ++++++++++++++++++++------- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/models/cloudbrain.go b/models/cloudbrain.go index e4d8461fb..a061857c4 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -90,6 +90,7 @@ type Cloudbrain struct { FlavorCode string Description string WorkServerNumber int + FlavorName string User *User `xorm:"-"` Repo *Repository `xorm:"-"` diff --git a/modules/auth/modelarts.go b/modules/auth/modelarts.go index a53661b74..3cd8ac637 100755 --- a/modules/auth/modelarts.go +++ b/modules/auth/modelarts.go @@ -40,6 +40,7 @@ type CreateModelArtsTrainJobForm struct { PrameterDescription string `form:"parameter_description"` BranchName string `form:"branch_name" binding:"Required"` VersionName string `form:"version_name" binding:"Required"` + FlavorName string `form:"flavor_name" binding:"Required"` } func (f *CreateModelArtsTrainJobForm) Validate(ctx *macaron.Context, errs binding.Errors) binding.Errors { diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 7d87ca1b2..88378ab10 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -78,6 +78,7 @@ type GenerateTrainJobReq struct { Params string BranchName string FatherVersionName string + FlavorName string } type GenerateTrainJobVersionReq struct { @@ -98,6 +99,7 @@ type GenerateTrainJobVersionReq struct { PreVersionId int64 CommitID string BranchName string + FlavorName string } type VersionInfo struct { @@ -256,6 +258,7 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error FlavorCode: req.FlavorCode, Description: req.Description, WorkServerNumber: req.WorkServerNumber, + FlavorName: req.FlavorName, }) if err != nil { @@ -322,6 +325,7 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobVersionR FlavorCode: req.FlavorCode, Description: req.Description, WorkServerNumber: req.WorkServerNumber, + FlavorName: req.FlavorName, }) if err != nil { log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, err.Error()) diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 8426f99d8..b5b5b07a9 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -620,12 +620,17 @@ func TrainJobNewVersion(ctx *context.Context) { func trainJobNewVersionDataPrepare(ctx *context.Context) error { ctx.Data["PageIsCloudBrain"] = true var jobID = ctx.Params(":jobid") - var versionName = ctx.Query("versionName") - jobID = "19373" + var versionName = ctx.Query("version_name") + + task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, versionName) + if err != nil { + log.Error("GetCloudbrainByJobIDAndVersionName(%s) failed:%v", jobID, err.Error()) + return err + } t := time.Now() var jobName = cutString(ctx.User.Name, 5) + t.Format("2006010215") + strconv.Itoa(int(t.Unix()))[5:] - ctx.Data["job_name"] = jobName + ctx.Data["job_name"] = task.JobName attachs, err := models.GetModelArtsUserAttachments(ctx.User.ID) if err != nil { @@ -670,10 +675,14 @@ func trainJobNewVersionDataPrepare(ctx *context.Context) error { ctx.ServerError("GetBranches error:", err) return err } - ctx.Data["Branches"] = Branches - ctx.Data["BranchesCount"] = len(Branches) - ctx.Data["jobID"] = jobID - ctx.Data["versionName"] = versionName + ctx.Data["branches"] = Branches + ctx.Data["branch_name"] = task.BranchName + ctx.Data["description"] = task.Description + ctx.Data["boot_file"] = task.BootFile + ctx.Data["dataset_name"] = task.DatasetName + ctx.Data["params"] = task.Parameters + ctx.Data["work_server_number"] = task.WorkServerNumber + ctx.Data["flavor_name"] = task.FlavorName configList, err := getConfigList(modelarts.PerPage, 1, modelarts.SortByCreateTime, "desc", "", modelarts.ConfigTypeCustom) if err != nil { @@ -705,6 +714,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" branch_name := form.BranchName isLatestVersion := modelarts.IsLatestVersion + FlavorName := form.FlavorName if err := paramCheckCreateTrainJob(form); err != nil { log.Error("paramCheckCreateTrainJob failed:(%v)", err) @@ -851,6 +861,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) BranchName: branch_name, Params: form.Params, FatherVersionName: modelarts.InitFatherVersionName, + FlavorName: FlavorName, } err = modelarts.GenerateTrainJob(ctx, req) @@ -921,6 +932,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ dataPath := "/" + setting.Bucket + "/" + setting.BasePath + path.Join(uuid[0:1], uuid[1:2]) + "/" + uuid + uuid + "/" branch_name := form.BranchName fatherVersionName := form.VersionName + FlavorName := form.FlavorName if err := paramCheckCreateTrainJob(form); err != nil { log.Error("paramCheckCreateTrainJob failed:(%v)", err) @@ -1070,6 +1082,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ PreVersionId: task.VersionID, CommitID: commitID, BranchName: branch_name, + FlavorName: FlavorName, } err = modelarts.GenerateTrainJobVersion(ctx, req, jobID, fatherVersionName) if err != nil {