| @@ -440,59 +440,19 @@ class _Executor: | |||||
| Str, the full phase of the cell. | Str, the full phase of the cell. | ||||
| Bool, if the graph has been compiled before, return False, else return True. | Bool, if the graph has been compiled before, return False, else return True. | ||||
| """ | """ | ||||
| from mindspore import nn | |||||
| from mindspore.ops.composite import GradOperation | |||||
| class InputsToAttrCell(nn.Cell): | |||||
| """The cell that converts non-tensor inputs to attr.""" | |||||
| def __init__(self, net, args_names, non_tensor_inputs): | |||||
| super(InputsToAttrCell, self).__init__() | |||||
| self.net = net | |||||
| self.args_names = args_names | |||||
| self.non_tensor_inputs = non_tensor_inputs | |||||
| self.inputs_to_attr = True | |||||
| def construct(self, *tensor_inputs): | |||||
| real_inputs = () | |||||
| index = 0 | |||||
| for i in args_names: | |||||
| if i in self.non_tensor_inputs.keys(): | |||||
| real_inputs += (self.non_tensor_inputs[i],) | |||||
| else: | |||||
| real_inputs += (tensor_inputs[index],) | |||||
| index += 1 | |||||
| return self.net(*real_inputs) | |||||
| args_names, args_list = _generate_pip_args(obj, *args) | args_names, args_list = _generate_pip_args(obj, *args) | ||||
| if not hasattr(obj, "inputs_to_attr"): | |||||
| dic = dict(zip(args_names, args_list)) | |||||
| key = generate_key(phase, dic) | |||||
| obj.phase_prefix = str(key[1]) | |||||
| if 'export' in phase: | |||||
| phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) | |||||
| else: | |||||
| phase = obj.phase_prefix + phase + '.' + str(obj.create_time) | |||||
| if phase in self.compile_cache.keys(): | |||||
| logger.debug("%r graph has existed.", phase) | |||||
| return phase, False | |||||
| if getattr(obj, "support_non_tensor_inputs", None): | |||||
| for i in obj.__dict__.values(): | |||||
| if isinstance(i, GradOperation): | |||||
| raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, " | |||||
| "only support forward net.") | |||||
| attrs = {} | |||||
| inputs = [] | |||||
| for key, value in dic.items(): | |||||
| if not isinstance(value, (Tensor, MetaTensor)): | |||||
| attrs[key] = value | |||||
| else: | |||||
| inputs.append(value) | |||||
| if attrs: | |||||
| inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) | |||||
| return self.compile(inputs_to_attr_cell, *inputs, phase=phase) | |||||
| dic = dict(zip(args_names, args_list)) | |||||
| key = generate_key(phase, dic) | |||||
| obj.phase_prefix = str(key[1]) | |||||
| if 'export' in phase: | |||||
| phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) | |||||
| else: | |||||
| phase = obj.phase_prefix + phase + '.' + str(obj.create_time) | |||||
| if phase in self.compile_cache.keys(): | |||||
| logger.debug("%r graph has existed.", phase) | |||||
| return phase, False | |||||
| obj.check_names() | obj.check_names() | ||||
| _check_full_batch() | _check_full_batch() | ||||
| @@ -107,7 +107,6 @@ class Cell(Cell_): | |||||
| self._bprop_debug = False | self._bprop_debug = False | ||||
| self.cell_type = None | self.cell_type = None | ||||
| self._auto_parallel_compile_and_run = False | self._auto_parallel_compile_and_run = False | ||||
| self._support_non_tensor_inputs = False | |||||
| def __getstate__(self): | def __getstate__(self): | ||||
| base = Cell_.__getstate__(self) | base = Cell_.__getstate__(self) | ||||
| @@ -119,27 +118,6 @@ class Cell(Cell_): | |||||
| self.__dict__ = dict_ | self.__dict__ = dict_ | ||||
| self._attr_synced = False | self._attr_synced = False | ||||
| @property | |||||
| def support_non_tensor_inputs(self): | |||||
| """ | |||||
| Whether support non tensor inputs in outermost net in GRAPH MODE. | |||||
| This property only used in forward net, and is not supported in grad net. | |||||
| The default value of the property is the `False`, that is, | |||||
| it does not support passing non tensor inputs to the outermost net. | |||||
| If you want to support, set the property to the `True`. | |||||
| """ | |||||
| return self._support_non_tensor_inputs | |||||
| @support_non_tensor_inputs.setter | |||||
| def support_non_tensor_inputs(self, value): | |||||
| """ | |||||
| Set attr 'support_non_tensor_inputs'. | |||||
| """ | |||||
| if not isinstance(value, bool): | |||||
| raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.") | |||||
| self._support_non_tensor_inputs = value | |||||
| @property | @property | ||||
| def _cell_tag(self): | def _cell_tag(self): | ||||
| # `<class 'xxxxxxx'>` to `xxxxxxx` | # `<class 'xxxxxxx'>` to `xxxxxxx` | ||||
| @@ -666,11 +644,6 @@ class Cell(Cell_): | |||||
| """ | """ | ||||
| Defines the computation to be performed. This method must be overridden by all subclasses. | Defines the computation to be performed. This method must be overridden by all subclasses. | ||||
| Note: | |||||
| The outermost net only supports tensor inputs by default. | |||||
| If want to support non tensor inputs, set the property `support_non_tensor_inputs` to the `True`. | |||||
| Refer to the property `support_non_tensor_inputs` description. | |||||
| Returns: | Returns: | ||||
| Tensor, returns the computed result. | Tensor, returns the computed result. | ||||
| """ | """ | ||||
| @@ -27,7 +27,6 @@ def test_outermost_net_pass_scalar_tuple_list_dict(): | |||||
| class TestNet(nn.Cell): | class TestNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(TestNet, self).__init__() | super(TestNet, self).__init__() | ||||
| self.support_non_tensor_inputs = False | |||||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | def construct(self, tuple_a, z, list_m, w, s, dict_n): | ||||
| return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | ||||