diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 19258e2c5..fa636803b 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -984,6 +984,11 @@ type CreateTrainJobVersionParams struct { Config TrainJobVersionConfig `json:"config"` } +type CreateTrainJobVersionUserImageParams struct { + Description string `json:"job_desc"` + Config TrainJobVersionUserImageConfig `json:"config"` +} + type TrainJobVersionConfig struct { WorkServerNum int `json:"worker_server_num"` AppUrl string `json:"app_url"` //训练作业的代码目录 @@ -996,6 +1001,19 @@ type TrainJobVersionConfig struct { Flavor Flavor `json:"flavor"` PoolID string `json:"pool_id"` PreVersionId int64 `json:"pre_version_id"` +} + +type TrainJobVersionUserImageConfig struct { + WorkServerNum int `json:"worker_server_num"` + AppUrl string `json:"app_url"` //训练作业的代码目录 + BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 + Parameter []Parameter `json:"parameter"` + DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL + TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL + LogUrl string `json:"log_url"` + Flavor Flavor `json:"flavor"` + PoolID string `json:"pool_id"` + PreVersionId int64 `json:"pre_version_id"` UserImageUrl string `json:"user_image_url"` UserCommand string `json:"user_command"` } diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 9c68421bb..1f39d0fac 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -438,9 +438,9 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job var createErr error log.Info(" req.EngineID =" + fmt.Sprint(req.EngineID)) if req.EngineID < 0 { - jobResult, createErr = createTrainJobVersion(models.CreateTrainJobVersionParams{ + jobResult, createErr = createTrainJobVersionUserImage(models.CreateTrainJobVersionUserImageParams{ Description: req.Description, - Config: models.TrainJobVersionConfig{ + Config: models.TrainJobVersionUserImageConfig{ WorkServerNum: req.WorkServerNumber, AppUrl: req.CodeObsPath, BootFileUrl: req.BootFileUrl, diff --git a/modules/modelarts/resty.go b/modules/modelarts/resty.go index a9c868eaa..46c273a8b 100755 --- a/modules/modelarts/resty.go +++ b/modules/modelarts/resty.go @@ -639,6 +639,61 @@ sendjob: return &result, nil } +func createTrainJobVersionUserImage(createJobVersionParams models.CreateTrainJobVersionUserImageParams, jobID string) (*models.CreateTrainJobResult, error) { + checkSetting() + client := getRestyClient() + var result models.CreateTrainJobResult + + retry := 0 + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetAuthToken(TOKEN). + SetBody(createJobVersionParams). + SetResult(&result). + Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob + "/" + jobID + "/versions") + + if err != nil { + return nil, fmt.Errorf("resty create train-job version: %s", err) + } + + req, _ := json.Marshal(createJobVersionParams) + log.Info("%s", req) + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + _ = getToken() + goto sendjob + } + + if res.StatusCode() != http.StatusOK { + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + BootFileErrorMsg := "Invalid OBS path '" + createJobVersionParams.Config.BootFileUrl + "'." + DataSetErrorMsg := "Invalid OBS path '" + createJobVersionParams.Config.DataUrl + "'." + if temp.ErrorMsg == BootFileErrorMsg { + log.Error("启动文件错误!createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("启动文件错误!") + } + if temp.ErrorMsg == DataSetErrorMsg { + log.Error("数据集错误!createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("数据集错误!") + } + return &result, fmt.Errorf("createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + } + + if !result.IsSuccess { + log.Error("createTrainJobVersion failed(%s): %s", result.ErrorCode, result.ErrorMsg) + return &result, fmt.Errorf("createTrainJobVersion failed(%s): %s", result.ErrorCode, result.ErrorMsg) + } + + return &result, nil +} + func GetResourceSpecs() (*models.GetResourceSpecsResult, error) { checkSetting() client := getRestyClient()