| @@ -19,6 +19,7 @@ | |||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| from config import mnist_cfg as cfg | from config import mnist_cfg as cfg | ||||
| from dataset import create_dataset | |||||
| from dataset_distributed import create_dataset_parallel | from dataset_distributed import create_dataset_parallel | ||||
| from lenet import LeNet5 | from lenet import LeNet5 | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -29,7 +30,6 @@ from mindspore.train import Model | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.communication.management import init, get_rank | from mindspore.communication.management import init, get_rank | ||||
| import time | import time | ||||
| #导入openi包 | |||||
| from c2net.context import prepare, upload_output | from c2net.context import prepare, upload_output | ||||
| @@ -50,10 +50,7 @@ parser.add_argument('--epoch_size', | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 | ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 | ||||
| args, unknown = parser.parse_known_args() | args, unknown = parser.parse_known_args() | ||||
| MnistDataset_mindspore_path = '' | |||||
| Mindspore_MNIST_Example_Model_path = '' | |||||
| output_path = '' | |||||
| device_num = int(os.getenv('RANK_SIZE')) | device_num = int(os.getenv('RANK_SIZE')) | ||||
| #使用多卡时 | #使用多卡时 | ||||
| # set device_id and init for multi-card training | # set device_id and init for multi-card training | ||||
| @@ -63,32 +60,19 @@ if __name__ == "__main__": | |||||
| init() | init() | ||||
| #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data | #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data | ||||
| local_rank=int(os.getenv('RANK_ID')) | local_rank=int(os.getenv('RANK_ID')) | ||||
| if local_rank%8==0: | |||||
| #初始化导入数据集和预训练模型到容器内 | |||||
| c2net_context = prepare() | |||||
| #获取数据集路径 | |||||
| MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" | |||||
| #获取预训练模型路径 | |||||
| Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model" | |||||
| output_path = c2net_context.output_path | |||||
| #Set a cache file to determine whether the data has been copied to obs. | |||||
| #If this file exists during multi-card training, there is no need to copy the dataset multiple times. | |||||
| f = open("/cache/download_input.txt", 'w') | |||||
| f.close() | |||||
| try: | |||||
| if os.path.exists("/cache/download_input.txt"): | |||||
| print("download_input succeed") | |||||
| except Exception as e: | |||||
| print("download_input failed") | |||||
| while not os.path.exists("/cache/download_input.txt"): | |||||
| time.sleep(1) | |||||
| ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) | |||||
| #初始化导入数据集和预训练模型到容器内 | |||||
| c2net_context = prepare() | |||||
| #获取数据集路径 | |||||
| MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore" | |||||
| #获取预训练模型路径 | |||||
| Mindspore_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Mindspore_MNIST_Example_Model" | |||||
| output_path = c2net_context.output_path | |||||
| ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size) | |||||
| network = LeNet5(cfg.num_classes) | network = LeNet5(cfg.num_classes) | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | ||||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | ||||
| load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) | |||||
| #load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) | |||||
| if args.device_target != "Ascend": | if args.device_target != "Ascend": | ||||
| model = Model(network, | model = Model(network, | ||||
| net_loss, | net_loss, | ||||