|
|
@@ -126,6 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, |
|
|
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; |
|
|
|
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> key_ward_para_nodes; |
|
|
for (const auto &kwarg : kwarg_list) { |
|
|
for (const auto &kwarg : kwarg_list) { |
|
|
MS_EXCEPTION_IF_NULL(kwarg); |
|
|
MS_EXCEPTION_IF_NULL(kwarg); |
|
|
std::string kw_param_name = kwarg->get_key(); |
|
|
std::string kw_param_name = kwarg->get_key(); |
|
|
@@ -146,7 +147,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, |
|
|
return param != nullptr && param->name() == param_name; |
|
|
return param != nullptr && param->name() == param_name; |
|
|
}); |
|
|
}); |
|
|
if (find_kw_arg_in_list) { |
|
|
if (find_kw_arg_in_list) { |
|
|
MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; |
|
|
|
|
|
|
|
|
MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name; |
|
|
} |
|
|
} |
|
|
p->set_name(param_name); |
|
|
p->set_name(param_name); |
|
|
p->debug_info()->set_name(param_name); |
|
|
p->debug_info()->set_name(param_name); |
|
|
@@ -159,12 +160,14 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, |
|
|
} else { |
|
|
} else { |
|
|
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); |
|
|
auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); |
|
|
// multiply values found given for parameter |
|
|
// multiply values found given for parameter |
|
|
if (node_itr != specialized_parameter_list->end()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; |
|
|
|
|
|
|
|
|
if (node_itr != specialized_parameter_list->end() && |
|
|
|
|
|
key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) { |
|
|
|
|
|
MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name; |
|
|
} else { |
|
|
} else { |
|
|
specialized_parameter_list->push_back(param_node); |
|
|
specialized_parameter_list->push_back(param_node); |
|
|
auto extract_node = specialized_graph->NewCNode( |
|
|
auto extract_node = specialized_graph->NewCNode( |
|
|
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); |
|
|
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); |
|
|
|
|
|
key_ward_para_nodes.insert(param_node); |
|
|
(void)repl_nodes->emplace(param_node, extract_node); |
|
|
(void)repl_nodes->emplace(param_node, extract_node); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -199,10 +202,7 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// if the graph is generated for specific input, do not need to generate again |
|
|
// if the graph is generated for specific input, do not need to generate again |
|
|
if (is_generated()) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
|
|
|
return !is_generated(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, |
|
|
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, |
|
|
@@ -232,20 +232,23 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, |
|
|
|
|
|
|
|
|
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { |
|
|
FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { |
|
|
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; |
|
|
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; |
|
|
|
|
|
std::vector<size_t> pos_arg_indexes; |
|
|
size_t arguments_count = args_spec_list.size(); |
|
|
size_t arguments_count = args_spec_list.size(); |
|
|
for (const auto &arg : args_spec_list) { |
|
|
|
|
|
// if it is a keyword argument |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(arg); |
|
|
|
|
|
if (arg->isa<abstract::AbstractKeywordArg>()) { |
|
|
|
|
|
kwarg_list.push_back(dyn_cast<abstract::AbstractKeywordArg>(arg)); |
|
|
|
|
|
|
|
|
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(args_spec_list[i]); |
|
|
|
|
|
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) { |
|
|
|
|
|
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>()); |
|
|
|
|
|
} else { |
|
|
|
|
|
pos_arg_indexes.push_back(i); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!NeedGenerate(kwarg_list)) { |
|
|
if (!NeedGenerate(kwarg_list)) { |
|
|
return shared_from_base<FuncGraph>(); |
|
|
return shared_from_base<FuncGraph>(); |
|
|
} |
|
|
} |
|
|
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>()); |
|
|
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>()); |
|
|
size_t kwarg_count = kwarg_list.size(); |
|
|
size_t kwarg_count = kwarg_list.size(); |
|
|
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); |
|
|
|
|
|
|
|
|
int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_); |
|
|
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); |
|
|
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); |
|
|
int variable_args_count = pos_args_input_count - pos_args_count; |
|
|
int variable_args_count = pos_args_input_count - pos_args_count; |
|
|
std::vector<AnfNodePtr> specialized_parameter_list; |
|
|
std::vector<AnfNodePtr> specialized_parameter_list; |
|
|
@@ -265,8 +268,14 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) |
|
|
// append hyper parameter to specialized_parameter_list |
|
|
// append hyper parameter to specialized_parameter_list |
|
|
MS_EXCEPTION_IF_NULL(specialized_graph); |
|
|
MS_EXCEPTION_IF_NULL(specialized_graph); |
|
|
auto params = specialized_graph->parameters(); |
|
|
auto params = specialized_graph->parameters(); |
|
|
(void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), |
|
|
|
|
|
std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); |
|
|
|
|
|
|
|
|
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_), |
|
|
|
|
|
params.end()); |
|
|
|
|
|
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(), |
|
|
|
|
|
specialized_parameter_list.end()); |
|
|
|
|
|
for (size_t i = 0; i < pos_arg_indexes.size(); i++) { |
|
|
|
|
|
specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i], |
|
|
|
|
|
specialized_parameter_list[i]); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); |
|
|
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); |
|
|
auto tr = manager->Transact(); |
|
|
auto tr = manager->Transact(); |
|
|
@@ -275,7 +284,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) |
|
|
<< node_pair.second->DebugString(); |
|
|
<< node_pair.second->DebugString(); |
|
|
(void)tr.Replace(node_pair.first, node_pair.second); |
|
|
(void)tr.Replace(node_pair.first, node_pair.second); |
|
|
} |
|
|
} |
|
|
tr.SetParameters(specialized_graph, specialized_parameter_list); |
|
|
|
|
|
|
|
|
tr.SetParameters(specialized_graph, specialized_parameter_list_update); |
|
|
tr.Commit(); |
|
|
tr.Commit(); |
|
|
specialized_graph->set_has_kwarg(false); |
|
|
specialized_graph->set_has_kwarg(false); |
|
|
specialized_graph->set_has_vararg(false); |
|
|
specialized_graph->set_has_vararg(false); |
|
|
|