|
|
|
@@ -55,8 +55,12 @@ if __name__ == "__main__": |
|
|
|
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1)) |
|
|
|
if args.dataset_name == "cifar10": |
|
|
|
cfg = alexnet_cifar10_cfg |
|
|
|
if device_num > 1: |
|
|
|
cfg.learning_rate = cfg.learning_rate * device_num |
|
|
|
cfg.epoch_size = cfg.epoch_size * 2 |
|
|
|
elif args.dataset_name == "imagenet": |
|
|
|
cfg = alexnet_imagenet_cfg |
|
|
|
else: |
|
|
|
@@ -65,14 +69,11 @@ if __name__ == "__main__": |
|
|
|
device_target = args.device_target |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) |
|
|
|
context.set_context(save_graphs=False) |
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1)) |
|
|
|
|
|
|
|
if device_target == "Ascend": |
|
|
|
context.set_context(device_id=args.device_id) |
|
|
|
|
|
|
|
if device_num > 1: |
|
|
|
cfg.learning_rate = cfg.learning_rate * device_num |
|
|
|
cfg.epoch_size = cfg.epoch_size * 2 |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True) |
|
|
|
|