diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 1909173a89..89b0288b86 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -440,59 +440,19 @@ class _Executor: Str, the full phase of the cell. Bool, if the graph has been compiled before, return False, else return True. """ - from mindspore import nn - from mindspore.ops.composite import GradOperation - - class InputsToAttrCell(nn.Cell): - """The cell that converts non-tensor inputs to attr.""" - - def __init__(self, net, args_names, non_tensor_inputs): - super(InputsToAttrCell, self).__init__() - self.net = net - self.args_names = args_names - self.non_tensor_inputs = non_tensor_inputs - self.inputs_to_attr = True - - def construct(self, *tensor_inputs): - real_inputs = () - index = 0 - for i in args_names: - if i in self.non_tensor_inputs.keys(): - real_inputs += (self.non_tensor_inputs[i],) - else: - real_inputs += (tensor_inputs[index],) - index += 1 - return self.net(*real_inputs) args_names, args_list = _generate_pip_args(obj, *args) - if not hasattr(obj, "inputs_to_attr"): - dic = dict(zip(args_names, args_list)) - key = generate_key(phase, dic) - obj.phase_prefix = str(key[1]) - if 'export' in phase: - phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) - else: - phase = obj.phase_prefix + phase + '.' + str(obj.create_time) - - if phase in self.compile_cache.keys(): - logger.debug("%r graph has existed.", phase) - return phase, False - - if getattr(obj, "support_non_tensor_inputs", None): - for i in obj.__dict__.values(): - if isinstance(i, GradOperation): - raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, " - "only support forward net.") - attrs = {} - inputs = [] - for key, value in dic.items(): - if not isinstance(value, (Tensor, MetaTensor)): - attrs[key] = value - else: - inputs.append(value) - if attrs: - inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) - return self.compile(inputs_to_attr_cell, *inputs, phase=phase) + dic = dict(zip(args_names, args_list)) + key = generate_key(phase, dic) + obj.phase_prefix = str(key[1]) + if 'export' in phase: + phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) + else: + phase = obj.phase_prefix + phase + '.' + str(obj.create_time) + + if phase in self.compile_cache.keys(): + logger.debug("%r graph has existed.", phase) + return phase, False obj.check_names() _check_full_batch() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 301126504c..370941a4e7 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -107,7 +107,6 @@ class Cell(Cell_): self._bprop_debug = False self.cell_type = None self._auto_parallel_compile_and_run = False - self._support_non_tensor_inputs = False def __getstate__(self): base = Cell_.__getstate__(self) @@ -119,27 +118,6 @@ class Cell(Cell_): self.__dict__ = dict_ self._attr_synced = False - @property - def support_non_tensor_inputs(self): - """ - Whether support non tensor inputs in outermost net in GRAPH MODE. - This property only used in forward net, and is not supported in grad net. - The default value of the property is the `False`, that is, - it does not support passing non tensor inputs to the outermost net. - If you want to support, set the property to the `True`. - - """ - return self._support_non_tensor_inputs - - @support_non_tensor_inputs.setter - def support_non_tensor_inputs(self, value): - """ - Set attr 'support_non_tensor_inputs'. - """ - if not isinstance(value, bool): - raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.") - self._support_non_tensor_inputs = value - @property def _cell_tag(self): # `` to `xxxxxxx` @@ -666,11 +644,6 @@ class Cell(Cell_): """ Defines the computation to be performed. This method must be overridden by all subclasses. - Note: - The outermost net only supports tensor inputs by default. - If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`. - Refer to the property `support_non_tensor_inputs` description. - Returns: Tensor, returns the computed result. """ diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index 85fc8504bd..d6bc8167f0 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -27,7 +27,6 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): class TestNet(nn.Cell): def __init__(self): super(TestNet, self).__init__() - self.support_non_tensor_inputs = False def construct(self, tuple_a, z, list_m, w, s, dict_n): return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]