Browse Source

提交后端功能代码。

Signed-off-by: zouap <zouap@pcl.ac.cn>
tags/v1.22.10.1^2
zouap 3 years ago
parent
commit
0f4a71ea9a
4 changed files with 163 additions and 16 deletions
  1. +1
    -1
      modules/aisafety/resty.go
  2. +16
    -0
      modules/setting/setting.go
  3. +140
    -11
      routers/repo/aisafety.go
  4. +6
    -4
      routers/routes/routes.go

+ 1
- 1
modules/aisafety/resty.go View File

@@ -227,7 +227,7 @@ func GetTaskStatus(jobID string) (map[string]interface{}, error) {


if err != nil { if err != nil {
log.Info("error =" + err.Error()) log.Info("error =" + err.Error())
return nil, fmt.Errorf("resty GetJob: %v", err)
return nil, fmt.Errorf("Get task status error: %v", err)
} else { } else {
reMap := make(map[string]interface{}) reMap := make(map[string]interface{})
err = json.Unmarshal(res.Body(), &reMap) err = json.Unmarshal(res.Body(), &reMap)


+ 16
- 0
modules/setting/setting.go View File

@@ -707,6 +707,13 @@ var (
NPU_MINDSPORE_IMAGE_ID int NPU_MINDSPORE_IMAGE_ID int
NPU_TENSORFLOW_IMAGE_ID int NPU_TENSORFLOW_IMAGE_ID int
}{} }{}

ModelSafetyTest = struct {
BaseDataSetName string
BaseDataSetUUID string
CombatDataSetName string
CombatDataSetUUID string
}{}
) )


// DateLang transforms standard language locale name to corresponding value in datetime plugin. // DateLang transforms standard language locale name to corresponding value in datetime plugin.
@@ -1527,6 +1534,15 @@ func NewContext() {
getGrampusConfig() getGrampusConfig()
getModelartsCDConfig() getModelartsCDConfig()
getModelConvertConfig() getModelConvertConfig()
getModelSafetyConfig()
}

func getModelSafetyConfig() {
sec := Cfg.Section("model_safety_test")
ModelSafetyTest.BaseDataSetName = sec.Key("BaseDataSetName").MustString("ImageNet1000_100基础数据集;CIFAR10_1000基础数据集")
ModelSafetyTest.BaseDataSetUUID = sec.Key("BaseDataSetUUID").MustString("0fa81800-e95e-42f4-ab40-2c3ca83f2344;6eaab665-1c68-45fc-ad05-c070f2db092e")
ModelSafetyTest.CombatDataSetName = sec.Key("CombatDataSetName").MustString("ImageNet1000_100_FGSM;CIFAR10_1000_FGSM.zip")
ModelSafetyTest.CombatDataSetUUID = sec.Key("CombatDataSetUUID").MustString("9ba30d3f-83e1-4f9f-849d-6f93217e2ca3;23825796-e4f3-4cf8-b697-9963048cef42")
} }


