Browse Source

fix batchnorm under mix precision in pynative mode

tags/v0.7.0-beta
kingfo 5 years ago
parent
commit
fab9fac109
7 changed files with 50 additions and 11 deletions
  1. +7
    -6
      mindspore/_extends/builtin_operations.py
  2. +11
    -0
      mindspore/common/parameter.py
  3. +1
    -1
      mindspore/common/tensor.py
  4. +7
    -3
      mindspore/nn/cell.py
  5. +1
    -0
      mindspore/ops/functional.py
  6. +1
    -0
      mindspore/ops/operations/nn_ops.py
  7. +22
    -1
      mindspore/ops/primitive.py

+ 7
- 6
mindspore/_extends/builtin_operations.py View File

@@ -15,6 +15,7 @@
"""builtin_operations"""
import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype

@@ -173,11 +174,11 @@ def stop_gradient(x):
"""Implement `stop_gradient`."""
return x


hyper_map = C.HyperMap()

def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
if isinstance(x, tuple):
res = list()
for item in x:
res.append(F.cast(item, dst_type))
return tuple(res)
return F.cast(x, dst_type)
def cast_inner(data):
return F.cast(data, dst_type)
return hyper_map(cast_inner, x)

+ 11
- 0
mindspore/common/parameter.py View File

@@ -61,6 +61,7 @@ class Parameter:
self._is_init = False
self._sliced = False
self.is_param_ps = False
self._cast_type = None
self.init_in_server = False
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
@@ -103,6 +104,16 @@ class Parameter:
raise ValueError("The type of the name should be `str` or `None`.")
self._value.name = name_

@property
def cast_type(self):
return self._cast_type

@cast_type.setter
def cast_type(self, dst_type):
if dst_type not in (mstype.float16, mstype.float32, None):
raise ValueError("The type of the name should be type of [float32, float16] or `None`.")
self._cast_type = dst_type

@property
def sliced(self):
"""Get slice status of the parameter."""


+ 1
- 1
mindspore/common/tensor.py View File

@@ -278,7 +278,7 @@ class SparseTensor:
Returns:
SparseTensor, composed of `indices`, `values`, `dense_shape`.

Examples:
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self, dense_shape):
>>> super(Net, self).__init__()


+ 7
- 3
mindspore/nn/cell.py View File

@@ -286,6 +286,8 @@ class Cell:
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value
else:
object.__setattr__(self, name, value)
@@ -499,9 +501,11 @@ class Cell:
"""
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp16'):
return cast(param, mstype.float16)
if self._mindspore_flags.get('fp32'):
return cast(param, mstype.float32)
param.cast_type = mstype.float16
elif self._mindspore_flags.get('fp32'):
param.cast_type = mstype.float32
else:
param.cast_type = None
return param

def insert_child_to_cell(self, child_name, child):


+ 1
- 0
mindspore/ops/functional.py View File

@@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)
tensor_operator_registry.register('cast', cast)

+ 1
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive):
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self._update_parameter = True


class BNTrainingReduce(PrimitiveWithInfer):


+ 22
- 1
mindspore/ops/primitive.py View File

@@ -18,6 +18,8 @@
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry
from .._c_expression import Primitive_, real_run_op, prim_type
from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind
@@ -49,6 +51,7 @@ class Primitive(Primitive_):
self.name = name
self.attrs = {}
self.init_attrs = {"name": name}
self._update_parameter = False
Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__)
@@ -189,6 +192,11 @@ class Primitive(Primitive_):
# for checking output number with kernel implementation
self.add_prim_attr("output_names", outputs)

@property
def update_parameter(self):
""" Whether the primitive will update the value of parameter."""
return self._update_parameter


class PrimitiveWithInfer(Primitive):
"""
@@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
output = real_run_op(obj, op_name, args)
cast = tensor_operator_registry.get("cast")
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = list()
for arg in args:
if isinstance(arg, Parameter):
if arg.cast_type:
cast_args.append(cast(arg, arg.cast_type))
else:
cast_args.append(arg)
else:
cast_args.append(arg)
output = real_run_op(obj, op_name, tuple(cast_args))
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:


Loading…
Cancel
Save