From 0fb745427b958e62003546b44dabaed4d5b3f405 Mon Sep 17 00:00:00 2001 From: zouap Date: Fri, 16 Sep 2022 11:18:18 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BB=A3=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zouap --- models/ai_model_manage.go | 16 +++++++++ routers/repo/ai_model_manage.go | 58 ++++++++++++++++++++------------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/models/ai_model_manage.go b/models/ai_model_manage.go index 0ea01d6e5..97cae95a0 100644 --- a/models/ai_model_manage.go +++ b/models/ai_model_manage.go @@ -286,6 +286,22 @@ func ModifyModelDescription(id string, description string) error { return nil } +func ModifyModelStatus(id string, modelSize int64, status int, modelPath string) error { + var sess *xorm.Session + sess = x.ID(id) + defer sess.Close() + re, err := sess.Cols("size", "status", "path").Update(&AiModelManage{ + Size: modelSize, + Status: status, + Path: modelPath, + }) + if err != nil { + return err + } + log.Info("success to update ModelStatus from db.re=" + fmt.Sprint((re))) + return nil +} + func ModifyModelNewProperty(id string, new int, versioncount int) error { var sess *xorm.Session sess = x.ID(id) diff --git a/routers/repo/ai_model_manage.go b/routers/repo/ai_model_manage.go index d01539a75..1b295660a 100644 --- a/routers/repo/ai_model_manage.go +++ b/routers/repo/ai_model_manage.go @@ -27,6 +27,9 @@ const ( MODEL_LATEST = 1 MODEL_NOT_LATEST = 0 MODEL_MAX_SIZE = 1024 * 1024 * 1024 + STATUS_COPY_MODEL = 1 + STATUS_FINISHED = 0 + STATUS_ERROR = 2 ) func saveModelByParameters(jobId string, versionName string, name string, version string, label string, description string, engine int, ctx *context.Context) error { @@ -62,13 +65,9 @@ func saveModelByParameters(jobId string, versionName string, name string, versio modelSelectedFile := ctx.Query("modelSelectedFile") //download model zip //train type if aiTask.ComputeResource == models.NPUResource { - modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "", aiTask.TrainUrl, modelSelectedFile) - if err != nil { - log.Info("download model from CloudBrainTwo faild." + err.Error()) - return err - } cloudType = models.TypeCloudBrainTwo } else if aiTask.ComputeResource == models.GPUResource { + cloudType = models.TypeCloudBrainOne var ResourceSpecs *models.ResourceSpecs json.Unmarshal([]byte(setting.ResourceSpecs), &ResourceSpecs) for _, tmp := range ResourceSpecs.ResourceSpec { @@ -77,24 +76,8 @@ func saveModelByParameters(jobId string, versionName string, name string, versio aiTask.FlavorName = flaverName } } - modelPath, modelSize, err = downloadModelFromCloudBrainOne(id, aiTask.JobName, "", aiTask.TrainUrl, modelSelectedFile) - if err != nil { - log.Info("download model from CloudBrainOne faild." + err.Error()) - return err - } - cloudType = models.TypeCloudBrainOne } - // else if cloudType == models.TypeC2Net { - // if aiTask.ComputeResource == models.NPUResource { - // modelPath, modelSize, err = downloadModelFromCloudBrainTwo(id, aiTask.JobName, "", aiTask.TrainUrl, modelSelectedFile) - // if err != nil { - // log.Info("download model from CloudBrainTwo faild." + err.Error()) - // return err - // } - // } else if aiTask.ComputeResource == models.GPUResource { - - // } - // } + accuracy := make(map[string]string) accuracy["F1"] = "" accuracy["Recall"] = "" @@ -123,6 +106,7 @@ func saveModelByParameters(jobId string, versionName string, name string, versio Engine: int64(engine), TrainTaskInfo: string(aiTaskJson), Accuracy: string(accuracyJson), + Status: STATUS_COPY_MODEL, } err = models.SaveModelToDb(model) @@ -146,11 +130,41 @@ func saveModelByParameters(jobId string, versionName string, name string, versio models.UpdateRepositoryUnits(ctx.Repo.Repository, units, deleteUnitTypes) + go asyncToCopyModel(aiTask, id, modelSelectedFile) + log.Info("save model end.") notification.NotifyOtherTask(ctx.User, ctx.Repo.Repository, id, name, models.ActionCreateNewModelTask) return nil } +func asyncToCopyModel(aiTask *models.Cloudbrain, id string, modelSelectedFile string) { + if aiTask.ComputeResource == models.NPUResource { + modelPath, modelSize, err := downloadModelFromCloudBrainTwo(id, aiTask.JobName, "", aiTask.TrainUrl, modelSelectedFile) + if err != nil { + updateStatus(id, 0, STATUS_ERROR, modelPath) + log.Info("download model from CloudBrainTwo faild." + err.Error()) + } else { + updateStatus(id, modelSize, STATUS_FINISHED, modelPath) + } + } else if aiTask.ComputeResource == models.GPUResource { + + modelPath, modelSize, err := downloadModelFromCloudBrainOne(id, aiTask.JobName, "", aiTask.TrainUrl, modelSelectedFile) + if err != nil { + updateStatus(id, 0, STATUS_ERROR, modelPath) + log.Info("download model from CloudBrainOne faild." + err.Error()) + } else { + updateStatus(id, modelSize, STATUS_FINISHED, modelPath) + } + } +} + +func updateStatus(id string, modelSize int64, status int, modelPath string) { + err := models.ModifyModelStatus(id, modelSize, STATUS_FINISHED, modelPath) + if err != nil { + log.Info("update status error." + err.Error()) + } +} + func SaveNewNameModel(ctx *context.Context) { if !ctx.Repo.CanWrite(models.UnitTypeModelManage) { ctx.Error(403, ctx.Tr("repo.model_noright"))