|
|
@@ -26,6 +26,7 @@ from ..common.parameter import Parameter, ParameterTuple |
|
|
from .._c_expression import init_backend |
|
|
from .._c_expression import init_backend |
|
|
from ..ops.primitive import Primitive |
|
|
from ..ops.primitive import Primitive |
|
|
from ..ops.operations import HookBackward |
|
|
from ..ops.operations import HookBackward |
|
|
|
|
|
from ..ops.functional import cast |
|
|
from ..parallel._tensor import _load_tensor_by_layout |
|
|
from ..parallel._tensor import _load_tensor_by_layout |
|
|
from ..common.tensor import Tensor |
|
|
from ..common.tensor import Tensor |
|
|
|
|
|
|
|
|
@@ -60,6 +61,7 @@ class Cell: |
|
|
def __init__(self, auto_prefix=True, flags=None): |
|
|
def __init__(self, auto_prefix=True, flags=None): |
|
|
self._params = OrderedDict() |
|
|
self._params = OrderedDict() |
|
|
self._cells = OrderedDict() |
|
|
self._cells = OrderedDict() |
|
|
|
|
|
self._params_list = OrderedDict() |
|
|
self.training = False |
|
|
self.training = False |
|
|
self.requires_grad = False |
|
|
self.requires_grad = False |
|
|
self.pynative = False |
|
|
self.pynative = False |
|
|
@@ -188,11 +190,22 @@ class Cell: |
|
|
if '_params' in self.__dict__: |
|
|
if '_params' in self.__dict__: |
|
|
params = self.__dict__['_params'] |
|
|
params = self.__dict__['_params'] |
|
|
if name in params: |
|
|
if name in params: |
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE: |
|
|
|
|
|
return self.cast_param(params[name]) |
|
|
return params[name] |
|
|
return params[name] |
|
|
if '_cells' in self.__dict__: |
|
|
if '_cells' in self.__dict__: |
|
|
cells = self.__dict__['_cells'] |
|
|
cells = self.__dict__['_cells'] |
|
|
if name in cells: |
|
|
if name in cells: |
|
|
return cells[name] |
|
|
return cells[name] |
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE and '_params_list' in self.__dict__: |
|
|
|
|
|
params_list = self.__dict__['_params_list'] |
|
|
|
|
|
if name in params_list: |
|
|
|
|
|
para_list = params_list[name] |
|
|
|
|
|
cast_list = list() |
|
|
|
|
|
for para in para_list: |
|
|
|
|
|
cast_list.append(self.cast_param(para)) |
|
|
|
|
|
para_list = ParameterTuple(cast_list) |
|
|
|
|
|
return para_list |
|
|
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) |
|
|
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name)) |
|
|
|
|
|
|
|
|
def __del__(self): |
|
|
def __del__(self): |
|
|
@@ -225,10 +238,21 @@ class Cell: |
|
|
cell.set_grad(True) |
|
|
cell.set_grad(True) |
|
|
else: |
|
|
else: |
|
|
_pynative_exec.set_grad_flag(False) |
|
|
_pynative_exec.set_grad_flag(False) |
|
|
|
|
|
cast_inputs = list() |
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): |
|
|
|
|
|
for item in inputs: |
|
|
|
|
|
cast_inputs.append(cast(item, mstype.float16)) |
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): |
|
|
|
|
|
for item in inputs: |
|
|
|
|
|
cast_inputs.append(cast(item, mstype.float32)) |
|
|
|
|
|
if cast_inputs: |
|
|
|
|
|
cast_inputs = tuple(cast_inputs) |
|
|
|
|
|
else: |
|
|
|
|
|
cast_inputs = inputs |
|
|
if self.enable_hook: |
|
|
if self.enable_hook: |
|
|
output = self._hook_construct(*inputs) |
|
|
|
|
|
|
|
|
output = self._hook_construct(*cast_inputs) |
|
|
else: |
|
|
else: |
|
|
output = self.construct(*inputs) |
|
|
|
|
|
|
|
|
output = self.construct(*cast_inputs) |
|
|
if isinstance(output, Parameter): |
|
|
if isinstance(output, Parameter): |
|
|
output = output.data |
|
|
output = output.data |
|
|
if self.requires_grad is True: |
|
|
if self.requires_grad is True: |
|
|
@@ -241,6 +265,7 @@ class Cell: |
|
|
def __setattr__(self, name, value): |
|
|
def __setattr__(self, name, value): |
|
|
cells = self.__dict__.get('_cells') |
|
|
cells = self.__dict__.get('_cells') |
|
|
params = self.__dict__.get('_params') |
|
|
params = self.__dict__.get('_params') |
|
|
|
|
|
params_list = self.__dict__.get('_params_list') |
|
|
if isinstance(value, Parameter): |
|
|
if isinstance(value, Parameter): |
|
|
if params is None: |
|
|
if params is None: |
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.") |
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.") |
|
|
@@ -256,7 +281,12 @@ class Cell: |
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.") |
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.") |
|
|
for item in value: |
|
|
for item in value: |
|
|
self.insert_param_to_cell(item.name, item, check_name=False) |
|
|
self.insert_param_to_cell(item.name, item, check_name=False) |
|
|
object.__setattr__(self, name, value) |
|
|
|
|
|
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE: |
|
|
|
|
|
if name in self.__dict__: |
|
|
|
|
|
del self.__dict__[name] |
|
|
|
|
|
params_list[name] = value |
|
|
|
|
|
else: |
|
|
|
|
|
object.__setattr__(self, name, value) |
|
|
elif isinstance(value, Cell): |
|
|
elif isinstance(value, Cell): |
|
|
if cells is None: |
|
|
if cells is None: |
|
|
raise AttributeError("Can not assign cells before Cell.__init__() call.") |
|
|
raise AttributeError("Can not assign cells before Cell.__init__() call.") |
|
|
@@ -458,6 +488,19 @@ class Cell: |
|
|
raise TypeError("The type of parameter should be 'Parameter' if not None.") |
|
|
raise TypeError("The type of parameter should be 'Parameter' if not None.") |
|
|
self._params[param_name] = param |
|
|
self._params[param_name] = param |
|
|
|
|
|
|
|
|
|
|
|
def cast_param(self, param): |
|
|
|
|
|
""" |
|
|
|
|
|
Cast parameter according to auto mix precison level in pynative mode. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
param (Parameter): The parameter to cast. |
|
|
|
|
|
""" |
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp16'): |
|
|
|
|
|
return cast(param, mstype.float16) |
|
|
|
|
|
if hasattr(self, "_mindspore_flags") and self._mindspore_flags.get('fp32'): |
|
|
|
|
|
return cast(param, mstype.float32) |
|
|
|
|
|
return param |
|
|
|
|
|
|
|
|
def insert_child_to_cell(self, child_name, child): |
|
|
def insert_child_to_cell(self, child_name, child): |
|
|
""" |
|
|
""" |
|
|
Adds a child cell to the current cell. |
|
|
Adds a child cell to the current cell. |
|
|
|