From 5de0aa0dbcaac4a6d3219dcd47a0c555675320c8 Mon Sep 17 00:00:00 2001 From: zouap Date: Thu, 29 Sep 2022 17:21:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AF=84=E6=B5=8B=E5=90=8E?= =?UTF-8?q?=E7=AB=AF=E4=BB=A3=E7=A0=81=E7=BC=96=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zouap --- models/cloudbrain.go | 40 +++++++++++++++++++-------------------- modules/aisafety/resty.go | 8 -------- routers/repo/aisafety.go | 39 +++++++++++++++++++++++++++----------- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 24be7b989..cb0fb3421 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -175,26 +175,26 @@ type Cloudbrain struct { ImageID string //grampus image_id AiCenter string //grampus ai center: center_id+center_name - TrainUrl string //输出模型的obs路径 - BranchName string //分支名称 - Parameters string //传给modelarts的param参数 - BootFile string //启动文件 - DataUrl string //数据集的obs路径 - LogUrl string //日志输出的obs路径 - PreVersionId int64 //父版本的版本id - FlavorCode string //modelarts上的规格id - Description string `xorm:"varchar(256)"` //描述 - WorkServerNumber int //节点数 - FlavorName string //规格名称 - EngineName string //引擎名称 - TotalVersionCount int //任务的所有版本数量,包括删除的 - LabelName string //标签名称 - ModelName string //模型名称 - ModelVersion string //模型版本 - CkptName string //权重文件名称 - PreTrainModelUrl string //预训练模型地址 - ResultUrl string //推理结果的obs路径 - + TrainUrl string //输出模型的obs路径 + BranchName string //分支名称 + Parameters string //传给modelarts的param参数 + BootFile string //启动文件 + DataUrl string //数据集的obs路径 + LogUrl string //日志输出的obs路径 + PreVersionId int64 //父版本的版本id + FlavorCode string //modelarts上的规格id + Description string `xorm:"varchar(256)"` //描述 + WorkServerNumber int //节点数 + FlavorName string //规格名称 + EngineName string //引擎名称 + TotalVersionCount int //任务的所有版本数量,包括删除的 + LabelName string //标签名称 + ModelName string //模型名称 + ModelVersion string //模型版本 + CkptName string //权重文件名称 + PreTrainModelUrl string //预训练模型地址 + ResultUrl string //推理结果的obs路径 + ResultJson string `xorm:"varchar(4000)"` User *User `xorm:"-"` Repo *Repository `xorm:"-"` BenchmarkType string `xorm:"-"` //算法评测,模型评测 diff --git a/modules/aisafety/resty.go b/modules/aisafety/resty.go index 9f7ebf5d2..be6468529 100644 --- a/modules/aisafety/resty.go +++ b/modules/aisafety/resty.go @@ -237,7 +237,6 @@ func GetTaskStatus(jobID string) (*ReturnMsg, error) { Get(HOST + "/v1/external/eval-standard/query?serialNo=" + jobID) log.Info("url=" + HOST + "/v1/external/eval-standard/query?serialNo=" + jobID) - responseStr := string(res.Body()) log.Info("GetTaskStatus responseStr=" + responseStr + " res code=" + fmt.Sprint(res.StatusCode())) @@ -245,13 +244,6 @@ func GetTaskStatus(jobID string) (*ReturnMsg, error) { log.Info("error =" + err.Error()) return nil, fmt.Errorf("Get task status error: %v", err) } else { - log.Info("finished.") - // var reMap ReturnMsg - // err = json.Unmarshal(res.Body(), &reMap) - // if err == nil { return &reMap, nil - // } else { - // return nil, fmt.Errorf("get error,code not 0") - // } } } diff --git a/routers/repo/aisafety.go b/routers/repo/aisafety.go index a8bd6a2a2..f845c3918 100644 --- a/routers/repo/aisafety.go +++ b/routers/repo/aisafety.go @@ -76,19 +76,17 @@ func GetAiSafetyTask(ctx *context.Context) { } func syncAiSafetyTaskStatus(job *models.Cloudbrain) { - if job.Type == models.TypeCloudBrainTwo { - if isTaskNotFinished(job.Status) { + if isTaskNotFinished(job.Status) { + if job.Type == models.TypeCloudBrainTwo { queryTaskStatusFromCloudbrainTwo(job) - } - } else if job.Type == models.TypeCloudBrainOne { - if isTaskNotFinished(job.Status) { + } else if job.Type == models.TypeCloudBrainOne { queryTaskStatusFromCloudbrain(job) + } + } else { + if job.Status == string(models.ModelSafetyTesting) { + queryTaskStatusFromModelSafetyTestServer(job) } else { - if job.Status == string(models.ModelSafetyTesting) { - queryTaskStatusFromModelSafetyTestServer(job) - } else { - log.Info("The job is finished. status=" + job.Status) - } + log.Info("The job is finished. status=" + job.Status) } } } @@ -233,8 +231,27 @@ func queryTaskStatusFromModelSafetyTestServer(job *models.Cloudbrain) { result, err := aisafety.GetTaskStatus(job.PreVersionName) if err == nil { if result.Code == "0" { - + if result.Data.Status == 1 { + log.Info("The task is running....") + } else { + if result.Data.Code == 0 { + job.ResultJson = result.Data.StandardJson + err = models.UpdateJob(job) + if err != nil { + log.Error("UpdateJob failed:", err) + } + } + } + } else { + log.Info("The task is failed.") + job.Status = string(models.JobFailed) + err = models.UpdateJob(job) + if err != nil { + log.Error("UpdateJob failed:", err) + } } + } else { + log.Info("The task not found.....") } }