|
|
|
@@ -15,10 +15,7 @@ |
|
|
|
1、在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 |
|
|
|
2、用户需要调用c2net的python sdk包 |
|
|
|
""" |
|
|
|
|
|
|
|
import os |
|
|
|
os.system("pip install openi-test") |
|
|
|
os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH"))) |
|
|
|
import argparse |
|
|
|
from config import mnist_cfg as cfg |
|
|
|
from dataset import create_dataset |
|
|
|
@@ -56,10 +53,9 @@ if __name__ == "__main__": |
|
|
|
#获取预训练模型路径 |
|
|
|
mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts" |
|
|
|
|
|
|
|
device_num = int(os.getenv('RANK_SIZE')) |
|
|
|
context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) |
|
|
|
#使用数据集的方式 |
|
|
|
ds_train = create_dataset(os.path.join(mnistdata_path + "/MNISTData", "train"), cfg.batch_size) |
|
|
|
ds_train = create_dataset(os.path.join(mnistdata_path, "train"), cfg.batch_size) |
|
|
|
network = LeNet5(cfg.num_classes) |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) |
|
|
|
|