diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 5ecbf3a5f..6e2c74b0b 100755 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -993,6 +993,9 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/query_model_byId", repo.QueryModelById) m.Get("/query_model_for_predict", repo.QueryModelListForPredict) m.Get("/query_modelfile_for_predict", repo.QueryModelFileForPredict) + m.Get("/query_train_job", repo.QueryTrainJobList) + m.Get("/query_train_model", repo.QueryTrainModelList) + m.Get("/query_train_job_version", repo.QueryTrainJobVersionList) m.Post("/create_model_convert", repo.CreateModelConvert) m.Get("/show_model_convert_page") m.Get("/:id", repo.GetCloudbrainModelConvertTask) diff --git a/routers/api/v1/repo/modelmanage.go b/routers/api/v1/repo/modelmanage.go index 253d51e00..2c1fd9f01 100644 --- a/routers/api/v1/repo/modelmanage.go +++ b/routers/api/v1/repo/modelmanage.go @@ -5,6 +5,7 @@ import ( "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/storage" routerRepo "code.gitea.io/gitea/routers/repo" ) @@ -47,22 +48,38 @@ func QueryModelListForPredict(ctx *context.APIContext) { routerRepo.QueryModelListForPredict(ctx.Context) } +func QueryTrainModelList(ctx *context.APIContext) { + result, err := routerRepo.QueryTrainModelFileById(ctx.Context) + if err != nil { + log.Info("query error." + err.Error()) + } + re := convertFileFormat(result) + ctx.JSON(http.StatusOK, re) +} + +func convertFileFormat(result []storage.FileInfo) []FileInfo { + re := make([]FileInfo, 0) + if result != nil { + for _, file := range result { + tmpFile := FileInfo{ + FileName: file.FileName, + ModTime: file.ModTime, + IsDir: file.IsDir, + Size: file.Size, + ParenDir: file.ParenDir, + UUID: file.UUID, + } + re = append(re, tmpFile) + } + } + return re +} + func QueryModelFileForPredict(ctx *context.APIContext) { log.Info("QueryModelFileForPredict by api.") id := ctx.Query("id") result := routerRepo.QueryModelFileByID(id) - re := make([]FileInfo, len(result)) - for _, file := range result { - tmpFile := FileInfo{ - FileName: file.FileName, - ModTime: file.ModTime, - IsDir: file.IsDir, - Size: file.Size, - ParenDir: file.ParenDir, - UUID: file.UUID, - } - re = append(re, tmpFile) - } + re := convertFileFormat(result) ctx.JSON(http.StatusOK, re) } diff --git a/routers/repo/ai_model_manage.go b/routers/repo/ai_model_manage.go index ca90ff432..99f032572 100644 --- a/routers/repo/ai_model_manage.go +++ b/routers/repo/ai_model_manage.go @@ -528,23 +528,33 @@ func QueryTrainJobList(ctx *context.Context) { } -func QueryTrainModelList(ctx *context.Context) { - log.Info("query train job list. start.") - jobName := ctx.Query("jobName") - taskType := ctx.QueryInt("type") - VersionName := ctx.Query("versionName") - if VersionName == "" { - VersionName = ctx.Query("VersionName") +func QueryTrainModelFileById(ctx *context.Context) ([]storage.FileInfo, error) { + JobID := ctx.Query("jobId") + VersionListTasks, count, err := models.QueryModelTrainJobVersionList(JobID) + if err == nil { + if count == 1 { + task := VersionListTasks[0] + jobName := task.JobName + taskType := task.Type + VersionName := task.VersionName + modelDbResult, err := getModelFromObjectSave(jobName, taskType, VersionName) + return modelDbResult, err + } } + log.Info("get TypeCloudBrainTwo TrainJobListModel failed:", err) + return nil, errors.New("Not found task.") +} + +func getModelFromObjectSave(jobName string, taskType int, VersionName string) ([]storage.FileInfo, error) { if taskType == models.TypeCloudBrainTwo { objectkey := path.Join(setting.TrainJobModelPath, jobName, setting.OutPutPath, VersionName) + "/" modelDbResult, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, objectkey) log.Info("bucket=" + setting.Bucket + " objectkey=" + objectkey) if err != nil { log.Info("get TypeCloudBrainTwo TrainJobListModel failed:", err) + return nil, err } else { - ctx.JSON(200, modelDbResult) - return + return modelDbResult, nil } } else if taskType == models.TypeCloudBrainOne { modelSrcPrefix := setting.CBCodePathPrefix + jobName + "/model/" @@ -552,12 +562,30 @@ func QueryTrainModelList(ctx *context.Context) { modelDbResult, err := storage.GetAllObjectByBucketAndPrefixMinio(bucketName, modelSrcPrefix) if err != nil { log.Info("get TypeCloudBrainOne TrainJobListModel failed:", err) + return nil, err } else { - ctx.JSON(200, modelDbResult) - return + return modelDbResult, nil } } - ctx.JSON(200, "") + return nil, errors.New("Not support.") +} + +func QueryTrainModelList(ctx *context.Context) { + log.Info("query train job list. start.") + jobName := ctx.Query("jobName") + taskType := ctx.QueryInt("type") + VersionName := ctx.Query("versionName") + if VersionName == "" { + VersionName = ctx.Query("VersionName") + } + modelDbResult, err := getModelFromObjectSave(jobName, taskType, VersionName) + if err != nil { + log.Info("get TypeCloudBrainTwo TrainJobListModel failed:", err) + ctx.JSON(200, "") + } else { + ctx.JSON(200, modelDbResult) + return + } } func DownloadSingleModelFile(ctx *context.Context) {