diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index f35601191..5f318a546 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -598,6 +598,11 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job CreatedUnix: createTime, UpdatedUnix: createTime, Spec: req.Spec, + ModelName: req.ModelName, + ModelVersion: req.ModelVersion, + LabelName: req.LabelName, + PreTrainModelUrl: req.PreTrainModelUrl, + CkptName: req.CkptName, }) if createErr != nil { log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, createErr.Error()) diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index cb4b2c1cc..13ae93dcf 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -1656,6 +1656,14 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ }) } + if form.ModelName != "" { //使用预训练模型训练 + ckptUrl := "/" + form.PreTrainModelUrl + form.CkptName + param = append(param, models.Parameter{ + Label: modelarts.CkptUrl, + Value: "s3:/" + ckptUrl, + }) + } + // //save param config // if isSaveParam == "on" { // saveparams := append(param, models.Parameter{ @@ -1730,6 +1738,15 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ DatasetName: datasetNames, Spec: spec, } + + if form.ModelName != "" { //使用预训练模型训练 + req.ModelName = form.ModelName + req.LabelName = form.LabelName + req.CkptName = form.CkptName + req.ModelVersion = form.ModelVersion + req.PreTrainModelUrl = form.PreTrainModelUrl + + } userCommand, userImageUrl := getUserCommand(engineID, req) req.UserCommand = userCommand req.UserImageUrl = userImageUrl