Browse Source

Merge pull request 'fix-1710' (#1736) from fix-1710 into V20220328

Reviewed-on: https://git.openi.org.cn/OpenI/aiforge/pulls/1736
Reviewed-by: lewis <747342561@qq.com>
tags/v1.22.3.2^2
lewis 3 years ago
parent
commit
8148bdf4ef
2 changed files with 50 additions and 28 deletions
  1. +2
    -0
      modules/modelarts/modelarts.go
  2. +48
    -28
      routers/repo/modelarts.go

+ 2
- 0
modules/modelarts/modelarts.go View File

@@ -51,6 +51,8 @@ const (
DataUrl = "data_url"
ResultUrl = "result_url"
CkptUrl = "ckpt_url"
DeviceTarget = "device_target"
Ascend = "Ascend"
PerPage = 10
IsLatestVersion = "1"
NotLatestVersion = "0"


+ 48
- 28
routers/repo/modelarts.go View File

@@ -962,17 +962,9 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
return
}

//todo: del local code?

var parameters models.Parameters
param := make([]models.Parameter, 0)
param = append(param, models.Parameter{
Label: modelarts.TrainUrl,
Value: outputObsPath,
}, models.Parameter{
Label: modelarts.DataUrl,
Value: dataPath,
})
existDeviceTarget := false
if len(params) != 0 {
err := json.Unmarshal([]byte(params), &parameters)
if err != nil {
@@ -983,6 +975,9 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
}

for _, parameter := range parameters.Parameter {
if parameter.Label == modelarts.DeviceTarget {
existDeviceTarget = true
}
if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl {
param = append(param, models.Parameter{
Label: parameter.Label,
@@ -991,9 +986,22 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
}
}
}
if !existDeviceTarget {
param = append(param, models.Parameter{
Label: modelarts.DeviceTarget,
Value: modelarts.Ascend,
})
}

//save param config
if isSaveParam == "on" {
saveparams := append(param, models.Parameter{
Label: modelarts.TrainUrl,
Value: outputObsPath,
}, models.Parameter{
Label: modelarts.DataUrl,
Value: dataPath,
})
if form.ParameterTemplateName == "" {
log.Error("ParameterTemplateName is empty")
trainJobNewDataPrepare(ctx)
@@ -1015,7 +1023,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
EngineID: int64(engineID),
LogUrl: logObsPath,
PoolID: poolID,
Parameter: param,
Parameter: saveparams,
})

if err != nil {
@@ -1041,7 +1049,7 @@ func TrainJobCreate(ctx *context.Context, form auth.CreateModelArtsTrainJobForm)
LogUrl: logObsPath,
PoolID: poolID,
Uuid: uuid,
Parameters: parameters.Parameter,
Parameters: param,
CommitID: commitID,
IsLatestVersion: isLatestVersion,
BranchName: branch_name,
@@ -1177,13 +1185,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ

var parameters models.Parameters
param := make([]models.Parameter, 0)
param = append(param, models.Parameter{
Label: modelarts.TrainUrl,
Value: outputObsPath,
}, models.Parameter{
Label: modelarts.DataUrl,
Value: dataPath,
})
existDeviceTarget := true
if len(params) != 0 {
err := json.Unmarshal([]byte(params), &parameters)
if err != nil {
@@ -1192,8 +1194,10 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
ctx.RenderWithErr("运行参数错误", tplModelArtsTrainJobVersionNew, &form)
return
}

for _, parameter := range parameters.Parameter {
if parameter.Label == modelarts.DeviceTarget {
existDeviceTarget = true
}
if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl {
param = append(param, models.Parameter{
Label: parameter.Label,
@@ -1202,9 +1206,22 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
}
}
}
if !existDeviceTarget {
param = append(param, models.Parameter{
Label: modelarts.DeviceTarget,
Value: modelarts.Ascend,
})
}

//save param config
if isSaveParam == "on" {
saveparams := append(param, models.Parameter{
Label: modelarts.TrainUrl,
Value: outputObsPath,
}, models.Parameter{
Label: modelarts.DataUrl,
Value: dataPath,
})
if form.ParameterTemplateName == "" {
log.Error("ParameterTemplateName is empty")
versionErrorDataPrepare(ctx, form)
@@ -1226,7 +1243,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
EngineID: int64(engineID),
LogUrl: logObsPath,
PoolID: poolID,
Parameter: parameters.Parameter,
Parameter: saveparams,
})

if err != nil {
@@ -1237,12 +1254,6 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
}
}

if err != nil {
log.Error("getFlavorNameByEngineID(%s) failed:%v", engineID, err.Error())
ctx.RenderWithErr(err.Error(), tplModelArtsTrainJobVersionNew, &form)
return
}

task, err := models.GetCloudbrainByJobIDAndVersionName(jobID, PreVersionName)
if err != nil {
log.Error("GetCloudbrainByJobIDAndVersionName(%s) failed:%v", jobID, err.Error())
@@ -1266,7 +1277,7 @@ func TrainJobCreateVersion(ctx *context.Context, form auth.CreateModelArtsTrainJ
PoolID: poolID,
Uuid: uuid,
Params: form.Params,
Parameters: parameters.Parameter,
Parameters: param,
PreVersionId: task.VersionID,
CommitID: commitID,
BranchName: branch_name,
@@ -1791,7 +1802,6 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
return
}

//todo: del local code?
var parameters models.Parameters
param := make([]models.Parameter, 0)
param = append(param, models.Parameter{
@@ -1801,6 +1811,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
Label: modelarts.CkptUrl,
Value: "s3:/" + ckptUrl,
})
existDeviceTarget := false
if len(params) != 0 {
err := json.Unmarshal([]byte(params), &parameters)
if err != nil {
@@ -1811,6 +1822,9 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
}

for _, parameter := range parameters.Parameter {
if parameter.Label == modelarts.DeviceTarget {
existDeviceTarget = true
}
if parameter.Label != modelarts.TrainUrl && parameter.Label != modelarts.DataUrl {
param = append(param, models.Parameter{
Label: parameter.Label,
@@ -1819,6 +1833,12 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference
}
}
}
if !existDeviceTarget {
param = append(param, models.Parameter{
Label: modelarts.DeviceTarget,
Value: modelarts.Ascend,
})
}

req := &modelarts.GenerateInferenceJobReq{
JobName: jobName,


Loading…
Cancel
Save