|
|
|
@@ -1790,38 +1790,92 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { |
|
|
|
auto func_graph = std::make_shared<FuncGraph>(); |
|
|
|
func_graph->debug_info()->set_name("top"); |
|
|
|
// Generate and copy a ValueNode, or a CNode with its child nodes |
|
|
|
static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr ¶m_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(param_node); |
|
|
|
if (param_node->isa<ValueNode>()) { |
|
|
|
return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value()); |
|
|
|
} |
|
|
|
|
|
|
|
// Parameter default value is CNode. |
|
|
|
std::size_t index = 0; |
|
|
|
std::vector<AnfNodePtr> old_cnodes; |
|
|
|
old_cnodes.emplace_back(param_node); |
|
|
|
auto res = func_graph->NewCNode({}); |
|
|
|
std::vector<CNodePtr> new_cnodes; |
|
|
|
new_cnodes.emplace_back(res); |
|
|
|
while (index < old_cnodes.size()) { |
|
|
|
auto current = old_cnodes[index]; |
|
|
|
auto current_new_cnode = new_cnodes[index]; |
|
|
|
index++; |
|
|
|
MS_EXCEPTION_IF_NULL(current); |
|
|
|
if (current->isa<CNode>()) { |
|
|
|
auto &inputs = current->cast<CNodePtr>()->inputs(); |
|
|
|
for (auto it = inputs.begin(); it != inputs.end(); it++) { |
|
|
|
AnfNodePtr input = *it; |
|
|
|
if (input != nullptr && input->isa<CNode>()) { |
|
|
|
old_cnodes.emplace_back(input); |
|
|
|
auto new_cnode = func_graph->NewCNode({}); |
|
|
|
new_cnodes.emplace_back(new_cnode); |
|
|
|
current_new_cnode->add_input(new_cnode); |
|
|
|
} else if (input->isa<ValueNode>()) { |
|
|
|
current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value())); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
// def top(*arg, *kwargs): |
|
|
|
auto param_vargs = func_graph->add_parameter(); |
|
|
|
auto args_name = "args"; |
|
|
|
param_vargs->set_name(args_name); |
|
|
|
param_vargs->debug_info()->set_name(args_name); |
|
|
|
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { |
|
|
|
auto current_graph = dyn_cast<FuncGraph>(cell_ptr); |
|
|
|
if (current_graph == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString(); |
|
|
|
} |
|
|
|
|
|
|
|
auto param_vkwargs = func_graph->add_parameter(); |
|
|
|
args_name = "kwargs"; |
|
|
|
param_vkwargs->set_name(args_name); |
|
|
|
param_vkwargs->debug_info()->set_name(args_name); |
|
|
|
auto func_graph = std::make_shared<FuncGraph>(); |
|
|
|
func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper"); |
|
|
|
|
|
|
|
func_graph->set_has_vararg(true); |
|
|
|
func_graph->set_has_kwarg(true); |
|
|
|
func_graph->set_kwonlyargs_count(0); |
|
|
|
// Copy all parameters information |
|
|
|
for (auto ¶ : current_graph->parameters()) { |
|
|
|
auto param = func_graph->add_parameter(); |
|
|
|
auto orig_param = para->cast<ParameterPtr>(); |
|
|
|
auto name = orig_param->name(); |
|
|
|
param->set_name(name); |
|
|
|
param->debug_info()->set_name(name); |
|
|
|
} |
|
|
|
func_graph->set_has_vararg(current_graph->has_vararg()); |
|
|
|
func_graph->set_has_kwarg(current_graph->has_kwarg()); |
|
|
|
func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count()); |
|
|
|
// Copy all default values |
|
|
|
for (auto &d : current_graph->parameter_default_value()) { |
|
|
|
func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second)); |
|
|
|
} |
|
|
|
|
|
|
|
// cell_obj |
|
|
|
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); |
|
|
|
parse::UpdateFuncGraphFlags(cell, func_graph); |
|
|
|
// top graph's construct flag |
|
|
|
if (py::hasattr(cell, "construct")) { |
|
|
|
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); |
|
|
|
} |
|
|
|
|
|
|
|
// ret = cell_obj(*arg, *kwargs) |
|
|
|
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs}); |
|
|
|
|
|
|
|
// return ret |
|
|
|
func_graph->set_output(call_fn); |
|
|
|
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); |
|
|
|
auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg(); |
|
|
|
if (!unpacking) { |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(cell_ptr)); |
|
|
|
auto ¶ms = func_graph->parameters(); |
|
|
|
(void)std::transform(params.begin(), params.end(), std::back_inserter(inputs), |
|
|
|
[](AnfNodePtr node) -> AnfNodePtr { return node; }); |
|
|
|
func_graph->set_output(func_graph->NewCNode(inputs)); |
|
|
|
} else { |
|
|
|
// ret = cell_obj(*arg, *kwargs) |
|
|
|
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters()); |
|
|
|
// return ret |
|
|
|
func_graph->set_output(call_fn); |
|
|
|
} |
|
|
|
return func_graph; |
|
|
|
} |
|
|
|
} // namespace parse |
|
|
|
|