|
|
|
@@ -50,10 +50,13 @@ parser.add_argument('--epoch_size', |
|
|
|
if __name__ == "__main__": |
|
|
|
###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题 |
|
|
|
args, unknown = parser.parse_known_args() |
|
|
|
|
|
|
|
MnistDataset_mindspore_path = '' |
|
|
|
Mindspore_MNIST_Example_Model_path = '' |
|
|
|
output_path = '' |
|
|
|
|
|
|
|
device_num = int(os.getenv('RANK_SIZE')) |
|
|
|
#使用多卡时 |
|
|
|
# set device_id and init for multi-card training |
|
|
|
# set device_id and init for multi-card training |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) |
|
|
|
|