From: @zhangbuxue Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -115,7 +115,8 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str | |||||
| if (!parse::ConvertData(arg.second, &converted)) { | if (!parse::ConvertData(arg.second, &converted)) { | ||||
| MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; | MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; | ||||
| } | } | ||||
| args_spec.push_back(abstract::FromValue(converted, true)); | |||||
| bool broaden = converted->isa<Tensor>() || converted->isa<MetaTensor>(); | |||||
| args_spec.push_back(abstract::FromValue(converted, broaden)); | |||||
| } | } | ||||
| if (g_args_cache.count(args_spec) == 0) { | if (g_args_cache.count(args_spec) == 0) { | ||||
| static int64_t key = 0; | static int64_t key = 0; | ||||
| @@ -413,18 +413,54 @@ class _Executor: | |||||
| Str, the full phase of the cell. | Str, the full phase of the cell. | ||||
| Bool, if the graph has been compiled before, return False, else return True. | Bool, if the graph has been compiled before, return False, else return True. | ||||
| """ | """ | ||||
| args_names, args_list = _generate_pip_args(obj, *args) | |||||
| dic = dict(zip(args_names, args_list)) | |||||
| key = generate_key(phase, dic) | |||||
| self.phase_prefix = str(key[1]) | |||||
| if 'export' in phase: | |||||
| phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) | |||||
| else: | |||||
| phase = self.phase_prefix + phase + '.' + str(obj.create_time) | |||||
| from mindspore import nn | |||||
| class InputsToAttrCell(nn.Cell): | |||||
| """The cell that converts non-tensor inputs to attr.""" | |||||
| def __init__(self, net, args_names, non_tensor_inputs): | |||||
| super(InputsToAttrCell, self).__init__() | |||||
| self.net = net | |||||
| self.args_names = args_names | |||||
| self.non_tensor_inputs = non_tensor_inputs | |||||
| self.inputs_to_attr = True | |||||
| def construct(self, *tensor_inputs): | |||||
| real_inputs = () | |||||
| index = 0 | |||||
| for i in args_names: | |||||
| if i in self.non_tensor_inputs.keys(): | |||||
| real_inputs += (self.non_tensor_inputs[i],) | |||||
| else: | |||||
| real_inputs += (tensor_inputs[index],) | |||||
| index += 1 | |||||
| return self.net(*real_inputs) | |||||
| if phase in self.compile_cache.keys(): | |||||
| logger.debug("%r graph has existed.", phase) | |||||
| return phase, False | |||||
| args_names, args_list = _generate_pip_args(obj, *args) | |||||
| if not hasattr(obj, "inputs_to_attr"): | |||||
| dic = dict(zip(args_names, args_list)) | |||||
| key = generate_key(phase, dic) | |||||
| self.phase_prefix = str(key[1]) | |||||
| if 'export' in phase: | |||||
| phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) | |||||
| else: | |||||
| phase = self.phase_prefix + phase + '.' + str(obj.create_time) | |||||
| if phase in self.compile_cache.keys(): | |||||
| logger.debug("%r graph has existed.", phase) | |||||
| return phase, False | |||||
| if getattr(obj, "support_non_tensor_inputs", None): | |||||
| attrs = {} | |||||
| inputs = [] | |||||
| for key, value in dic.items(): | |||||
| if not isinstance(value, (Tensor, MetaTensor)): | |||||
| attrs[key] = value | |||||
| else: | |||||
| inputs.append(value) | |||||
| if attrs: | |||||
| inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) | |||||
| return self.compile(inputs_to_attr_cell, *inputs, phase=phase) | |||||
| obj.check_names() | obj.check_names() | ||||
| _check_full_batch() | _check_full_batch() | ||||
| @@ -30,7 +30,7 @@ from ..ops.primitive import Primitive | |||||
| from ..ops.operations import HookBackward | from ..ops.operations import HookBackward | ||||
| from ..ops.functional import cast | from ..ops.functional import cast | ||||
| from ..parallel._tensor import _load_tensor_by_layout | from ..parallel._tensor import _load_tensor_by_layout | ||||
| from ..common.tensor import Tensor | |||||
| from ..common.tensor import Tensor, MetaTensor | |||||
| class Cell(Cell_): | class Cell(Cell_): | ||||
| @@ -104,6 +104,7 @@ class Cell(Cell_): | |||||
| self._already_run = False | self._already_run = False | ||||
| self.cell_type = None | self.cell_type = None | ||||
| self._auto_parallel_compile_and_run = False | self._auto_parallel_compile_and_run = False | ||||
| self._support_non_tensor_inputs = False | |||||
| @property | @property | ||||
| def already_run(self): | def already_run(self): | ||||
| @@ -119,6 +120,23 @@ class Cell(Cell_): | |||||
| self.__dict__ = dict_ | self.__dict__ = dict_ | ||||
| self._attr_synced = False | self._attr_synced = False | ||||
| @property | |||||
| def support_non_tensor_inputs(self): | |||||
| """ | |||||
| Whether support non tensor inputs in cell `construct` method. | |||||
| This property only used in forward net, is not supported in grad net. | |||||
| """ | |||||
| return self._support_non_tensor_inputs | |||||
| @support_non_tensor_inputs.setter | |||||
| def support_non_tensor_inputs(self, value): | |||||
| """ | |||||
| Set attr 'support_non_tensor_inputs'. | |||||
| """ | |||||
| if not isinstance(value, bool): | |||||
| raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.") | |||||
| self._support_non_tensor_inputs = value | |||||
| @property | @property | ||||
| def _cell_tag(self): | def _cell_tag(self): | ||||
| # `<class 'xxxxxxx'>` to `xxxxxxx` | # `<class 'xxxxxxx'>` to `xxxxxxx` | ||||
| @@ -553,14 +571,19 @@ class Cell(Cell_): | |||||
| self._auto_parallel_compile_and_run = True | self._auto_parallel_compile_and_run = True | ||||
| self.compile(*inputs) | self.compile(*inputs) | ||||
| new_inputs = [] | |||||
| for i in inputs: | |||||
| if isinstance(i, (Tensor, MetaTensor)): | |||||
| new_inputs.append(i) | |||||
| if self._auto_parallel_mode: | if self._auto_parallel_mode: | ||||
| if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: | |||||
| if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag: | |||||
| # get parallel inputs in sink mode, parallel inputs set in _executor.compile | # get parallel inputs in sink mode, parallel inputs set in _executor.compile | ||||
| parallel_inputs_run = self._parallel_inputs_run | parallel_inputs_run = self._parallel_inputs_run | ||||
| else: | else: | ||||
| parallel_inputs_run = inputs | |||||
| parallel_inputs_run = new_inputs | |||||
| return _executor(self, *parallel_inputs_run, phase=self.phase) | return _executor(self, *parallel_inputs_run, phase=self.phase) | ||||
| return _executor(self, *inputs, phase=self.phase) | |||||
| return _executor(self, *new_inputs, phase=self.phase) | |||||
| def auto_parallel_compile_and_run(self): | def auto_parallel_compile_and_run(self): | ||||
| return self._auto_parallel_compile_and_run | return self._auto_parallel_compile_and_run | ||||
| @@ -94,7 +94,7 @@ def restrict_int_index(data_shape, tuple_indexes): | |||||
| for i, index in enumerate(tuple_indexes): | for i, index in enumerate(tuple_indexes): | ||||
| if isinstance(index, mstype.Int): | if isinstance(index, mstype.Int): | ||||
| if index < -data_shape[i] or index >= data_shape[i]: | if index < -data_shape[i] or index >= data_shape[i]: | ||||
| const_utils.raise_index_error("The index is out of the data's special dimension range.") | |||||
| raise_index_error("The index is out of the data's special dimension range.") | |||||
| elif index < 0: | elif index < 0: | ||||
| tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) | tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) | ||||
| else: | else: | ||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test outermost net pass scalar tuple list dict""" | |||||
| import pytest | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import composite as C | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_outermost_net_pass_scalar_tuple_list_dict(): | |||||
| class TestNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(TestNet, self).__init__() | |||||
| self.support_non_tensor_inputs = True | |||||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||||
| return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.forward_net = net | |||||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||||
| self.grad_all = C.GradOperation(get_all=True) | |||||
| def construct(self, tuple_a, z, list_m, w, s, dict_n): | |||||
| return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n) | |||||
| x = Tensor(np.ones((2, 2), np.float32)) | |||||
| y = Tensor(np.ones((2, 2), np.float32) * 2) | |||||
| z = Tensor(np.ones((2, 2), np.float32) * 3) | |||||
| w = Tensor(np.ones((2, 2), np.float32) * 4) | |||||
| arg_t0 = (x, y, z, w) | |||||
| arg_t1 = (w, y, z, w) | |||||
| arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||||
| arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] | |||||
| args_d0 = {"x": x, "y": y} | |||||
| args_d1 = {"x": x, "y": y} | |||||
| forward_net = TestNet() | |||||
| forward_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||||
| forward_net(arg_t1, z, arg_l1, x, 6, args_d1) | |||||
| grad_net = GradNet(forward_net) | |||||
| with pytest.raises(TypeError) as err: | |||||
| grad_net(arg_t0, z, arg_l0, w, 6, args_d0) | |||||
| assert "For 'graph mode', the 0th arg" in str(err.value) | |||||