From 2ba72dcb25a9510805d537e19d0210b18d4a378e Mon Sep 17 00:00:00 2001 From: liuzx Date: Thu, 30 Dec 2021 11:19:55 +0800 Subject: [PATCH] update --- models/cloudbrain.go | 1 + modules/auth/modelarts.go | 1 + modules/modelarts/modelarts.go | 2 + routers/repo/modelarts.go | 83 ++++++++++++++++++++++++++++++++++ routers/routes/routes.go | 1 + 5 files changed, 88 insertions(+) diff --git a/models/cloudbrain.go b/models/cloudbrain.go index 57dc74ef7..a8a250478 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -126,6 +126,7 @@ type Cloudbrain struct { EngineName string //引擎名称 TotalVersionCount int //任务的所有版本数量,包括删除的 + LabelName string //标签名称 ModelName string //模型名称 ModelVersion string //模型版本 CkptName string //权重文件名称 diff --git a/modules/auth/modelarts.go b/modules/auth/modelarts.go index 7d727f182..821cd72f8 100755 --- a/modules/auth/modelarts.go +++ b/modules/auth/modelarts.go @@ -62,6 +62,7 @@ type CreateModelArtsInferenceJobForm struct { VersionName string `form:"version_name" binding:"Required"` FlavorName string `form:"flaver_names" binding:"Required"` EngineName string `form:"engine_names" binding:"Required"` + LabelName string `form:"label_names" binding:"Required"` TrainUrl string `form:"train_url" binding:"Required"` ModelName string `form:"model_name" binding:"Required"` ModelVersion string `form:"model_version" binding:"Required"` diff --git a/modules/modelarts/modelarts.go b/modules/modelarts/modelarts.go index 2a6e8317a..24bf31b80 100755 --- a/modules/modelarts/modelarts.go +++ b/modules/modelarts/modelarts.go @@ -136,6 +136,7 @@ type GenerateInferenceJobReq struct { BranchName string FlavorName string EngineName string + LabelName string IsLatestVersion string VersionCount int TotalVersionCount int @@ -535,6 +536,7 @@ func GenerateInferenceJob(ctx *context.Context, req *GenerateInferenceJobReq) (e WorkServerNumber: req.WorkServerNumber, FlavorName: req.FlavorName, EngineName: req.EngineName, + LabelName: req.LabelName, IsLatestVersion: req.IsLatestVersion, VersionCount: req.VersionCount, TotalVersionCount: req.TotalVersionCount, diff --git a/routers/repo/modelarts.go b/routers/repo/modelarts.go index 3368474e4..f06123f76 100755 --- a/routers/repo/modelarts.go +++ b/routers/repo/modelarts.go @@ -1,6 +1,7 @@ package repo import ( + "archive/zip" "encoding/json" "errors" "io" @@ -1284,6 +1285,19 @@ func paramCheckCreateInferenceJob(form auth.CreateModelArtsInferenceJobForm) err return errors.New("计算节点数必须在1-25之间") } + if form.ModelName == "" { + log.Error("the ModelName(%d) must not be nil", form.ModelName) + return errors.New("模型名称不能为空") + } + if form.ModelVersion == "" { + log.Error("the ModelVersion(%d) must not be nil", form.ModelVersion) + return errors.New("模型版本不能为空") + } + if form.CkptName == "" { + log.Error("the CkptName(%d) must not be nil", form.CkptName) + return errors.New("权重文件不能为空") + } + return nil } @@ -1564,6 +1578,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference branch_name := form.BranchName FlavorName := form.FlavorName EngineName := form.EngineName + LabelName := form.LabelName isLatestVersion := modelarts.IsLatestVersion VersionCount := modelarts.VersionCount trainUrl := form.TrainUrl @@ -1694,6 +1709,7 @@ func InferenceJobCreate(ctx *context.Context, form auth.CreateModelArtsInference Params: form.Params, FlavorName: FlavorName, EngineName: EngineName, + LabelName: LabelName, IsLatestVersion: isLatestVersion, VersionCount: VersionCount, TotalVersionCount: modelarts.TotalVersionCount, @@ -2018,3 +2034,70 @@ func DeleteJobStorage(jobName string) error { return nil } + +func DownloadMultiResultFile(ctx *context.Context) { + log.Info("DownloadMultiModelFile start.") + id := ctx.Query("ID") + log.Info("id=" + id) + task, err := models.QueryModelById(id) + if err != nil { + log.Error("no such model!", err.Error()) + ctx.ServerError("no such model:", err) + return + } + if !isCanDeleteOrDownload(ctx, task) { + ctx.ServerError("no right.", errors.New(ctx.Tr("repo.model_noright"))) + return + } + + path := Model_prefix + models.AttachmentRelativePath(id) + "/" + + allFile, err := storage.GetAllObjectByBucketAndPrefix(setting.Bucket, path) + if err == nil { + //count++ + models.ModifyModelDownloadCount(id) + + returnFileName := task.Name + "_" + task.Version + ".zip" + ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+returnFileName) + ctx.Resp.Header().Set("Content-Type", "application/octet-stream") + w := zip.NewWriter(ctx.Resp) + defer w.Close() + for _, oneFile := range allFile { + if oneFile.IsDir { + log.Info("zip dir name:" + oneFile.FileName) + } else { + log.Info("zip file name:" + oneFile.FileName) + fDest, err := w.Create(oneFile.FileName) + if err != nil { + log.Info("create zip entry error, download file failed: %s\n", err.Error()) + ctx.ServerError("download file failed:", err) + return + } + body, err := storage.ObsDownloadAFile(setting.Bucket, path+oneFile.FileName) + if err != nil { + log.Info("download file failed: %s\n", err.Error()) + ctx.ServerError("download file failed:", err) + return + } else { + defer body.Close() + p := make([]byte, 1024) + var readErr error + var readCount int + // 读取对象内容 + for { + readCount, readErr = body.Read(p) + if readCount > 0 { + fDest.Write(p[:readCount]) + } + if readErr != nil { + break + } + } + } + } + } + } else { + log.Info("error,msg=" + err.Error()) + ctx.ServerError("no file to download.", err) + } +} diff --git a/routers/routes/routes.go b/routers/routes/routes.go index 6a3e89b3d..1c5fa7b1d 100755 --- a/routers/routes/routes.go +++ b/routers/routes/routes.go @@ -1038,6 +1038,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("", reqRepoCloudBrainReader, repo.InferenceJobShow) m.Post("/stop", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobStop) m.Post("/del", cloudbrain.AdminOrOwnerOrJobCreaterRight, repo.InferenceJobDel) + m.Get("/downloadall", repo.DownloadMultiResultFile) }) m.Get("/create", reqRepoCloudBrainWriter, repo.InferenceJobNew) m.Post("/create", reqRepoCloudBrainWriter, bindIgnErr(auth.CreateModelArtsInferenceJobForm{}), repo.InferenceJobCreate)