Browse Source

模型评测后端代码编写

Signed-off-by: zouap <zouap@pcl.ac.cn>
tags/v1.22.10.1^2
zouap 3 years ago
parent
commit
5de0aa0dbc
3 changed files with 48 additions and 39 deletions
  1. +20
    -20
      models/cloudbrain.go
  2. +0
    -8
      modules/aisafety/resty.go
  3. +28
    -11
      routers/repo/aisafety.go

+ 20
- 20
models/cloudbrain.go View File

@@ -175,26 +175,26 @@ type Cloudbrain struct {
ImageID string //grampus image_id
AiCenter string //grampus ai center: center_id+center_name

TrainUrl string //输出模型的obs路径
BranchName string //分支名称
Parameters string //传给modelarts的param参数
BootFile string //启动文件
DataUrl string //数据集的obs路径
LogUrl string //日志输出的obs路径
PreVersionId int64 //父版本的版本id
FlavorCode string //modelarts上的规格id
Description string `xorm:"varchar(256)"` //描述
WorkServerNumber int //节点数
FlavorName string //规格名称
EngineName string //引擎名称
TotalVersionCount int //任务的所有版本数量,包括删除的
LabelName string //标签名称
ModelName string //模型名称
ModelVersion string //模型版本
CkptName string //权重文件名称
PreTrainModelUrl string //预训练模型地址
ResultUrl string //推理结果的obs路径
TrainUrl string //输出模型的obs路径
BranchName string //分支名称
Parameters string //传给modelarts的param参数
BootFile string //启动文件
DataUrl string //数据集的obs路径
LogUrl string //日志输出的obs路径
PreVersionId int64 //父版本的版本id
FlavorCode string //modelarts上的规格id
Description string `xorm:"varchar(256)"` //描述
WorkServerNumber int //节点数
FlavorName string //规格名称
EngineName string //引擎名称
TotalVersionCount int //任务的所有版本数量,包括删除的
LabelName string //标签名称
ModelName string //模型名称
ModelVersion string //模型版本
CkptName string //权重文件名称
PreTrainModelUrl string //预训练模型地址
ResultUrl string //推理结果的obs路径
ResultJson string `xorm:"varchar(4000)"`
User *User `xorm:"-"`
Repo *Repository `xorm:"-"`
BenchmarkType string `xorm:"-"` //算法评测,模型评测


+ 0
- 8
modules/aisafety/resty.go View File

@@ -237,7 +237,6 @@ func GetTaskStatus(jobID string) (*ReturnMsg, error) {
Get(HOST + "/v1/external/eval-standard/query?serialNo=" + jobID)

log.Info("url=" + HOST + "/v1/external/eval-standard/query?serialNo=" + jobID)

responseStr := string(res.Body())
log.Info("GetTaskStatus responseStr=" + responseStr + " res code=" + fmt.Sprint(res.StatusCode()))

@@ -245,13 +244,6 @@ func GetTaskStatus(jobID string) (*ReturnMsg, error) {
log.Info("error =" + err.Error())
return nil, fmt.Errorf("Get task status error: %v", err)
} else {
log.Info("finished.")
// var reMap ReturnMsg
// err = json.Unmarshal(res.Body(), &reMap)
// if err == nil {
return &reMap, nil
// } else {
// return nil, fmt.Errorf("get error,code not 0")
// }
}
}

+ 28
- 11
routers/repo/aisafety.go View File

@@ -76,19 +76,17 @@ func GetAiSafetyTask(ctx *context.Context) {
}

func syncAiSafetyTaskStatus(job *models.Cloudbrain) {
if job.Type == models.TypeCloudBrainTwo {
if isTaskNotFinished(job.Status) {
if isTaskNotFinished(job.Status) {
if job.Type == models.TypeCloudBrainTwo {
queryTaskStatusFromCloudbrainTwo(job)
}
} else if job.Type == models.TypeCloudBrainOne {
if isTaskNotFinished(job.Status) {
} else if job.Type == models.TypeCloudBrainOne {
queryTaskStatusFromCloudbrain(job)
}
} else {
if job.Status == string(models.ModelSafetyTesting) {
queryTaskStatusFromModelSafetyTestServer(job)
} else {
if job.Status == string(models.ModelSafetyTesting) {
queryTaskStatusFromModelSafetyTestServer(job)
} else {
log.Info("The job is finished. status=" + job.Status)
}
log.Info("The job is finished. status=" + job.Status)
}
}
}
@@ -233,8 +231,27 @@ func queryTaskStatusFromModelSafetyTestServer(job *models.Cloudbrain) {
result, err := aisafety.GetTaskStatus(job.PreVersionName)
if err == nil {
if result.Code == "0" {

if result.Data.Status == 1 {
log.Info("The task is running....")
} else {
if result.Data.Code == 0 {
job.ResultJson = result.Data.StandardJson
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
}
}
} else {
log.Info("The task is failed.")
job.Status = string(models.JobFailed)
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
}
} else {
log.Info("The task not found.....")
}
}



Loading…
Cancel
Save