Browse Source

remove attr support_non_tensor_input of cell

tags/v1.2.0-rc1
buxue 4 years ago
parent
commit
7eaf84d07a
3 changed files with 11 additions and 79 deletions
  1. +11
    -51
      mindspore/common/api.py
  2. +0
    -27
      mindspore/nn/cell.py
  3. +0
    -1
      tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py

+ 11
- 51
mindspore/common/api.py View File

@@ -440,59 +440,19 @@ class _Executor:
Str, the full phase of the cell.
Bool, if the graph has been compiled before, return False, else return True.
"""
from mindspore import nn
from mindspore.ops.composite import GradOperation

class InputsToAttrCell(nn.Cell):
"""The cell that converts non-tensor inputs to attr."""

def __init__(self, net, args_names, non_tensor_inputs):
super(InputsToAttrCell, self).__init__()
self.net = net
self.args_names = args_names
self.non_tensor_inputs = non_tensor_inputs
self.inputs_to_attr = True

def construct(self, *tensor_inputs):
real_inputs = ()
index = 0
for i in args_names:
if i in self.non_tensor_inputs.keys():
real_inputs += (self.non_tensor_inputs[i],)
else:
real_inputs += (tensor_inputs[index],)
index += 1
return self.net(*real_inputs)

args_names, args_list = _generate_pip_args(obj, *args)
if not hasattr(obj, "inputs_to_attr"):
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)

if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False

if getattr(obj, "support_non_tensor_inputs", None):
for i in obj.__dict__.values():
if isinstance(i, GradOperation):
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, "
"only support forward net.")
attrs = {}
inputs = []
for key, value in dic.items():
if not isinstance(value, (Tensor, MetaTensor)):
attrs[key] = value
else:
inputs.append(value)
if attrs:
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs)
return self.compile(inputs_to_attr_cell, *inputs, phase=phase)
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)

if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False

obj.check_names()
_check_full_batch()


+ 0
- 27
mindspore/nn/cell.py View File

@@ -107,7 +107,6 @@ class Cell(Cell_):
self._bprop_debug = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
self._support_non_tensor_inputs = False

def __getstate__(self):
base = Cell_.__getstate__(self)
@@ -119,27 +118,6 @@ class Cell(Cell_):
self.__dict__ = dict_
self._attr_synced = False

@property
def support_non_tensor_inputs(self):
"""
Whether support non tensor inputs in outermost net in GRAPH MODE.
This property only used in forward net, and is not supported in grad net.
The default value of the property is the `False`, that is,
it does not support passing non tensor inputs to the outermost net.
If you want to support, set the property to the `True`.

"""
return self._support_non_tensor_inputs

@support_non_tensor_inputs.setter
def support_non_tensor_inputs(self, value):
"""
Set attr 'support_non_tensor_inputs'.
"""
if not isinstance(value, bool):
raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.")
self._support_non_tensor_inputs = value

@property
def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx`
@@ -666,11 +644,6 @@ class Cell(Cell_):
"""
Defines the computation to be performed. This method must be overridden by all subclasses.

Note:
The outermost net only supports tensor inputs by default.
If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`.
Refer to the property `support_non_tensor_inputs` description.

Returns:
Tensor, returns the computed result.
"""


+ 0
- 1
tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py View File

@@ -27,7 +27,6 @@ def test_outermost_net_pass_scalar_tuple_list_dict():
class TestNet(nn.Cell):
def __init__(self):
super(TestNet, self).__init__()
self.support_non_tensor_inputs = False

def construct(self, tuple_a, z, list_m, w, s, dict_n):
return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"]


Loading…
Cancel
Save