Browse Source

!1717 fix bug introduced by gpu support

Merge pull request !1717 from gengdongjie/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
9eee55969a
5 changed files with 9 additions and 5 deletions
  1. +1
    -1
      example/resnet101_imagenet2012/config.py
  2. +1
    -1
      example/resnet50_cifar10/config.py
  3. +4
    -2
      example/resnet50_cifar10/train.py
  4. +1
    -1
      example/resnet50_imagenet2012/config.py
  5. +2
    -0
      example/resnet50_imagenet2012/train.py

+ 1
- 1
example/resnet101_imagenet2012/config.py View File

@@ -29,7 +29,7 @@ config = ed({
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10, "keep_checkpoint_max": 10,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"warmup_epochs": 0, "warmup_epochs": 0,


+ 1
- 1
example/resnet50_cifar10/config.py View File

@@ -28,7 +28,7 @@ config = ed({
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_steps": 1950,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10, "keep_checkpoint_max": 10,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"warmup_epochs": 5, "warmup_epochs": 5,


+ 4
- 2
example/resnet50_cifar10/train.py View File

@@ -43,6 +43,8 @@ args_opt = parser.parse_args()


if __name__ == '__main__': if __name__ == '__main__':
target = args_opt.device_target target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
if target == "Ascend": if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
@@ -80,13 +82,13 @@ if __name__ == '__main__':
else: else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True)
amp_level="O2", keep_batchnorm_fp32=False)


time_cb = TimeMonitor(data_size=step_size) time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor() loss_cb = LossMonitor()
cb = [time_cb, loss_cb] cb = [time_cb, loss_cb]
if config.save_checkpoint: if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]


+ 1
- 1
example/resnet50_imagenet2012/config.py View File

@@ -29,7 +29,7 @@ config = ed({
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10, "keep_checkpoint_max": 10,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"warmup_epochs": 0, "warmup_epochs": 0,


+ 2
- 0
example/resnet50_imagenet2012/train.py View File

@@ -46,6 +46,8 @@ args_opt = parser.parse_args()


if __name__ == '__main__': if __name__ == '__main__':
target = args_opt.device_target target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
if target == "Ascend": if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))


Loading…
Cancel
Save