diff --git a/routers/repo/ai_model_convert.go b/routers/repo/ai_model_convert.go index 72dffc1e4..85824ee01 100644 --- a/routers/repo/ai_model_convert.go +++ b/routers/repo/ai_model_convert.go @@ -36,9 +36,10 @@ const ( Success = "S000" GPU_PYTORCH_IMAGE = "dockerhub.pcl.ac.cn:5000/user-images/openi:tensorRT_7_zouap" - PytorchBootFile = "convert_pytorch.py" - MindsporeBootFile = "convert_mindspore.py" - TensorFlowBootFile = "convert_tensorflow.py" + PytorchBootFile = "convert_pytorch.py" + MindsporeBootFile = "convert_mindspore.py" + TensorFlowNpuBootFile = "convert_tensorflow.py" + TensorFlowGpuBootFile = "convert_tensorflow_gpu.py" ConvertRepoPath = "https://git.openi.org.cn/zouap/npu_test" @@ -98,7 +99,7 @@ func SaveModelConvert(ctx *context.Context) { UserId: ctx.User.ID, } models.SaveModelConvert(modelConvert) - if modelConvert.SrcEngine == PYTORCH_ENGINE { + if modelConvert.SrcEngine == PYTORCH_ENGINE || modelConvert.SrcEngine == TENSORFLOW_ENGINE { err = createGpuTrainJob(modelConvert, ctx, task.Path) } else { //create npu job @@ -188,7 +189,7 @@ func createNpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context bootfile := MindsporeBootFile if modelConvert.SrcEngine == TENSORFLOW_ENGINE { engineId = int64(NPU_TENSORFLOW_IMAGE_ID) - bootfile = TensorFlowBootFile + bootfile = TensorFlowNpuBootFile } req := &modelarts.GenerateTrainJobReq{ JobName: modelConvert.ID, @@ -265,7 +266,9 @@ func downloadConvertCode(repopath string, codePath, branchName string) error { func createGpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context, modelRelativePath string) error { command := "" if modelConvert.SrcEngine == PYTORCH_ENGINE { - command = getPytorchModelConvertCommand(modelConvert.ID, modelConvert.ModelPath, modelConvert) + command = getGpuModelConvertCommand(modelConvert.ID, modelConvert.ModelPath, modelConvert, PytorchBootFile) + } else if modelConvert.SrcEngine == TENSORFLOW_ENGINE { + command = getGpuModelConvertCommand(modelConvert.ID, modelConvert.ModelPath, modelConvert, TensorFlowGpuBootFile) } log.Info("command=" + command) @@ -353,7 +356,7 @@ func createGpuTrainJob(modelConvert *models.AiModelConvert, ctx *context.Context return nil } -func getPytorchModelConvertCommand(name string, modelFile string, modelConvert *models.AiModelConvert) string { +func getGpuModelConvertCommand(name string, modelFile string, modelConvert *models.AiModelConvert, bootfile string) string { var command string intputshape := strings.Split(modelConvert.InputShape, ",") @@ -367,7 +370,7 @@ func getPytorchModelConvertCommand(name string, modelFile string, modelConvert * h = intputshape[2] w = intputshape[3] } - command += "python3 /code/" + PytorchBootFile + " --model " + modelFile + " --n " + n + " --c " + c + " --h " + h + " --w " + w + " > " + ModelMountPath + "/" + name + "-" + LogFile + command += "python3 /code/" + bootfile + " --model " + modelFile + " --n " + n + " --c " + c + " --h " + h + " --w " + w + " > " + ModelMountPath + "/" + name + "-" + LogFile return command }