Browse Source

!5148 change group conv dtype in gpu resnext50

Merge pull request !5148 from zhaoting/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
40a3a7146f
2 changed files with 2 additions and 6 deletions
  1. +0
    -3
      model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py
  2. +2
    -3
      model_zoo/official/cv/resnext50/train.py

+ 0
- 3
model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py View File

@@ -44,9 +44,6 @@ def auto_mixed_precision(network):
elif name == 'fc':
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
change = True
elif name == 'conv2':
subcell.to_float(mstype.float32)
change = True
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
change = True


+ 2
- 3
model_zoo/official/cv/resnext50/train.py View File

@@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.utils.logging import get_logger
from src.utils.optimizers__init__ import get_param_groups
from src.image_classification import get_network
from src.utils.auto_mixed_precision import auto_mixed_precision
from src.config import config


@@ -273,8 +272,8 @@ def train(cloud_args=None):
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O3")
else:
auto_mixed_precision(network)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O2")

# checkpoint save
progress_cb = ProgressMonitor(args)


Loading…
Cancel
Save