Browse Source

add keep_bn_fp32 parameter

tags/v0.2.0-alpha
liubuyu 6 years ago
parent
commit
672244e0ac
4 changed files with 14 additions and 8 deletions
  1. +3
    -3
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc
  2. +1
    -1
      mindspore/nn/optim/optimizer.py
  3. +9
    -3
      mindspore/train/model.py
  4. +1
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py

+ 3
- 3
mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc View File

@@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
auto prim = std::make_shared<Primitive>(kFusedMulAddNOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
inputs.push_back(mul->input(kMulInputNum - lossscale_input_index));
inputs.push_back(addn->input(1));
inputs.push_back(addn->input(2));
// scalar input should be 3rd input
inputs.push_back(mul->input(lossscale_input_index));
auto fusion_node = graph->NewCNode(inputs);
@@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const {
VarPtr Z = std::make_shared<Var>();

VectorRef mul({prim::kPrimMul, X, Z});
VectorRef addn({prim::kPrimAddN, Y, mul});
VectorRef addn({prim::kPrimAddN, mul, Y});
return addn;
}

@@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
if (addn == nullptr || addn->inputs().size() != kAddNInputNum) {
return nullptr;
}
auto mul_anf = addn->input(2);
auto mul_anf = addn->input(1);
if (mul_anf == nullptr) {
return nullptr;
}


+ 1
- 1
mindspore/nn/optim/optimizer.py View File

@@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((gradient, weight * weight_decay))
return op_add((weight * weight_decay, gradient))
return gradient




+ 9
- 3
mindspore/train/model.py View File

@@ -62,6 +62,7 @@ class Model:
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.

Examples:
>>> class Net(nn.Cell):
@@ -96,7 +97,10 @@ class Model:
self._optimizer = optimizer
self._loss_scale_manager = None
self._loss_scale_manager_set = False
self._keep_bn_fp32 = True
self._check_kwargs(kwargs)
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
@@ -112,7 +116,7 @@ class Model:

def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager']:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
raise ValueError(f"Unsupport arg '{arg}'")

def _build_train_network(self):
@@ -124,12 +128,14 @@ class Model:
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager)
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level)
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None


+ 1
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py View File

@@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag):
@fns
def before(a, b):
res = mul(scalar, a)
res = addn((b, res))
res = addn((res, b))
return res

@fns


Loading…
Cancel
Save