From 9df8fa9f24032d1cae07f8f034f5dc233a2f126d Mon Sep 17 00:00:00 2001 From: ychao_1983 Date: Tue, 13 Sep 2022 14:48:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modules/modelarts/modelarts.go | 132 ++++++++++++++++++--------------- routers/repo/modelarts.go | 16 ++++ 2 files changed, 87 insertions(+), 61 deletions(-) diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 4539699ad..ead824b60 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -75,35 +75,40 @@ var ( ) type GenerateTrainJobReq struct { - JobName string - DisplayJobName string - Uuid string - Description string - CodeObsPath string - BootFile string - BootFileUrl string - DataUrl string - TrainUrl string - LogUrl string - PoolID string - WorkServerNumber int - EngineID int64 - Parameters []models.Parameter - CommitID string - IsLatestVersion string - Params string - BranchName string - PreVersionId int64 - PreVersionName string - FlavorCode string - FlavorName string - VersionCount int - EngineName string - TotalVersionCount int - UserImageUrl string - UserCommand string - DatasetName string - Spec *models.Specification + JobName string + DisplayJobName string + Uuid string + Description string + CodeObsPath string + BootFile string + BootFileUrl string + DataUrl string + TrainUrl string + LogUrl string + PoolID string + WorkServerNumber int + EngineID int64 + Parameters []models.Parameter + CommitID string + IsLatestVersion string + Params string + BranchName string + PreVersionId int64 + PreVersionName string + FlavorCode string + FlavorName string + VersionCount int + EngineName string + TotalVersionCount int + UserImageUrl string + UserCommand string + DatasetName string + Spec *models.Specification + ModelName string + LabelName string + CkptName string + ModelVersion string + PreTrainingModelUrl string } type GenerateInferenceJobReq struct { @@ -407,38 +412,43 @@ func GenerateTrainJob(ctx *context.Context, req *GenerateTrainJobReq) (err error } jobId := strconv.FormatInt(jobResult.JobID, 10) createErr = models.CreateCloudbrain(&models.Cloudbrain{ - Status: TransTrainJobStatus(jobResult.Status), - UserID: ctx.User.ID, - RepoID: ctx.Repo.Repository.ID, - JobID: jobId, - JobName: req.JobName, - DisplayJobName: req.DisplayJobName, - JobType: string(models.JobTypeTrain), - Type: models.TypeCloudBrainTwo, - VersionID: jobResult.VersionID, - VersionName: jobResult.VersionName, - Uuid: req.Uuid, - DatasetName: req.DatasetName, - CommitID: req.CommitID, - IsLatestVersion: req.IsLatestVersion, - ComputeResource: models.NPUResource, - EngineID: req.EngineID, - TrainUrl: req.TrainUrl, - BranchName: req.BranchName, - Parameters: req.Params, - BootFile: req.BootFile, - DataUrl: req.DataUrl, - LogUrl: req.LogUrl, - FlavorCode: req.Spec.SourceSpecId, - Description: req.Description, - WorkServerNumber: req.WorkServerNumber, - FlavorName: req.FlavorName, - EngineName: req.EngineName, - VersionCount: req.VersionCount, - TotalVersionCount: req.TotalVersionCount, - CreatedUnix: createTime, - UpdatedUnix: createTime, - Spec: req.Spec, + Status: TransTrainJobStatus(jobResult.Status), + UserID: ctx.User.ID, + RepoID: ctx.Repo.Repository.ID, + JobID: jobId, + JobName: req.JobName, + DisplayJobName: req.DisplayJobName, + JobType: string(models.JobTypeTrain), + Type: models.TypeCloudBrainTwo, + VersionID: jobResult.VersionID, + VersionName: jobResult.VersionName, + Uuid: req.Uuid, + DatasetName: req.DatasetName, + CommitID: req.CommitID, + IsLatestVersion: req.IsLatestVersion, + ComputeResource: models.NPUResource, + EngineID: req.EngineID, + TrainUrl: req.TrainUrl, + BranchName: req.BranchName, + Parameters: req.Params, + BootFile: req.BootFile, + DataUrl: req.DataUrl, + LogUrl: req.LogUrl, + FlavorCode: req.Spec.SourceSpecId, + Description: req.Description, + WorkServerNumber: req.WorkServerNumber, + FlavorName: req.FlavorName, + EngineName: req.EngineName, + VersionCount: req.VersionCount, + TotalVersionCount: req.TotalVersionCount, + CreatedUnix: createTime, + UpdatedUnix: createTime, + Spec: req.Spec, + ModelName: req.ModelName, + ModelVersion: req.ModelVersion, + LabelName: req.LabelName, + PreTrainingModelUrl: req.PreTrainingModelUrl, + CkptName: req.CkptName, }) if createErr != nil { diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index b4f6f000e..121c4fd78 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -1290,6 +1290,13 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) Value: string(jsondatas), }) } + if form.ModelName != "" { //使用预训练模型训练 + ckptUrl := "/" + form.PreTrainModelUrl + form.CkptName + param = append(param, models.Parameter{ + Label: modelarts.CkptUrl, + Value: "s3:/" + ckptUrl, + }) + } //save param config // if isSaveParam == "on" { @@ -1358,6 +1365,15 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm) DatasetName: datasetNames, Spec: spec, } + if form.ModelName != "" { //使用预训练模型训练 + req.ModelName = form.ModelName + req.LabelName = form.LabelName + req.CkptName = form.CkptName + req.ModelVersion = form.ModelVersion + req.PreTrainingModelUrl = form.PreTrainModelUrl + + } + userCommand, userImageUrl := getUserCommand(engineID, req) req.UserCommand = userCommand req.UserImageUrl = userImageUrl