diff --git a/mindspore/common/api.py b/mindspore/common/api.py index e84bdab726..ba77828f5e 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -421,6 +421,7 @@ class _Executor: Bool, if the graph has been compiled before, return False, else return True. """ from mindspore import nn + from mindspore.ops.composite import GradOperation class InputsToAttrCell(nn.Cell): """The cell that converts non-tensor inputs to attr.""" @@ -457,17 +458,21 @@ class _Executor: logger.debug("%r graph has existed.", phase) return phase, False - if getattr(obj, "support_non_tensor_inputs", None): - attrs = {} - inputs = [] - for key, value in dic.items(): - if not isinstance(value, (Tensor, MetaTensor)): - attrs[key] = value - else: - inputs.append(value) - if attrs: - inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) - return self.compile(inputs_to_attr_cell, *inputs, phase=phase) + if getattr(obj, "support_non_tensor_inputs", None): + for i in obj.__dict__.values(): + if isinstance(i, GradOperation): + raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, " + "only support forward net.") + attrs = {} + inputs = [] + for key, value in dic.items(): + if not isinstance(value, (Tensor, MetaTensor)): + attrs[key] = value + else: + inputs.append(value) + if attrs: + inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) + return self.compile(inputs_to_attr_cell, *inputs, phase=phase) obj.check_names() _check_full_batch() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index a126766a90..f2d91244e3 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -127,8 +127,12 @@ class Cell(Cell_): @property def support_non_tensor_inputs(self): """ - Whether support non tensor inputs in cell `construct` method. - This property only used in forward net, is not supported in grad net. + Whether support non tensor inputs in outermost net in GRAPH MODE. + This property only used in forward net, and is not supported in grad net. + The default value of the property is the `False`, that is, + it does not support passing non tensor inputs to the outermost net. + If you want to support, set the property to the `True`. + """ return self._support_non_tensor_inputs @@ -670,8 +674,9 @@ class Cell(Cell_): Defines the computation to be performed. This method must be overridden by all subclasses. Note: - The inputs of the top cell only allow Tensor. - Other types (tuple, list, int etc.) are forbidden. + The outermost net only supports tensor inputs by default. + If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`. + Refer to the property `support_non_tensor_inputs` description. Returns: Tensor, returns the computed result. diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py index 39c374b953..1964ad63e3 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py @@ -61,3 +61,9 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): with pytest.raises(TypeError) as err: grad_net(arg_t0, z, arg_l0, w, 6, args_d0) assert "For 'graph mode', the 0th arg" in str(err.value) + + grad_net.support_non_tensor_inputs = True + with pytest.raises(ValueError) as err: + grad_net(arg_t0, z, arg_l0, w, 6, args_d0) + assert "Not support set 'support_non_tensor_inputs' to the 'True' for grad net, only support forward net." \ + in str(err.value)