Browse Source

提交代码

tags/v1.22.9.2^2
ychao_1983 3 years ago
parent
commit
cc8b76eaca
2 changed files with 22 additions and 0 deletions
  1. +5
    -0
      modules/modelarts/modelarts.go
  2. +17
    -0
      routers/repo/modelarts.go

+ 5
- 0
modules/modelarts/modelarts.go View File

@@ -598,6 +598,11 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job
CreatedUnix: createTime,
UpdatedUnix: createTime,
Spec: req.Spec,
ModelName: req.ModelName,
ModelVersion: req.ModelVersion,
LabelName: req.LabelName,
PreTrainModelUrl: req.PreTrainModelUrl,
CkptName: req.CkptName,
})
if createErr != nil {
log.Error("CreateCloudbrain(%s) failed:%v", req.JobName, createErr.Error())


+ 17
- 0
routers/repo/modelarts.go View File

@@ -1656,6 +1656,14 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
})
}

if form.ModelName != "" { //使用预训练模型训练
ckptUrl := "/" + form.PreTrainModelUrl + form.CkptName
param = append(param, models.Parameter{
Label: modelarts.CkptUrl,
Value: "s3:/" + ckptUrl,
})
}

// //save param config
// if isSaveParam == "on" {
// saveparams := append(param, models.Parameter{
@@ -1730,6 +1738,15 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
DatasetName: datasetNames,
Spec: spec,
}

if form.ModelName != "" { //使用预训练模型训练
req.ModelName = form.ModelName
req.LabelName = form.LabelName
req.CkptName = form.CkptName
req.ModelVersion = form.ModelVersion
req.PreTrainModelUrl = form.PreTrainModelUrl

}
userCommand, userImageUrl := getUserCommand(engineID, req)
req.UserCommand = userCommand
req.UserImageUrl = userImageUrl


Loading…
Cancel
Save