Browse Source

增加6.8的接口。

Signed-off-by: zouap <zouap@pcl.ac.cn>
tags/v1.22.11.2^2
zouap 3 years ago
parent
commit
7efc877097
3 changed files with 72 additions and 24 deletions
  1. +3
    -0
      routers/api/v1/api.go
  2. +29
    -12
      routers/api/v1/repo/modelmanage.go
  3. +40
    -12
      routers/repo/ai_model_manage.go

+ 3
- 0
routers/api/v1/api.go View File

@@ -993,6 +993,9 @@ func RegisterRoutes(m *macaron.Macaron) {
m.Get("/query_model_byId", repo.QueryModelById)
m.Get("/query_model_for_predict", repo.QueryModelListForPredict)
m.Get("/query_modelfile_for_predict", repo.QueryModelFileForPredict)
m.Get("/query_train_job", repo.QueryTrainJobList)
m.Get("/query_train_model", repo.QueryTrainModelList)
m.Get("/query_train_job_version", repo.QueryTrainJobVersionList)
m.Post("/create_model_convert", repo.CreateModelConvert)
m.Get("/show_model_convert_page")
m.Get("/:id", repo.GetCloudbrainModelConvertTask)


+ 29
- 12
routers/api/v1/repo/modelmanage.go View File

@@ -5,6 +5,7 @@ import (

"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/storage"
routerRepo "code.gitea.io/gitea/routers/repo"
)

@@ -47,22 +48,38 @@ func QueryModelListForPredict(ctx *context.APIContext) {
routerRepo.QueryModelListForPredict(ctx.Context)
}

func QueryTrainModelList(ctx *context.APIContext) {
result, err := routerRepo.QueryTrainModelFileById(ctx.Context)
if err != nil {
log.Info("query error." + err.Error())
}
re := convertFileFormat(result)
ctx.JSON(http.StatusOK, re)
}

func convertFileFormat(result []storage.FileInfo) []FileInfo {
re := make([]FileInfo, 0)
if result != nil {
for _, file := range result {
tmpFile := FileInfo{
FileName: file.FileName,
ModTime: file.ModTime,
IsDir: file.IsDir,
Size: file.Size,
ParenDir: file.ParenDir,
UUID: file.UUID,
}
re = append(re, tmpFile)
}
}
return re
}

func QueryModelFileForPredict(ctx *context.APIContext) {
log.Info("QueryModelFileForPredict by api.")
id := ctx.Query("id")
result := routerRepo.QueryModelFileByID(id)
re := make([]FileInfo, len(result))
for _, file := range result {
tmpFile := FileInfo{
FileName: file.FileName,
ModTime: file.ModTime,
IsDir: file.IsDir,
Size: file.Size,
ParenDir: file.ParenDir,
UUID: file.UUID,
}
re = append(re, tmpFile)
}
re := convertFileFormat(result)
ctx.JSON(http.StatusOK, re)
}



+ 40
- 12
routers/repo/ai_model_manage.go View File

@@ -528,23 +528,33 @@ func QueryTrainJobList(ctx *context.Context) {

}

func QueryTrainModelList(ctx *context.Context) {
log.Info("query train job list. start.")
jobName := ctx.Query("jobName")
taskType := ctx.QueryInt("type")
VersionName := ctx.Query("versionName")
if VersionName == "" {
VersionName = ctx.Query("VersionName")
func QueryTrainModelFileById(ctx *context.Context) ([]storage.FileInfo, error) {
JobID := ctx.Query("jobId")
VersionListTasks, count, err := models.QueryModelTrainJobVersionList(JobID)
if err == nil {
if count == 1 {
task := VersionListTasks[0]
jobName := task.JobName
taskType := task.Type
VersionName := task.VersionName
modelDbResult, err := getModelFromObjectSave(jobName, taskType, VersionName)
return modelDbResult, err
}
}
log.Info("get TypeCloudBrainTwo TrainJobListModel failed:", err)
return nil, errors.New("Not found task.")
}

func getModelFromObjectSave(jobName string, taskType int, VersionName string) ([]storage.FileInfo, error) {
if taskType == models.TypeCloudBrainTwo {
objectkey := path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, VersionName) + "/"
modelDbResult, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, objectkey)
log.Info("bucket=" + setting.Bucket + " objectkey=" + objectkey)
if err != nil {
log.Info("get TypeCloudBrainTwo TrainJobListModel failed:", err)
return nil, err
} else {
ctx.JSON(200, modelDbResult)
return
return modelDbResult, nil
}
} else if taskType == models.TypeCloudBrainOne {
modelSrcPrefix := setting.CBCodePathPrefix + jobName + "/model/"
@@ -552,12 +562,30 @@ func QueryTrainModelList(ctx *context.Context) {
modelDbResult, err := storage.GetAllObjectByBucketAndPrefixMinio(bucketName, modelSrcPrefix)
if err != nil {
log.Info("get TypeCloudBrainOne TrainJobListModel failed:", err)
return nil, err
} else {
ctx.JSON(200, modelDbResult)
return
return modelDbResult, nil
}
}
ctx.JSON(200, "")
return nil, errors.New("Not support.")
}

func QueryTrainModelList(ctx *context.Context) {
log.Info("query train job list. start.")
jobName := ctx.Query("jobName")
taskType := ctx.QueryInt("type")
VersionName := ctx.Query("versionName")
if VersionName == "" {
VersionName = ctx.Query("VersionName")
}
modelDbResult, err := getModelFromObjectSave(jobName, taskType, VersionName)
if err != nil {
log.Info("get TypeCloudBrainTwo TrainJobListModel failed:", err)
ctx.JSON(200, "")
} else {
ctx.JSON(200, modelDbResult)
return
}
}

func DownloadSingleModelFile(ctx *context.Context) {


Loading…
Cancel
Save