func getModelConvertConfig() { func getModelConvertConfig() {


+ 140
- 11
routers/repo/aisafety.go View File

@@ -1,8 +1,10 @@
package repo package repo


import ( import (
"bufio"
"encoding/json" "encoding/json"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
@@ -22,6 +24,10 @@ import (
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
) )


const (
tplModelSafetyTestCreate = "repo/modelsafety/new"
)

func CloudBrainAiSafetyCreateTest(ctx *context.Context) { func CloudBrainAiSafetyCreateTest(ctx *context.Context) {
log.Info("start to create CloudBrainAiSafetyCreate") log.Info("start to create CloudBrainAiSafetyCreate")
uuid := uuid.NewV4() uuid := uuid.NewV4()
@@ -57,34 +63,148 @@ func CloudBrainAiSafetyCreateTest(ctx *context.Context) {


func GetAiSafetyTask(ctx *context.Context) { func GetAiSafetyTask(ctx *context.Context) {
var ID = ctx.Params(":jobid") var ID = ctx.Params(":jobid")
task, err := models.GetCloudbrainByJobIDWithDeleted(ID)
getAiSafetyTaskStatusFromCloudbrain(ID)
}

func getAiSafetyTaskStatusFromCloudbrain(ID string) {
job, err := models.GetCloudbrainByJobIDWithDeleted(ID)
if err != nil { if err != nil {
log.Error("GetCloudbrainByJobID failed:" + err.Error()) log.Error("GetCloudbrainByJobID failed:" + err.Error())
ctx.NotFound(ctx.Req.URL.RequestURI(), nil)
return return
} }
if task.Type == models.TypeCloudBrainTwo {
if job.Type == models.TypeCloudBrainTwo {

} else if job.Type == models.TypeCloudBrainOne {
if isTaskNotFinished(job.Status) {
log.Info("The task not finished,name=" + job.DisplayJobName)
jobResult, err := cloudbrain.GetJob(job.JobID)

result, err := models.ConvertToJobResultPayload(jobResult.Payload)
if err != nil {
log.Error("ConvertToJobResultPayload failed:", err)
return
}
job.Status = result.JobStatus.State
if result.JobStatus.State != string(models.JobWaiting) && result.JobStatus.State != string(models.JobFailed) {
taskRoles := result.TaskRoles
taskRes, _ := models.ConvertToTaskPod(taskRoles[cloudbrain.SubTaskName].(map[string]interface{}))
job.Status = taskRes.TaskStatuses[0].State
}

if result.JobStatus.State != string(models.JobSucceeded) {
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
} else {
//
job.Status = string(models.ModelSafetyTesting)
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
//send msg to beihang
sendGpuInferenceResultToTest(job)
}


} else if task.Type == models.TypeCloudBrainOne {
} else {
if job.Status == string(models.ModelSafetyTesting) {
//
result, err := aisafety.GetTaskStatus(job.PreVersionName)
if err == nil {
if result["code"] != nil {

}
}
}
}


} }
} }


func isTaskFinished(status string) bool {
func sendGpuInferenceResultToTest(job *models.Cloudbrain) {
datasetname := job.DatasetName
datasetnames := strings.Split(datasetname, ";")
indicator := job.LabelName

req := aisafety.TaskReq{
UnionId: job.JobID,
EvalName: job.DisplayJobName,
EvalContent: job.Description,
TLPath: "test",
Indicators: strings.Split(indicator, ";"),
CDName: datasetnames[1],
BDName: datasetnames[0],
}

resultDir := "/model"
prefix := "/" + setting.CBCodePathPrefix + job.JobName + resultDir
files, err := storage.GetOneLevelAllObjectUnderDirMinio(setting.Attachment.Minio.Bucket, prefix, "")
if err != nil {
log.Error("query cloudbrain one model failed: %v", err)
return
}
jsonContent := ""
for _, file := range files {
if strings.HasSuffix(file.FileName, "result.json") {
path := storage.GetMinioPath(job.JobName+resultDir+"/", file.FileName)
log.Info("path=" + path)
reader, err := os.Open(path)
defer reader.Close()
if err == nil {
r := bufio.NewReader(reader)
for {
line, error := r.ReadString('\n')
if error == io.EOF {
log.Info("read file completed.")
break
}
if error != nil {
log.Info("read file error." + error.Error())
break
}
jsonContent += line
}
}
break
}
}
if jsonContent != "" {
serialNo, err := aisafety.CreateSafetyTask(req, jsonContent)
if err == nil {
//update serial no to db
job.PreVersionName = serialNo
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
}
} else {
log.Info("The json is null. so set it failed.")
//update task failed.
job.Status = string(models.JobFailed)
err = models.UpdateJob(job)
if err != nil {
log.Error("UpdateJob failed:", err)
}
}
}

func isTaskNotFinished(status string) bool {
if status == string(models.ModelArtsTrainJobRunning) || status == string(models.ModelArtsTrainJobWaiting) { if status == string(models.ModelArtsTrainJobRunning) || status == string(models.ModelArtsTrainJobWaiting) {
return false
return true
} }
if status == string(models.JobWaiting) || status == string(models.JobRunning) { if status == string(models.JobWaiting) || status == string(models.JobRunning) {
return false
return true
} }


if status == string(models.ModelArtsTrainJobUnknown) || status == string(models.ModelArtsTrainJobInit) { if status == string(models.ModelArtsTrainJobUnknown) || status == string(models.ModelArtsTrainJobInit) {
return false
return true
} }
if status == string(models.ModelArtsTrainJobImageCreating) || status == string(models.ModelArtsTrainJobSubmitTrying) { if status == string(models.ModelArtsTrainJobImageCreating) || status == string(models.ModelArtsTrainJobSubmitTrying) {
return false
return true
} }
return true
return false
} }


func StopAiSafetyTask(ctx *context.Context) { func StopAiSafetyTask(ctx *context.Context) {
@@ -95,7 +215,16 @@ func DelAiSafetyTask(ctx *context.Context) {


} }


func CloudBrainAiSafetyCreate(ctx *context.Context) {
func AiSafetyCreateForGet(ctx *context.Context) {
ctx.Data["PageIsCloudBrain"] = true
ctx.Data["BaseDataSetName"] = setting.ModelSafetyTest.BaseDataSetName
ctx.Data["BaseDataSetUUID"] = setting.ModelSafetyTest.BaseDataSetUUID
ctx.Data["CombatDataSetName"] = setting.ModelSafetyTest.CombatDataSetName
ctx.Data["CombatDataSetUUID"] = setting.ModelSafetyTest.CombatDataSetUUID
ctx.HTML(200, tplModelSafetyTestCreate)
}

func AiSafetyCreateForPost(ctx *context.Context) {
ctx.Data["PageIsCloudBrain"] = true ctx.Data["PageIsCloudBrain"] = true
displayJobName := ctx.Query("DisplayJobName") displayJobName := ctx.Query("DisplayJobName")
jobName := util.ConvertDisplayJobNameToJobName(displayJobName) jobName := util.ConvertDisplayJobNameToJobName(displayJobName)


+ 6
- 4
routers/routes/routes.go View File

@@ -6,15 +6,16 @@ package routes


import ( import (
"bytes" "bytes"
"code.gitea.io/gitea/routers/reward/point"
"code.gitea.io/gitea/routers/task"
"code.gitea.io/gitea/services/reward"
"encoding/gob" "encoding/gob"
"net/http" "net/http"
"path" "path"
"text/template" "text/template"
"time" "time"


"code.gitea.io/gitea/routers/reward/point"
"code.gitea.io/gitea/routers/task"
"code.gitea.io/gitea/services/reward"

"code.gitea.io/gitea/modules/slideimage" "code.gitea.io/gitea/modules/slideimage"


"code.gitea.io/gitea/routers/image" "code.gitea.io/gitea/routers/image"
@@ -1231,7 +1232,8 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("", reqRepoCloudBrainWriter, repo.GetAiSafetyTask) m.Get("", reqRepoCloudBrainWriter, repo.GetAiSafetyTask)
m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.StopAiSafetyTask) m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.StopAiSafetyTask)
m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.DelAiSafetyTask) m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.DelAiSafetyTask)
m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.CloudBrainAiSafetyCreate)
m.Get("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.AiSafetyCreateForGet)
m.Post("/create", reqWechatBind, reqRepoCloudBrainWriter, repo.AiSafetyCreateForPost)
}) })
}, context.RepoRef()) }, context.RepoRef())




Loading…
Cancel
Save