From: @zhangbuxue Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -115,7 +115,8 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str | |||
| if (!parse::ConvertData(arg.second, &converted)) { | |||
| MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; | |||
| } | |||
| args_spec.push_back(abstract::FromValue(converted, true)); | |||
| bool broaden = converted->isa<Tensor>() || converted->isa<MetaTensor>(); | |||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | |||
| } | |||
| if (g_args_cache.count(args_spec) == 0) { | |||
| static int64_t key = 0; | |||
| @@ -413,18 +413,54 @@ class _Executor: | |||
| Str, the full phase of the cell. | |||
| Bool, if the graph has been compiled before, return False, else return True. | |||
| """ | |||
| args_names, args_list = _generate_pip_args(obj, *args) | |||
| dic = dict(zip(args_names, args_list)) | |||
| key = generate_key(phase, dic) | |||
| self.phase_prefix = str(key[1]) | |||
| if 'export' in phase: | |||
| phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) | |||
| else: | |||
| phase = self.phase_prefix + phase + '.' + str(obj.create_time) | |||
| from mindspore import nn | |||
| class InputsToAttrCell(nn.Cell): | |||
| """The cell that converts non-tensor inputs to attr.""" | |||
| def __init__(self, net, args_names, non_tensor_inputs): | |||
| super(InputsToAttrCell, self).__init__() | |||
| self.net = net | |||
| self.args_names = args_names | |||
| self.non_tensor_inputs = non_tensor_inputs | |||
| self.inputs_to_attr = True | |||
| def construct(self, *tensor_inputs): | |||
| real_inputs = () | |||
| index = 0 | |||
| for i in args_names: | |||
| if i in self.non_tensor_inputs.keys(): | |||
| real_inputs += (self.non_tensor_inputs[i],) | |||
| else: | |||
| real_inputs += (tensor_inputs[index],) | |||
| index += 1 | |||
| return self.net(*real_inputs) | |||
| if phase in self.compile_cache.keys(): | |||
| logger.debug("%r graph has existed.", phase) | |||
| return phase, False | |||
| args_names, args_list = _generate_pip_args(obj, *args) | |||
| if not hasattr(obj, "inputs_to_attr"): | |||
| dic = dict(zip(args_names, args_list)) | |||
| key = generate_key(phase, dic) | |||
| self.phase_prefix = str(key[1]) | |||
| if 'export' in phase: | |||
| phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) | |||
| else: | |||
| phase = self.phase_prefix + phase + '.' + str(obj.create_time) | |||
| if phase in self.compile_cache.keys(): | |||
| logger.debug("%r graph has existed.", phase) | |||
| return phase, False | |||
| if getattr(obj, "support_non_tensor_inputs", None): | |||
| attrs = {} | |||
| inputs = [] | |||
| for key, value in dic.items(): | |||
| if not isinstance(value, (Tensor, MetaTensor)): | |||
| attrs[key] = value | |||
| else: | |||
| inputs.append(value) | |||
| if attrs: | |||
| inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) | |||
| return self.compile(inputs_to_attr_cell, *inputs, phase=phase) | |||
| obj.check_names() | |||
| _check_full_batch() | |||
| @@ -30,7 +30,7 @@ from ..ops.primitive import Primitive | |||
| from ..ops.operations import HookBackward | |||
| from ..ops.functional import cast | |||
| from ..parallel._tensor import _load_tensor_by_layout | |||
| from ..common.tensor import Tensor | |||
| from ..common.tensor import Tensor, MetaTensor | |||
| class Cell(Cell_): | |||
| @@ -104,6 +104,7 @@ class Cell(Cell_): | |||
| self._already_run = False | |||
| self.cell_type = None | |||
| self._auto_parallel_compile_and_run = False | |||
| self._support_non_tensor_inputs = False | |||
| @property | |||
| def already_run(self): | |||
| @@ -119,6 +120,23 @@ class Cell(Cell_): | |||
| self.__dict__ = dict_ | |||
| self._attr_synced = False | |||
| @property | |||
| def support_non_tensor_inputs(self): | |||
| """ | |||
| Whether support non tensor inputs in cell `construct` method. | |||
| This property only used in forward net, is not supported in grad net. | |||
| """ | |||
| return self._support_non_tensor_inputs | |||
| @support_non_tensor_inputs.setter | |||
| def support_non_tensor_inputs(self, value): | |||
| """ | |||
| Set attr 'support_non_tensor_inputs'. | |||
| """ | |||
| if not isinstance(value, bool): | |||
| raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.") | |||
| self._support_non_tensor_inputs = value | |||
| @property | |||
| def _cell_tag(self): | |||
| # `<class 'xxxxxxx'>` to `xxxxxxx` | |||
| @@ -553,14 +571,19 @@ class Cell(Cell_): | |||
| self._auto_parallel_compile_and_run = True | |||
| self.compile(*inputs) | |||
| new_inputs = [] | |||
| for i in inputs: | |||
| if isinstance(i, (Tensor, MetaTensor)): | |||
| new_inputs.append(i) | |||
| if self._auto_parallel_mode: | |||
| if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: | |||
| if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag: | |||
| # get parallel inputs in sink mode, parallel inputs set in _executor.compile | |||
| parallel_inputs_run = self._parallel_inputs_run | |||
| else: | |||
| parallel_inputs_run = inputs | |||
| parallel_inputs_run = new_inputs | |||
| return _executor(self, *parallel_inputs_run, phase=self.phase) | |||
| return _executor(self, *inputs, phase=self.phase) | |||
| return _executor(self, *new_inputs, phase=self.phase) | |||
| def auto_parallel_compile_and_run(self): | |||
| return self._auto_parallel_compile_and_run | |||
| @@ -94,7 +94,7 @@ def restrict_int_index(data_shape, tuple_indexes): | |||
| for i, index in enumerate(tuple_indexes): | |||
| if isinstance(index, mstype.Int): | |||
| if index < -data_shape[i] or index >= data_shape[i]: | |||
| const_utils.raise_index_error("The index is out of the data's special dimension range.") | |||
| raise_index_error("The index is out of the data's special dimension range.") | |||
| elif index < 0: | |||
| tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) | |||
| else: | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test outermost net pass scalar tuple list dict""" | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_outermost_net_pass_scalar_tuple_list_dict(): | |||
| class TestNet(nn.Cell): | |||
| def __init__(self): | |||
| super(TestNet, self).__init__() | |||
| self.support_non_tensor_inputs = True | |||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||
| return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.forward_net = net | |||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||
| self.grad_all = C.GradOperation(get_all=True) | |||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||
| return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n) | |||
| x = Tensor(np.ones((2, 2), np.float32)) | |||
| y = Tensor(np.ones((2, 2), np.float32) * 2) | |||
| z = Tensor(np.ones((2, 2), np.float32) * 3) | |||
| w = Tensor(np.ones((2, 2), np.float32) * 4) | |||
| arg_t0 = (x, y, z, w) | |||
| arg_t1 = (w, y, z, w) | |||
| arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||
| arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||
| args_d0 = {"x": x, "y": y} | |||
| args_d1 = {"x": x, "y": y} | |||
| forward_net = TestNet() | |||
| forward_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| forward_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||
| grad_net = GradNet(forward_net) | |||
| with pytest.raises(TypeError) as err: | |||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||
| assert "For 'graph mode', the 0th arg" in str(err.value) | |||