Browse Source

!8924 modify param

From: @wukesong
Reviewed-by: @yingjy,@oacjiewen
Signed-off-by: @yingjy
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0e58beea2f
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      model_zoo/official/cv/alexnet/train.py

+ 4
- 3
model_zoo/official/cv/alexnet/train.py View File

@@ -55,8 +55,12 @@ if __name__ == "__main__":
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: 0)')
args = parser.parse_args()

device_num = int(os.environ.get("DEVICE_NUM", 1))
if args.dataset_name == "cifar10":
cfg = alexnet_cifar10_cfg
if device_num > 1:
cfg.learning_rate = cfg.learning_rate * device_num
cfg.epoch_size = cfg.epoch_size * 2
elif args.dataset_name == "imagenet":
cfg = alexnet_imagenet_cfg
else:
@@ -65,14 +69,11 @@ if __name__ == "__main__":
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=False)
device_num = int(os.environ.get("DEVICE_NUM", 1))

if device_target == "Ascend":
context.set_context(device_id=args.device_id)

if device_num > 1:
cfg.learning_rate = cfg.learning_rate * device_num
cfg.epoch_size = cfg.epoch_size * 2
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)


Loading…
Cancel
Save