From 4880969f82fe87cfcb328e16cf39a8ccc8d70744 Mon Sep 17 00:00:00 2001 From: zouap Date: Thu, 9 Jun 2022 15:43:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BB=A3=E7=A0=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: zouap --- routers/repo/ai_model_convert.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/routers/repo/ai_model_convert.go b/routers/repo/ai_model_convert.go index 87b11ecd6..7f7bbfcab 100644 --- a/routers/repo/ai_model_convert.go +++ b/routers/repo/ai_model_convert.go @@ -99,10 +99,12 @@ func SaveModelConvert(ctx *context.Context) { UserId: ctx.User.ID, } models.SaveModelConvert(modelConvert) - if modelConvert.SrcEngine == PYTORCH_ENGINE { + if modelConvert.SrcEngine == PYTORCH_ENGINE || modelConvert.SrcEngine == TENSORFLOW_ENGINE { + log.Info("create gpu train job.") err = createGpuTrainJob(modelConvert, ctx, task.Path) } else { //create npu job + log.Info("create npu train job.") createNpuTrainJob(modelConvert, ctx, task.Path) } @@ -386,7 +388,7 @@ func DeleteModelConvert(ctx *context.Context) { } func isCloudBrainTask(task *models.AiModelConvert) bool { - if task.SrcEngine == PYTORCH_ENGINE { + if task.SrcEngine == PYTORCH_ENGINE || task.SrcEngine == TENSORFLOW_ENGINE { return true } return false