From 9fd463c8f48ed539a9de0e215948924991e1c4e4 Mon Sep 17 00:00:00 2001 From: zouap Date: Tue, 12 Jul 2022 09:19:48 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=95=8C=E9=9D=A2Bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zouap --- models/cloudbrain.go | 18 +++++++++++ modules/modelarts/modelarts.go | 4 +-- modules/modelarts/resty.go | 55 ++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 19258e2c5..fa636803b 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -984,6 +984,11 @@ type CreateTrainJobVersionParams struct { Config TrainJobVersionConfig `json:"config"` } +type CreateTrainJobVersionUserImageParams struct { + Description string `json:"job_desc"` + Config TrainJobVersionUserImageConfig `json:"config"` +} + type TrainJobVersionConfig struct { WorkServerNum int `json:"worker_server_num"` AppUrl string `json:"app_url"` //训练作业的代码目录 @@ -996,6 +1001,19 @@ type TrainJobVersionConfig struct { Flavor Flavor `json:"flavor"` PoolID string `json:"pool_id"` PreVersionId int64 `json:"pre_version_id"` +} + +type TrainJobVersionUserImageConfig struct { + WorkServerNum int `json:"worker_server_num"` + AppUrl string `json:"app_url"` //训练作业的代码目录 + BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 + Parameter []Parameter `json:"parameter"` + DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL + TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL + LogUrl string `json:"log_url"` + Flavor Flavor `json:"flavor"` + PoolID string `json:"pool_id"` + PreVersionId int64 `json:"pre_version_id"` UserImageUrl string `json:"user_image_url"` UserCommand string `json:"user_command"` } diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 9c68421bb..1f39d0fac 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -438,9 +438,9 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job var createErr error log.Info(" req.EngineID =" + fmt.Sprint(req.EngineID)) if req.EngineID < 0 { - jobResult, createErr = createTrainJobVersion(models.CreateTrainJobVersionParams{ + jobResult, createErr = createTrainJobVersionUserImage(models.CreateTrainJobVersionUserImageParams{ Description: req.Description, - Config: models.TrainJobVersionConfig{ + Config: models.TrainJobVersionUserImageConfig{ WorkServerNum: req.WorkServerNumber, AppUrl: req.CodeObsPath, BootFileUrl: req.BootFileUrl, diff --git a/modules/modelarts/resty.go b/modules/modelarts/resty.go index a9c868eaa..46c273a8b 100755 --- a/modules/modelarts/resty.go +++ b/modules/modelarts/resty.go @@ -639,6 +639,61 @@ sendjob: return &result, nil } +func createTrainJobVersionUserImage(createJobVersionParams models.CreateTrainJobVersionUserImageParams, jobID string) (*models.CreateTrainJobResult, error) { + checkSetting() + client := getRestyClient() + var result models.CreateTrainJobResult + + retry := 0 + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetAuthToken(TOKEN). + SetBody(createJobVersionParams). + SetResult(&result). + Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob + "/" + jobID + "/versions") + + if err != nil { + return nil, fmt.Errorf("resty create train-job version: %s", err) + } + + req, _ := json.Marshal(createJobVersionParams) + log.Info("%s", req) + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + _ = getToken() + goto sendjob + } + + if res.StatusCode() != http.StatusOK { + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + BootFileErrorMsg := "Invalid OBS path '" + createJobVersionParams.Config.BootFileUrl + "'." + DataSetErrorMsg := "Invalid OBS path '" + createJobVersionParams.Config.DataUrl + "'." + if temp.ErrorMsg == BootFileErrorMsg { + log.Error("启动文件错误!createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("启动文件错误!") + } + if temp.ErrorMsg == DataSetErrorMsg { + log.Error("数据集错误!createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("数据集错误!") + } + return &result, fmt.Errorf("createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + } + + if !result.IsSuccess { + log.Error("createTrainJobVersion failed(%s): %s", result.ErrorCode, result.ErrorMsg) + return &result, fmt.Errorf("createTrainJobVersion failed(%s): %s", result.ErrorCode, result.ErrorMsg) + } + + return &result, nil +} + func GetResourceSpecs() (*models.GetResourceSpecsResult, error) { checkSetting() client := getRestyClient()