From bc3c4701d2dca5ff6af46ca90901b16ed874afae Mon Sep 17 00:00:00 2001 From: liuzx Date: Fri, 5 Jan 2024 16:48:06 +0800 Subject: [PATCH] update --- npu_mnist_example/train_npu.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/npu_mnist_example/train_npu.py b/npu_mnist_example/train_npu.py index c1c776c..5546bc3 100644 --- a/npu_mnist_example/train_npu.py +++ b/npu_mnist_example/train_npu.py @@ -15,10 +15,7 @@ 1、在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 2、用户需要调用c2net的python sdk包 """ - import os -os.system("pip install openi-test") -os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH"))) import argparse from config import mnist_cfg as cfg from dataset import create_dataset @@ -56,10 +53,9 @@ if __name__ == "__main__": #获取预训练模型路径 mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts" - device_num = int(os.getenv('RANK_SIZE')) context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) #使用数据集的方式 - ds_train = create_dataset(os.path.join(mnistdata_path + "/MNISTData", "train"), cfg.batch_size) + ds_train = create_dataset(os.path.join(mnistdata_path, "train"), cfg.batch_size) network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)