Browse Source

!9869 add detailed and accurate description for support non tensor inputs for outermost net

From: @zhangbuxue
Reviewed-by: @zhunaipan,@c_34
Signed-off-by: @c_34
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5af9b056bc
3 changed files with 31 additions and 15 deletions
  1. +16
    -11
      mindspore/common/api.py
  2. +9
    -4
      mindspore/nn/cell.py
  3. +6
    -0
      tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py

+ 16
- 11
mindspore/common/api.py View File

@@ -421,6 +421,7 @@ class _Executor:
Bool, if the graph has been compiled before, return False, else return True.
"""
from mindspore import nn
from mindspore.ops.composite import GradOperation

class InputsToAttrCell(nn.Cell):
"""The cell that converts non-tensor inputs to attr."""
@@ -457,17 +458,21 @@ class _Executor:
logger.debug("%r graph has existed.", phase)
return phase, False

if getattr(obj, "support_non_tensor_inputs", None):
attrs = {}
inputs = []
for key, value in dic.items():
if not isinstance(value, (Tensor, MetaTensor)):
attrs[key] = value
else:
inputs.append(value)
if attrs:
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
if getattr(obj, "support_non_tensor_inputs", None):
for i in obj.__dict__.values():
if isinstance(i, GradOperation):
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, "
"only support forward net.")
attrs = {}
inputs = []
for key, value in dic.items():
if not isinstance(value, (Tensor, MetaTensor)):
attrs[key] = value
else:
inputs.append(value)
if attrs:
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)

obj.check_names()
_check_full_batch()


+ 9
- 4
mindspore/nn/cell.py View File

@@ -127,8 +127,12 @@ class Cell(Cell_):
@property
def support_non_tensor_inputs(self):
"""
Whether support non tensor inputs in cell `construct` method.
This property only used in forward net, is not supported in grad net.
Whether support non tensor inputs in outermost net in GRAPH MODE.
This property only used in forward net, and is not supported in grad net.
The default value of the property is the `False`, that is,
it does not support passing non tensor inputs to the outermost net.
If you want to support, set the property to the `True`.

"""
return self._support_non_tensor_inputs

@@ -670,8 +674,9 @@ class Cell(Cell_):
Defines the computation to be performed. This method must be overridden by all subclasses.

Note:
The inputs of the top cell only allow Tensor.
Other types (tuple, list, int etc.) are forbidden.
The outermost net only supports tensor inputs by default.
If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`.
Refer to the property `support_non_tensor_inputs` description.

Returns:
Tensor, returns the computed result.


+ 6
- 0
tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py View File

@@ -61,3 +61,9 @@ def test_outermost_net_pass_scalar_tuple_list_dict():
with pytest.raises(TypeError) as err:
grad_net(arg_t0, z, arg_l0, w, 6, args_d0)
assert "For 'graph mode', the 0th arg" in str(err.value)

grad_net.support_non_tensor_inputs = True
with pytest.raises(ValueError) as err:
grad_net(arg_t0, z, arg_l0, w, 6, args_d0)
assert "Not support set 'support_non_tensor_inputs' to the 'True' for grad net, only support forward net." \
in str(err.value)

Loading…
Cancel
Save