Browse Source

update

liuzx-patch-1
liuzx 2 years ago
parent
commit
bc3c4701d2
1 changed files with 1 additions and 5 deletions
  1. +1
    -5
      npu_mnist_example/train_npu.py

+ 1
- 5
npu_mnist_example/train_npu.py View File

@@ -15,10 +15,7 @@
1、在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
2、用户需要调用c2net的python sdk包
"""

import os
os.system("pip install openi-test")
os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH")))
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
@@ -56,10 +53,9 @@ if __name__ == "__main__":
#获取预训练模型路径
mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts"
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
#使用数据集的方式
ds_train = create_dataset(os.path.join(mnistdata_path + "/MNISTData", "train"), cfg.batch_size)
ds_train = create_dataset(os.path.join(mnistdata_path, "train"), cfg.batch_size)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)


Loading…
Cancel
Save