| @@ -984,6 +984,11 @@ type CreateTrainJobVersionParams struct { | |||||
| Config TrainJobVersionConfig `json:"config"` | Config TrainJobVersionConfig `json:"config"` | ||||
| } | } | ||||
| type CreateTrainJobVersionUserImageParams struct { | |||||
| Description string `json:"job_desc"` | |||||
| Config TrainJobVersionUserImageConfig `json:"config"` | |||||
| } | |||||
| type TrainJobVersionConfig struct { | type TrainJobVersionConfig struct { | ||||
| WorkServerNum int `json:"worker_server_num"` | WorkServerNum int `json:"worker_server_num"` | ||||
| AppUrl string `json:"app_url"` //训练作业的代码目录 | AppUrl string `json:"app_url"` //训练作业的代码目录 | ||||
| @@ -996,6 +1001,19 @@ type TrainJobVersionConfig struct { | |||||
| Flavor Flavor `json:"flavor"` | Flavor Flavor `json:"flavor"` | ||||
| PoolID string `json:"pool_id"` | PoolID string `json:"pool_id"` | ||||
| PreVersionId int64 `json:"pre_version_id"` | PreVersionId int64 `json:"pre_version_id"` | ||||
| } | |||||
| type TrainJobVersionUserImageConfig struct { | |||||
| WorkServerNum int `json:"worker_server_num"` | |||||
| AppUrl string `json:"app_url"` //训练作业的代码目录 | |||||
| BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 | |||||
| Parameter []Parameter `json:"parameter"` | |||||
| DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL | |||||
| TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL | |||||
| LogUrl string `json:"log_url"` | |||||
| Flavor Flavor `json:"flavor"` | |||||
| PoolID string `json:"pool_id"` | |||||
| PreVersionId int64 `json:"pre_version_id"` | |||||
| UserImageUrl string `json:"user_image_url"` | UserImageUrl string `json:"user_image_url"` | ||||
| UserCommand string `json:"user_command"` | UserCommand string `json:"user_command"` | ||||
| } | } | ||||
| @@ -438,9 +438,9 @@ func GenerateTrainJobVersion(ctx *context.Context, req *GenerateTrainJobReq, job | |||||
| var createErr error | var createErr error | ||||
| log.Info(" req.EngineID =" + fmt.Sprint(req.EngineID)) | log.Info(" req.EngineID =" + fmt.Sprint(req.EngineID)) | ||||
| if req.EngineID < 0 { | if req.EngineID < 0 { | ||||
| jobResult, createErr = createTrainJobVersion(models.CreateTrainJobVersionParams{ | |||||
| jobResult, createErr = createTrainJobVersionUserImage(models.CreateTrainJobVersionUserImageParams{ | |||||
| Description: req.Description, | Description: req.Description, | ||||
| Config: models.TrainJobVersionConfig{ | |||||
| Config: models.TrainJobVersionUserImageConfig{ | |||||
| WorkServerNum: req.WorkServerNumber, | WorkServerNum: req.WorkServerNumber, | ||||
| AppUrl: req.CodeObsPath, | AppUrl: req.CodeObsPath, | ||||
| BootFileUrl: req.BootFileUrl, | BootFileUrl: req.BootFileUrl, | ||||
| @@ -639,6 +639,61 @@ sendjob: | |||||
| return &result, nil | return &result, nil | ||||
| } | } | ||||
| func createTrainJobVersionUserImage(createJobVersionParams models.CreateTrainJobVersionUserImageParams, jobID string) (*models.CreateTrainJobResult, error) { | |||||
| checkSetting() | |||||
| client := getRestyClient() | |||||
| var result models.CreateTrainJobResult | |||||
| retry := 0 | |||||
| sendjob: | |||||
| res, err := client.R(). | |||||
| SetHeader("Content-Type", "application/json"). | |||||
| SetAuthToken(TOKEN). | |||||
| SetBody(createJobVersionParams). | |||||
| SetResult(&result). | |||||
| Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob + "/" + jobID + "/versions") | |||||
| if err != nil { | |||||
| return nil, fmt.Errorf("resty create train-job version: %s", err) | |||||
| } | |||||
| req, _ := json.Marshal(createJobVersionParams) | |||||
| log.Info("%s", req) | |||||
| if res.StatusCode() == http.StatusUnauthorized && retry < 1 { | |||||
| retry++ | |||||
| _ = getToken() | |||||
| goto sendjob | |||||
| } | |||||
| if res.StatusCode() != http.StatusOK { | |||||
| var temp models.ErrorResult | |||||
| if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { | |||||
| log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||||
| return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) | |||||
| } | |||||
| BootFileErrorMsg := "Invalid OBS path '" + createJobVersionParams.Config.BootFileUrl + "'." | |||||
| DataSetErrorMsg := "Invalid OBS path '" + createJobVersionParams.Config.DataUrl + "'." | |||||
| if temp.ErrorMsg == BootFileErrorMsg { | |||||
| log.Error("启动文件错误!createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| return &result, fmt.Errorf("启动文件错误!") | |||||
| } | |||||
| if temp.ErrorMsg == DataSetErrorMsg { | |||||
| log.Error("数据集错误!createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| return &result, fmt.Errorf("数据集错误!") | |||||
| } | |||||
| return &result, fmt.Errorf("createTrainJobVersion failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) | |||||
| } | |||||
| if !result.IsSuccess { | |||||
| log.Error("createTrainJobVersion failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||||
| return &result, fmt.Errorf("createTrainJobVersion failed(%s): %s", result.ErrorCode, result.ErrorMsg) | |||||
| } | |||||
| return &result, nil | |||||
| } | |||||
| func GetResourceSpecs() (*models.GetResourceSpecsResult, error) { | func GetResourceSpecs() (*models.GetResourceSpecsResult, error) { | ||||
| checkSetting() | checkSetting() | ||||
| client := getRestyClient() | client := getRestyClient() | ||||