From 559c741ccee9215eb75eb2a535d6830002ef6c7f Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 25 Sep 2020 17:02:11 +0800 Subject: [PATCH] improve the way passing ags of partial --- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 2 +- mindspore/core/ir/func_graph_extends.cc | 41 ++-- .../ut/python/pipeline/parse/test_partial.py | 223 ++++++++++++++++++ 3 files changed, 249 insertions(+), 17 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_partial.py diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 1dcbc0814b..04347a76ff 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -915,7 +915,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast List"; MS_EXCEPTION_IF_NULL(block); - py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); + py::list elts = python_adapter::GetPyObjAttr(node, "elts"); if (elts.size() == 0) { auto empty_list = std::vector(); return NewValueNode(std::make_shared(empty_list)); diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 2416da0823..133a072510 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -126,6 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; + std::set key_ward_para_nodes; for (const auto &kwarg : kwarg_list) { MS_EXCEPTION_IF_NULL(kwarg); std::string kw_param_name = kwarg->get_key(); @@ -146,7 +147,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, return param != nullptr && param->name() == param_name; }); if (find_kw_arg_in_list) { - MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; + MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name; } p->set_name(param_name); p->debug_info()->set_name(param_name); @@ -159,12 +160,14 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, } else { auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); // multiply values found given for parameter - if (node_itr != specialized_parameter_list->end()) { - MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; + if (node_itr != specialized_parameter_list->end() && + key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) { + MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name; } else { specialized_parameter_list->push_back(param_node); auto extract_node = specialized_graph->NewCNode( {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); + key_ward_para_nodes.insert(param_node); (void)repl_nodes->emplace(param_node, extract_node); } } @@ -199,10 +202,7 @@ bool FuncGraph::NeedGenerate(const std::vector } // if the graph is generated for specific input, do not need to generate again - if (is_generated()) { - return false; - } - return true; + return !is_generated(); } void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, @@ -232,20 +232,23 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { std::vector kwarg_list; + std::vector pos_arg_indexes; size_t arguments_count = args_spec_list.size(); - for (const auto &arg : args_spec_list) { - // if it is a keyword argument - MS_EXCEPTION_IF_NULL(arg); - if (arg->isa()) { - kwarg_list.push_back(dyn_cast(arg)); + for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { + MS_EXCEPTION_IF_NULL(args_spec_list[i]); + if (args_spec_list[i]->isa()) { + kwarg_list.push_back(args_spec_list[i]->cast()); + } else { + pos_arg_indexes.push_back(i); } } + if (!NeedGenerate(kwarg_list)) { return shared_from_base(); } FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); size_t kwarg_count = kwarg_list.size(); - int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); + int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_); int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); int variable_args_count = pos_args_input_count - pos_args_count; std::vector specialized_parameter_list; @@ -265,8 +268,14 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) // append hyper parameter to specialized_parameter_list MS_EXCEPTION_IF_NULL(specialized_graph); auto params = specialized_graph->parameters(); - (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), - std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); + specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_), + params.end()); + std::vector specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(), + specialized_parameter_list.end()); + for (size_t i = 0; i < pos_arg_indexes.size(); i++) { + specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i], + specialized_parameter_list[i]); + } std::shared_ptr manager = mindspore::Manage(specialized_graph, false); auto tr = manager->Transact(); @@ -275,7 +284,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) << node_pair.second->DebugString(); (void)tr.Replace(node_pair.first, node_pair.second); } - tr.SetParameters(specialized_graph, specialized_parameter_list); + tr.SetParameters(specialized_graph, specialized_parameter_list_update); tr.Commit(); specialized_graph->set_has_kwarg(false); specialized_graph->set_has_vararg(false); diff --git a/tests/ut/python/pipeline/parse/test_partial.py b/tests/ut/python/pipeline/parse/test_partial.py new file mode 100644 index 0000000000..ac7d50d658 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_partial.py @@ -0,0 +1,223 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test partial""" +from functools import partial + +import numpy as np +import pytest + +from mindspore import nn, Tensor, context + +context.set_context(mode=context.GRAPH_MODE) + +def test_partial_pos_arg(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self, x, y, z): + f = partial(self.show, x) + ret = f(y, z) + return ret + + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) + y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) + z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) + net = Net() + net(x, y, z) + +def test_partial_key_ward_arg(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self, x, y, z): + f = partial(self.show, x=x) + ret = f(y=y, z=z) + return ret + + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) + y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) + z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) + net = Net() + net(x, y, z) + +def test_partial_key_ward_arg_update(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self, x, y, z): + f = partial(self.show, x=x, y=y) + ret = f(y=y, z=z) + return ret + + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) + y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) + z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) + net = Net() + net(x, y, z) + + +def test_partial_key_ward_arg_and_pos_arg(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self, x, y, z): + f = partial(self.show, y=y) + ret = f(2, z=z) + return ret + + x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) + y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) + z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) + net = Net() + net(x, y, z) + + +def test_partial_pos_arg_const(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, 1) + ret = f(2, 3) + return ret + + net = Net() + assert net() == (1, 2, 3) + +def test_partial_key_ward_arg_const(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, x=1) + ret = f(y=2, z=3) + return ret + + net = Net() + assert net() == (1, 2, 3) + +def test_partial_key_ward_arg_update_const(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, x=1, y=2) + ret = f(y=3, z=4) + return ret + + net = Net() + assert net() == (1, 3, 4) + + +def test_partial_key_ward_arg_and_pos_arg_const(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, y=2) + ret = f(1, z=3) + return ret + + net = Net() + assert net() == (1, 2, 3) + + +def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, x=1) + ret = f(1, 2, 3) + return ret + + net = Net() + with pytest.raises(TypeError) as ex: + net() + assert "Multiply values for specific argument: x" in str(ex.value) + + +def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, y=2) + ret = f(1, 2, z=3) + return ret + + net = Net() + with pytest.raises(TypeError) as ex: + net() + assert "Multiply values for specific argument: y" in str(ex.value) + + +def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def show(self, x, y, z): + return x, y, z + + def construct(self): + f = partial(self.show, z=1) + ret = f(1, 2, 3) + return ret + + net = Net() + with pytest.raises(TypeError) as ex: + net() + assert "Multiply values for specific argument: z" in str(ex.value)