diff --git a/models/cloudbrain.go b/models/cloudbrain.go index fa22e497b..57dc74ef7 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -216,7 +216,7 @@ type CloudbrainsOptions struct { JobType string VersionName string IsLatestVersion string - JobTypeNot bool + JobTypeNot bool } type TaskPod struct { @@ -650,6 +650,28 @@ type Config struct { Flavor Flavor `json:"flavor"` PoolID string `json:"pool_id"` } +type CreateInferenceJobParams struct { + JobName string `json:"job_name"` + Description string `json:"job_desc"` + InfConfig InfConfig `json:"config"` + WorkspaceID string `json:"workspace_id"` +} + +type InfConfig struct { + WorkServerNum int `json:"worker_server_num"` + AppUrl string `json:"app_url"` //训练作业的代码目录 + BootFileUrl string `json:"boot_file_url"` //训练作业的代码启动文件,需要在代码目录下 + Parameter []Parameter `json:"parameter"` + DataUrl string `json:"data_url"` //训练作业需要的数据集OBS路径URL + EngineID int64 `json:"engine_id"` + // TrainUrl string `json:"train_url"` //训练作业的输出文件OBS路径URL + LogUrl string `json:"log_url"` + //UserImageUrl string `json:"user_image_url"` + //UserCommand string `json:"user_command"` + CreateVersion bool `json:"create_version"` + Flavor Flavor `json:"flavor"` + PoolID string `json:"pool_id"` +} type CreateTrainJobVersionParams struct { Description string `json:"job_desc"` diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 5891182c0..620bdae37 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -47,7 +47,7 @@ const ( TrainUrl = "train_url" DataUrl = "data_url" ResultUrl = "result_url" - CkptName = "ckpt_name" + CkptUrl = "ckpt_url" PerPage = 10 IsLatestVersion = "1" NotLatestVersion = "0" @@ -477,16 +477,16 @@ func GetVersionOutputPathByTotalVersionCount(TotalVersionCount int) (VersionOutp } func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (err error) { - jobResult, err := createTrainJob(models.CreateTrainJobParams{ + jobResult, err := createInferenceJob(models.CreateInferenceJobParams{ JobName: req.JobName, Description: req.Description, - Config: models.Config{ + InfConfig: models.InfConfig{ WorkServerNum: req.WorkServerNumber, AppUrl: req.CodeObsPath, BootFileUrl: req.BootFileUrl, DataUrl: req.DataUrl, EngineID: req.EngineID, - TrainUrl: req.TrainUrl, + // TrainUrl: req.TrainUrl, LogUrl: req.LogUrl, PoolID: req.PoolID, CreateVersion: true, diff --git a/modules/modelarts/resty.go b/modules/modelarts/resty.go index 07f26ceb7..0baa95787 100755 --- a/modules/modelarts/resty.go +++ b/modules/modelarts/resty.go @@ -874,3 +874,59 @@ sendjob: return &result, nil } + +func createInferenceJob(createJobParams models.CreateInferenceJobParams) (*models.CreateTrainJobResult, error) { + checkSetting() + client := getRestyClient() + var result models.CreateTrainJobResult + + retry := 0 + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetAuthToken(TOKEN). + SetBody(createJobParams). + SetResult(&result). + Post(HOST + "/v1/" + setting.ProjectID + urlTrainJob) + + if err != nil { + return nil, fmt.Errorf("resty create train-job: %s", err) + } + + req, _ := json.Marshal(createJobParams) + log.Info("%s", req) + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + _ = getToken() + goto sendjob + } + + if res.StatusCode() != http.StatusOK { + var temp models.ErrorResult + if err = json.Unmarshal([]byte(res.String()), &temp); err != nil { + log.Error("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + return &result, fmt.Errorf("json.Unmarshal failed(%s): %v", res.String(), err.Error()) + } + log.Error("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + BootFileErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.BootFileUrl + "'." + DataSetErrorMsg := "Invalid OBS path '" + createJobParams.InfConfig.DataUrl + "'." + if temp.ErrorMsg == BootFileErrorMsg { + log.Error("启动文件错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("启动文件错误!") + } + if temp.ErrorMsg == DataSetErrorMsg { + log.Error("数据集错误!createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + return &result, fmt.Errorf("数据集错误!") + } + return &result, fmt.Errorf("createTrainJob failed(%d):%s(%s)", res.StatusCode(), temp.ErrorCode, temp.ErrorMsg) + } + + if !result.IsSuccess { + log.Error("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg) + return &result, fmt.Errorf("createTrainJob failed(%s): %s", result.ErrorCode, result.ErrorMsg) + } + + return &result, nil +} diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 4f2e5bb8a..c2864d744 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -1596,6 +1596,8 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference modelVersion := form.ModelVersion ckptName := form.CkptName + ckptUrl := form.TrainUrl + form.CkptName + count, err := models.GetCloudbrainTrainJobCountByUserID(ctx.User.ID) if err != nil { log.Error("GetCloudbrainTrainJobCountByUserID failed:%v", err, ctx.Data["MsgID"]) @@ -1675,8 +1677,8 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference Label: modelarts.ResultUrl, Value: "s3:/" + resultObsPath, }, models.Parameter{ - Label: modelarts.CkptName, - Value: ckptName, + Label: modelarts.CkptUrl, + Value: "s3:/" + ckptUrl, }) if len(params) != 0 { err := json.Unmarshal([]byte(params), ¶meters)