| @@ -915,7 +915,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n | |||
| AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { | |||
| MS_LOG(DEBUG) << "Process ast List"; | |||
| MS_EXCEPTION_IF_NULL(block); | |||
| py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); | |||
| py::list elts = python_adapter::GetPyObjAttr(node, "elts"); | |||
| if (elts.size() == 0) { | |||
| auto empty_list = std::vector<ValuePtr>(); | |||
| return NewValueNode(std::make_shared<ValueList>(empty_list)); | |||
| @@ -126,6 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||
| std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| std::set<AnfNodePtr> key_ward_para_nodes; | |||
| for (const auto &kwarg : kwarg_list) { | |||
| MS_EXCEPTION_IF_NULL(kwarg); | |||
| std::string kw_param_name = kwarg->get_key(); | |||
| @@ -146,7 +147,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||
| return param != nullptr && param->name() == param_name; | |||
| }); | |||
| if (find_kw_arg_in_list) { | |||
| MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; | |||
| MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name; | |||
| } | |||
| p->set_name(param_name); | |||
| p->debug_info()->set_name(param_name); | |||
| @@ -159,12 +160,14 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||
| } else { | |||
| auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); | |||
| // multiply values found given for parameter | |||
| if (node_itr != specialized_parameter_list->end()) { | |||
| MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; | |||
| if (node_itr != specialized_parameter_list->end() && | |||
| key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) { | |||
| MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name; | |||
| } else { | |||
| specialized_parameter_list->push_back(param_node); | |||
| auto extract_node = specialized_graph->NewCNode( | |||
| {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); | |||
| key_ward_para_nodes.insert(param_node); | |||
| (void)repl_nodes->emplace(param_node, extract_node); | |||
| } | |||
| } | |||
| @@ -199,10 +202,7 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> | |||
| } | |||
| // if the graph is generated for specific input, do not need to generate again | |||
| if (is_generated()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| return !is_generated(); | |||
| } | |||
| void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, | |||
| @@ -232,20 +232,23 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, | |||
| FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { | |||
| std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | |||
| std::vector<size_t> pos_arg_indexes; | |||
| size_t arguments_count = args_spec_list.size(); | |||
| for (const auto &arg : args_spec_list) { | |||
| // if it is a keyword argument | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (arg->isa<abstract::AbstractKeywordArg>()) { | |||
| kwarg_list.push_back(dyn_cast<abstract::AbstractKeywordArg>(arg)); | |||
| for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[i]); | |||
| if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) { | |||
| kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>()); | |||
| } else { | |||
| pos_arg_indexes.push_back(i); | |||
| } | |||
| } | |||
| if (!NeedGenerate(kwarg_list)) { | |||
| return shared_from_base<FuncGraph>(); | |||
| } | |||
| FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>()); | |||
| size_t kwarg_count = kwarg_list.size(); | |||
| int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); | |||
| int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_); | |||
| int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); | |||
| int variable_args_count = pos_args_input_count - pos_args_count; | |||
| std::vector<AnfNodePtr> specialized_parameter_list; | |||
| @@ -265,8 +268,14 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) | |||
| // append hyper parameter to specialized_parameter_list | |||
| MS_EXCEPTION_IF_NULL(specialized_graph); | |||
| auto params = specialized_graph->parameters(); | |||
| (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), | |||
| std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); | |||
| specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_), | |||
| params.end()); | |||
| std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(), | |||
| specialized_parameter_list.end()); | |||
| for (size_t i = 0; i < pos_arg_indexes.size(); i++) { | |||
| specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i], | |||
| specialized_parameter_list[i]); | |||
| } | |||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); | |||
| auto tr = manager->Transact(); | |||
| @@ -275,7 +284,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) | |||
| << node_pair.second->DebugString(); | |||
| (void)tr.Replace(node_pair.first, node_pair.second); | |||
| } | |||
| tr.SetParameters(specialized_graph, specialized_parameter_list); | |||
| tr.SetParameters(specialized_graph, specialized_parameter_list_update); | |||
| tr.Commit(); | |||
| specialized_graph->set_has_kwarg(false); | |||
| specialized_graph->set_has_vararg(false); | |||
| @@ -0,0 +1,223 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test partial""" | |||
| from functools import partial | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import nn, Tensor, context | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_partial_pos_arg(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self, x, y, z): | |||
| f = partial(self.show, x) | |||
| ret = f(y, z) | |||
| return ret | |||
| x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) | |||
| y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) | |||
| z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) | |||
| net = Net() | |||
| net(x, y, z) | |||
| def test_partial_key_ward_arg(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self, x, y, z): | |||
| f = partial(self.show, x=x) | |||
| ret = f(y=y, z=z) | |||
| return ret | |||
| x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) | |||
| y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) | |||
| z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) | |||
| net = Net() | |||
| net(x, y, z) | |||
| def test_partial_key_ward_arg_update(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self, x, y, z): | |||
| f = partial(self.show, x=x, y=y) | |||
| ret = f(y=y, z=z) | |||
| return ret | |||
| x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) | |||
| y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) | |||
| z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) | |||
| net = Net() | |||
| net(x, y, z) | |||
| def test_partial_key_ward_arg_and_pos_arg(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self, x, y, z): | |||
| f = partial(self.show, y=y) | |||
| ret = f(2, z=z) | |||
| return ret | |||
| x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) | |||
| y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) | |||
| z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) | |||
| net = Net() | |||
| net(x, y, z) | |||
| def test_partial_pos_arg_const(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, 1) | |||
| ret = f(2, 3) | |||
| return ret | |||
| net = Net() | |||
| assert net() == (1, 2, 3) | |||
| def test_partial_key_ward_arg_const(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, x=1) | |||
| ret = f(y=2, z=3) | |||
| return ret | |||
| net = Net() | |||
| assert net() == (1, 2, 3) | |||
| def test_partial_key_ward_arg_update_const(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, x=1, y=2) | |||
| ret = f(y=3, z=4) | |||
| return ret | |||
| net = Net() | |||
| assert net() == (1, 3, 4) | |||
| def test_partial_key_ward_arg_and_pos_arg_const(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, y=2) | |||
| ret = f(1, z=3) | |||
| return ret | |||
| net = Net() | |||
| assert net() == (1, 2, 3) | |||
| def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, x=1) | |||
| ret = f(1, 2, 3) | |||
| return ret | |||
| net = Net() | |||
| with pytest.raises(TypeError) as ex: | |||
| net() | |||
| assert "Multiply values for specific argument: x" in str(ex.value) | |||
| def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, y=2) | |||
| ret = f(1, 2, z=3) | |||
| return ret | |||
| net = Net() | |||
| with pytest.raises(TypeError) as ex: | |||
| net() | |||
| assert "Multiply values for specific argument: y" in str(ex.value) | |||
| def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def show(self, x, y, z): | |||
| return x, y, z | |||
| def construct(self): | |||
| f = partial(self.show, z=1) | |||
| ret = f(1, 2, 3) | |||
| return ret | |||
| net = Net() | |||
| with pytest.raises(TypeError) as ex: | |||
| net() | |||
| assert "Multiply values for specific argument: z" in str(ex.value) | |||