Browse Source

!11270 remove attr support_non_tensor_input of cell

From: @zhangbuxue
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
eab76cc09f
3 changed files with 11 additions and 79 deletions
  1. +11
    -51
      mindspore/common/api.py
  2. +0
    -27
      mindspore/nn/cell.py
  3. +0
    -1
      tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py

+ 11
- 51
mindspore/common/api.py View File

@@ -440,59 +440,19 @@ class _Executor:
Str, the full phase of the cell.
Bool, if the graph has been compiled before, return False, else return True.
"""
from mindspore import nn
from mindspore.ops.composite import GradOperation

class InputsToAttrCell(nn.Cell):
"""The cell that converts non-tensor inputs to attr."""

def __init__(self, net, args_names, non_tensor_inputs):
super(InputsToAttrCell, self).__init__()
self.net = net
self.args_names = args_names
self.non_tensor_inputs = non_tensor_inputs
self.inputs_to_attr = True

def construct(self, *tensor_inputs):
real_inputs = ()
index = 0
for i in args_names:
if i in self.non_tensor_inputs.keys():
real_inputs += (self.non_tensor_inputs[i],)
else:
real_inputs += (tensor_inputs[index],)
index += 1
return self.net(*real_inputs)

args_names, args_list = _generate_pip_args(obj, *args)
if not hasattr(obj, "inputs_to_attr"):
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)

if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False

if getattr(obj, "support_non_tensor_inputs", None):
for i in obj.__dict__.values():
if isinstance(i, GradOperation):
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, "
"only support forward net.")
attrs = {}
inputs = []
for key, value in dic.items():
if not isinstance(value, (Tensor, MetaTensor)):
attrs[key] = value
else:
inputs.append(value)
if attrs:
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)

if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False

obj.check_names()
_check_full_batch()


+ 0
- 27
mindspore/nn/cell.py View File

@@ -107,7 +107,6 @@ class Cell(Cell_):
self._bprop_debug = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
self._support_non_tensor_inputs = False

def __getstate__(self):
base = Cell_.__getstate__(self)
@@ -119,27 +118,6 @@ class Cell(Cell_):
self.__dict__ = dict_
self._attr_synced = False

@property
def support_non_tensor_inputs(self):
"""
Whether support non tensor inputs in outermost net in GRAPH MODE.
This property only used in forward net, and is not supported in grad net.
The default value of the property is the `False`, that is,
it does not support passing non tensor inputs to the outermost net.
If you want to support, set the property to the `True`.

"""
return self._support_non_tensor_inputs

@support_non_tensor_inputs.setter
def support_non_tensor_inputs(self, value):
"""
Set attr 'support_non_tensor_inputs'.
"""
if not isinstance(value, bool):
raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.")
self._support_non_tensor_inputs = value

@property
def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx`
@@ -666,11 +644,6 @@ class Cell(Cell_):
"""
Defines the computation to be performed. This method must be overridden by all subclasses.

Note:
The outermost net only supports tensor inputs by default.
If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`.
Refer to the property `support_non_tensor_inputs` description.

Returns:
Tensor, returns the computed result.
"""


+ 0
- 1
tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py View File

@@ -27,7 +27,6 @@ def test_outermost_net_pass_scalar_tuple_list_dict():
class TestNet(nn.Cell):
def __init__(self):
super(TestNet, self).__init__()
self.support_non_tensor_inputs = False

def construct(self, tuple_a, z, list_m, w, s, dict_n):
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]


Loading…
Cancel
Save