diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index ae40d73fc6..df64ea6c67 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -115,7 +115,8 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_mapisa() || converted->isa(); + args_spec.push_back(abstract::FromValue(converted, broaden)); } if (g_args_cache.count(args_spec) == 0) { static int64_t key = 0; diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 02e9876ede..0c14d00c80 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -413,18 +413,54 @@ class _Executor: Str, the full phase of the cell. Bool, if the graph has been compiled before, return False, else return True. """ - args_names, args_list = _generate_pip_args(obj, *args) - dic = dict(zip(args_names, args_list)) - key = generate_key(phase, dic) - self.phase_prefix = str(key[1]) - if 'export' in phase: - phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) - else: - phase = self.phase_prefix + phase + '.' + str(obj.create_time) + from mindspore import nn + + class InputsToAttrCell(nn.Cell): + """The cell that converts non-tensor inputs to attr.""" + + def __init__(self, net, args_names, non_tensor_inputs): + super(InputsToAttrCell, self).__init__() + self.net = net + self.args_names = args_names + self.non_tensor_inputs = non_tensor_inputs + self.inputs_to_attr = True + + def construct(self, *tensor_inputs): + real_inputs = () + index = 0 + for i in args_names: + if i in self.non_tensor_inputs.keys(): + real_inputs += (self.non_tensor_inputs[i],) + else: + real_inputs += (tensor_inputs[index],) + index += 1 + return self.net(*real_inputs) - if phase in self.compile_cache.keys(): - logger.debug("%r graph has existed.", phase) - return phase, False + args_names, args_list = _generate_pip_args(obj, *args) + if not hasattr(obj, "inputs_to_attr"): + dic = dict(zip(args_names, args_list)) + key = generate_key(phase, dic) + self.phase_prefix = str(key[1]) + if 'export' in phase: + phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) + else: + phase = self.phase_prefix + phase + '.' + str(obj.create_time) + + if phase in self.compile_cache.keys(): + logger.debug("%r graph has existed.", phase) + return phase, False + + if getattr(obj, "support_non_tensor_inputs", None): + attrs = {} + inputs = [] + for key, value in dic.items(): + if not isinstance(value, (Tensor, MetaTensor)): + attrs[key] = value + else: + inputs.append(value) + if attrs: + inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) + return self.compile(inputs_to_attr_cell, *inputs, phase=phase) obj.check_names() _check_full_batch() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 777da7ef91..221ceb6b66 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -30,7 +30,7 @@ from ..ops.primitive import Primitive from ..ops.operations import HookBackward from ..ops.functional import cast from ..parallel._tensor import _load_tensor_by_layout -from ..common.tensor import Tensor +from ..common.tensor import Tensor, MetaTensor class Cell(Cell_): @@ -104,6 +104,7 @@ class Cell(Cell_): self._already_run = False self.cell_type = None self._auto_parallel_compile_and_run = False + self._support_non_tensor_inputs = False @property def already_run(self): @@ -119,6 +120,23 @@ class Cell(Cell_): self.__dict__ = dict_ self._attr_synced = False + @property + def support_non_tensor_inputs(self): + """ + Whether support non tensor inputs in cell `construct` method. + This property only used in forward net, is not supported in grad net. + """ + return self._support_non_tensor_inputs + + @support_non_tensor_inputs.setter + def support_non_tensor_inputs(self, value): + """ + Set attr 'support_non_tensor_inputs'. + """ + if not isinstance(value, bool): + raise ValueError("When set 'support_non_tensor_inputs' for cell, the value should be bool.") + self._support_non_tensor_inputs = value + @property def _cell_tag(self): # `` to `xxxxxxx` @@ -553,14 +571,19 @@ class Cell(Cell_): self._auto_parallel_compile_and_run = True self.compile(*inputs) + new_inputs = [] + for i in inputs: + if isinstance(i, (Tensor, MetaTensor)): + new_inputs.append(i) + if self._auto_parallel_mode: - if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: + if new_inputs and isinstance(new_inputs[0], Tensor) and inputs[0].virtual_flag: # get parallel inputs in sink mode, parallel inputs set in _executor.compile parallel_inputs_run = self._parallel_inputs_run else: - parallel_inputs_run = inputs + parallel_inputs_run = new_inputs return _executor(self, *parallel_inputs_run, phase=self.phase) - return _executor(self, *inputs, phase=self.phase) + return _executor(self, *new_inputs, phase=self.phase) def auto_parallel_compile_and_run(self): return self._auto_parallel_compile_and_run diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index e8be059971..1f9fa4c24b 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -94,7 +94,7 @@ def restrict_int_index(data_shape, tuple_indexes): for i, index in enumerate(tuple_indexes): if isinstance(index, mstype.Int): if index < -data_shape[i] or index >= data_shape[i]: - const_utils.raise_index_error("The index is out of the data's special dimension range.") + raise_index_error("The index is out of the data's special dimension range.") elif index < 0: tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) else: diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py new file mode 100644 index 0000000000..39c374b953 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_scalar_tuple_list_dict.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test outermost net pass scalar tuple list dict""" +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE) + + +def test_outermost_net_pass_scalar_tuple_list_dict(): + class TestNet(nn.Cell): + def __init__(self): + super(TestNet, self).__init__() + self.support_non_tensor_inputs = True + + def construct(self, tuple_a, z, list_m, w, s, dict_n): + return z - tuple_a[2] + list_m[1][1]["x"] - w + s - dict_n["y"] + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=True) + + def construct(self, tuple_a, z, list_m, w, s, dict_n): + return self.grad_all(self.forward_net)(tuple_a, z, list_m, w, s, dict_n) + + x = Tensor(np.ones((2, 2), np.float32)) + y = Tensor(np.ones((2, 2), np.float32) * 2) + z = Tensor(np.ones((2, 2), np.float32) * 3) + w = Tensor(np.ones((2, 2), np.float32) * 4) + arg_t0 = (x, y, z, w) + arg_t1 = (w, y, z, w) + arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] + arg_l1 = [[x, x], [[x, y], {"x": x, "y": y, "z": x, "p": y}]] + args_d0 = {"x": x, "y": y} + args_d1 = {"x": x, "y": y} + forward_net = TestNet() + forward_net(arg_t0, z, arg_l0, w, 6, args_d0) + forward_net(arg_t1, z, arg_l1, x, 6, args_d1) + + grad_net = GradNet(forward_net) + with pytest.raises(TypeError) as err: + grad_net(arg_t0, z, arg_l0, w, 6, args_d0) + assert "For 'graph mode', the 0th arg" in str(err.value)