|
|
|
@@ -13,25 +13,27 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""cell""" |
|
|
|
import inspect |
|
|
|
import time |
|
|
|
import gc |
|
|
|
import inspect |
|
|
|
import os |
|
|
|
import time |
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
import numpy |
|
|
|
|
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT |
|
|
|
from .. import context |
|
|
|
from .._c_expression import init_backend, Cell_ |
|
|
|
from .._checkparam import Validator |
|
|
|
from ..common import dtype as mstype |
|
|
|
from ..common.api import _executor, _pynative_exec |
|
|
|
from .._checkparam import Validator |
|
|
|
from ..common.parameter import Parameter, ParameterTuple |
|
|
|
from .._c_expression import init_backend, Cell_ |
|
|
|
from ..ops.primitive import Primitive |
|
|
|
from ..ops.operations import HookBackward |
|
|
|
from ..common.tensor import Tensor, MetaTensor |
|
|
|
from ..ops.functional import cast |
|
|
|
from ..ops.operations import HookBackward |
|
|
|
from ..ops.primitive import Primitive |
|
|
|
from ..parallel._tensor import _load_tensor_by_layout |
|
|
|
from ..common.tensor import Tensor, MetaTensor |
|
|
|
|
|
|
|
|
|
|
|
class Cell(Cell_): |
|
|
|
@@ -393,10 +395,10 @@ class Cell(Cell_): |
|
|
|
raise AttributeError("Can not assign params before Cell.__init__() call.") |
|
|
|
if name in self.__dict__: |
|
|
|
if self.__dict__[name] is not None: |
|
|
|
raise TypeError("Expected type is not in (Parameter, Cell), but got Parameter.") |
|
|
|
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.") |
|
|
|
del self.__dict__[name] |
|
|
|
if cells and name in cells: |
|
|
|
raise TypeError("Expected type is Cell, but got Parameter.") |
|
|
|
raise TypeError("The type of value should be Cell, but got Parameter.") |
|
|
|
self.insert_param_to_cell(name, value) |
|
|
|
elif isinstance(value, ParameterTuple): |
|
|
|
if params is None: |
|
|
|
@@ -417,7 +419,7 @@ class Cell(Cell_): |
|
|
|
if name in self.__dict__: |
|
|
|
del self.__dict__[name] |
|
|
|
if params and name in params: |
|
|
|
raise TypeError("Expected type is Parameter, but got Cell.") |
|
|
|
raise TypeError("The type of value should be Parameter, but got Cell.") |
|
|
|
if self._auto_prefix: |
|
|
|
value.update_parameters_name(name + '.') |
|
|
|
cells[name] = value |
|
|
|
@@ -427,12 +429,13 @@ class Cell(Cell_): |
|
|
|
if isinstance(value, Tensor) and self._params[name] is not None: |
|
|
|
self._params[name].set_data(value) |
|
|
|
elif value is not None: |
|
|
|
raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value))) |
|
|
|
raise TypeError(f"The type of value should be Parameter or ParameterTuple, " |
|
|
|
f"but got {type(value).__name__}.") |
|
|
|
else: |
|
|
|
self.insert_param_to_cell(name, None) |
|
|
|
elif cells and name in cells: |
|
|
|
if value is not None: |
|
|
|
raise TypeError("Expected type is cell, but got {}.".format(type(value))) |
|
|
|
raise TypeError(f"The type of value should be cell, but got {type(value).__name__}.") |
|
|
|
self._cells[name] = None |
|
|
|
elif isinstance(value, Tensor): |
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE: |
|
|
|
@@ -705,6 +708,7 @@ class Cell(Cell_): |
|
|
|
new_p = param.init_data(layout, set_sliced=set_sliced) |
|
|
|
replace[param] = new_p |
|
|
|
return new_p |
|
|
|
|
|
|
|
# replace all original usage. |
|
|
|
cells = self.cells_and_names() |
|
|
|
for _, cell in cells: |
|
|
|
|