| @@ -15,6 +15,7 @@ | |||||
| """builtin_operations""" | """builtin_operations""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype | ||||
| @@ -173,11 +174,11 @@ def stop_gradient(x): | |||||
| """Implement `stop_gradient`.""" | """Implement `stop_gradient`.""" | ||||
| return x | return x | ||||
| hyper_map = C.HyperMap() | |||||
| def mixed_precision_cast(dst_type, x): | def mixed_precision_cast(dst_type, x): | ||||
| """Implement `mixed_precision_cast`.""" | """Implement `mixed_precision_cast`.""" | ||||
| if isinstance(x, tuple): | |||||
| res = list() | |||||
| for item in x: | |||||
| res.append(F.cast(item, dst_type)) | |||||
| return tuple(res) | |||||
| return F.cast(x, dst_type) | |||||
| def cast_inner(data): | |||||
| return F.cast(data, dst_type) | |||||
| return hyper_map(cast_inner, x) | |||||
| @@ -61,6 +61,7 @@ class Parameter: | |||||
| self._is_init = False | self._is_init = False | ||||
| self._sliced = False | self._sliced = False | ||||
| self.is_param_ps = False | self.is_param_ps = False | ||||
| self._cast_type = None | |||||
| self.init_in_server = False | self.init_in_server = False | ||||
| if context.get_context("mode") == context.PYNATIVE_MODE: | if context.get_context("mode") == context.PYNATIVE_MODE: | ||||
| self.init_data() | self.init_data() | ||||
| @@ -103,6 +104,16 @@ class Parameter: | |||||
| raise ValueError("The type of the name should be `str` or `None`.") | raise ValueError("The type of the name should be `str` or `None`.") | ||||
| self._value.name = name_ | self._value.name = name_ | ||||
| @property | |||||
| def cast_type(self): | |||||
| return self._cast_type | |||||
| @cast_type.setter | |||||
| def cast_type(self, dst_type): | |||||
| if dst_type not in (mstype.float16, mstype.float32, None): | |||||
| raise ValueError("The type of the name should be type of [float32, float16] or `None`.") | |||||
| self._cast_type = dst_type | |||||
| @property | @property | ||||
| def sliced(self): | def sliced(self): | ||||
| """Get slice status of the parameter.""" | """Get slice status of the parameter.""" | ||||
| @@ -278,7 +278,7 @@ class SparseTensor: | |||||
| Returns: | Returns: | ||||
| SparseTensor, composed of `indices`, `values`, `dense_shape`. | SparseTensor, composed of `indices`, `values`, `dense_shape`. | ||||
| Examples: | |||||
| Examples: | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| >>> def __init__(self, dense_shape): | >>> def __init__(self, dense_shape): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| @@ -286,6 +286,8 @@ class Cell: | |||||
| if context.get_context("mode") == context.PYNATIVE_MODE: | if context.get_context("mode") == context.PYNATIVE_MODE: | ||||
| if name in self.__dict__: | if name in self.__dict__: | ||||
| del self.__dict__[name] | del self.__dict__[name] | ||||
| if name in params: | |||||
| del params[name] | |||||
| params_list[name] = value | params_list[name] = value | ||||
| else: | else: | ||||
| object.__setattr__(self, name, value) | object.__setattr__(self, name, value) | ||||
| @@ -499,9 +501,11 @@ class Cell: | |||||
| """ | """ | ||||
| if hasattr(self, "_mindspore_flags"): | if hasattr(self, "_mindspore_flags"): | ||||
| if self._mindspore_flags.get('fp16'): | if self._mindspore_flags.get('fp16'): | ||||
| return cast(param, mstype.float16) | |||||
| if self._mindspore_flags.get('fp32'): | |||||
| return cast(param, mstype.float32) | |||||
| param.cast_type = mstype.float16 | |||||
| elif self._mindspore_flags.get('fp32'): | |||||
| param.cast_type = mstype.float32 | |||||
| else: | |||||
| param.cast_type = None | |||||
| return param | return param | ||||
| def insert_child_to_cell(self, child_name, child): | def insert_child_to_cell(self, child_name, child): | ||||
| @@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge) | |||||
| tensor_operator_registry.register('shape', shape) | tensor_operator_registry.register('shape', shape) | ||||
| #support GE backend for no compare operators | #support GE backend for no compare operators | ||||
| tensor_operator_registry.register('vm_compare', BP.vm_compare) | tensor_operator_registry.register('vm_compare', BP.vm_compare) | ||||
| tensor_operator_registry.register('cast', cast) | |||||
| @@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive): | |||||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | ||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | ||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | ||||
| self._update_parameter = True | |||||
| class BNTrainingReduce(PrimitiveWithInfer): | class BNTrainingReduce(PrimitiveWithInfer): | ||||
| @@ -18,6 +18,8 @@ | |||||
| import inspect | import inspect | ||||
| import copy | import copy | ||||
| from mindspore.common.api import _wrap_func | from mindspore.common.api import _wrap_func | ||||
| from mindspore.common import Parameter | |||||
| from mindspore.common._register_for_tensor import tensor_operator_registry | |||||
| from .._c_expression import Primitive_, real_run_op, prim_type | from .._c_expression import Primitive_, real_run_op, prim_type | ||||
| from .._c_expression import signature_rw as sig_rw | from .._c_expression import signature_rw as sig_rw | ||||
| from .._c_expression import signature_kind as sig_kind | from .._c_expression import signature_kind as sig_kind | ||||
| @@ -49,6 +51,7 @@ class Primitive(Primitive_): | |||||
| self.name = name | self.name = name | ||||
| self.attrs = {} | self.attrs = {} | ||||
| self.init_attrs = {"name": name} | self.init_attrs = {"name": name} | ||||
| self._update_parameter = False | |||||
| Primitive_.__init__(self, name, self) | Primitive_.__init__(self, name, self) | ||||
| if hasattr(self.__class__, '__mindspore_signature__'): | if hasattr(self.__class__, '__mindspore_signature__'): | ||||
| sig = self._fill_signature(self.__class__.__mindspore_signature__) | sig = self._fill_signature(self.__class__.__mindspore_signature__) | ||||
| @@ -189,6 +192,11 @@ class Primitive(Primitive_): | |||||
| # for checking output number with kernel implementation | # for checking output number with kernel implementation | ||||
| self.add_prim_attr("output_names", outputs) | self.add_prim_attr("output_names", outputs) | ||||
| @property | |||||
| def update_parameter(self): | |||||
| """ Whether the primitive will update the value of parameter.""" | |||||
| return self._update_parameter | |||||
| class PrimitiveWithInfer(Primitive): | class PrimitiveWithInfer(Primitive): | ||||
| """ | """ | ||||
| @@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None): | |||||
| @_wrap_func | @_wrap_func | ||||
| def _run_op(obj, op_name, args): | def _run_op(obj, op_name, args): | ||||
| """Single op execution function supported by ge in PyNative mode.""" | """Single op execution function supported by ge in PyNative mode.""" | ||||
| output = real_run_op(obj, op_name, args) | |||||
| cast = tensor_operator_registry.get("cast") | |||||
| if op_name == "Cast" or obj.update_parameter: | |||||
| cast_args = args | |||||
| else: | |||||
| cast_args = list() | |||||
| for arg in args: | |||||
| if isinstance(arg, Parameter): | |||||
| if arg.cast_type: | |||||
| cast_args.append(cast(arg, arg.cast_type)) | |||||
| else: | |||||
| cast_args.append(arg) | |||||
| else: | |||||
| cast_args.append(arg) | |||||
| output = real_run_op(obj, op_name, tuple(cast_args)) | |||||
| if not output: | if not output: | ||||
| raise RuntimeError("Pynative run op %s failed!" % op_name) | raise RuntimeError("Pynative run op %s failed!" % op_name) | ||||
| if len(output) == 1: | if len(output) == 1: | ||||