|
|
|
@@ -52,8 +52,6 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
static std::shared_ptr<std::map<ParamInfoPtr, ParameterPtr>> python_paras; |
|
|
|
void ClearPythonParasMap() { python_paras = nullptr; } |
|
|
|
namespace { |
|
|
|
const int kSummaryGetItem = 2; |
|
|
|
const size_t max_depth = 128; |
|
|
|
@@ -681,19 +679,18 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf |
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
ParameterPtr new_parameter = nullptr; |
|
|
|
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter |
|
|
|
if (python_paras == nullptr) { |
|
|
|
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>(); |
|
|
|
} |
|
|
|
auto iter = python_paras->find(param_value); |
|
|
|
if (iter != python_paras->end()) { |
|
|
|
new_parameter = iter->second; |
|
|
|
if (param_value != nullptr) { |
|
|
|
new_parameter = param_value->parameter(); |
|
|
|
if (new_parameter == nullptr) { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
param_value->set_parameter(new_parameter); |
|
|
|
} |
|
|
|
} else { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
if (param_value != nullptr) { |
|
|
|
(*python_paras)[param_value] = new_parameter; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
new_parameter->IncreaseUsedGraphCount(); |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
valid_inputs->push_back(true); |
|
|
|
@@ -1126,10 +1123,10 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(sub_func_graph); |
|
|
|
if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { |
|
|
|
if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; |
|
|
|
} |
|
|
|
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; |
|
|
|
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()]; |
|
|
|
|
|
|
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph); |
|
|
|
new_value_node->set_abstract(value_node->abstract()); |
|
|
|
@@ -1155,19 +1152,19 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph |
|
|
|
|
|
|
|
auto param_value = GetParamDefaultValue(anf); |
|
|
|
ParameterPtr new_parameter = nullptr; |
|
|
|
if (python_paras == nullptr) { |
|
|
|
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>(); |
|
|
|
} |
|
|
|
auto iter = python_paras->find(param_value); |
|
|
|
if (iter != python_paras->end()) { |
|
|
|
new_parameter = iter->second; |
|
|
|
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter |
|
|
|
if (param_value != nullptr) { |
|
|
|
new_parameter = param_value->parameter(); |
|
|
|
if (new_parameter == nullptr) { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
param_value->set_parameter(new_parameter); |
|
|
|
} |
|
|
|
} else { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
if (param_value != nullptr) { |
|
|
|
(*python_paras)[param_value] = new_parameter; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
new_parameter->IncreaseUsedGraphCount(); |
|
|
|
|
|
|
|
return new_parameter; |
|
|
|
@@ -1423,7 +1420,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
auto graph = NewKernelGraph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
front_backend_graph_map_[func_graph] = graph; |
|
|
|
front_backend_graph_map_[func_graph.get()] = graph; |
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); |
|
|
|
for (const auto &node : node_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -1446,15 +1443,15 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP |
|
|
|
} |
|
|
|
// Create child kernel graph according ValueNode<FuncGraph> |
|
|
|
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); |
|
|
|
if (front_backend_graph_map_.find(child_graph) == front_backend_graph_map_.end()) { |
|
|
|
if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) { |
|
|
|
(void)ConstructKernelGraph(child_graph, all_out_graph); |
|
|
|
} |
|
|
|
(void)CreateValueNodeKernelGraph(node, graph.get()); |
|
|
|
auto &parent_graph = parent_graphs_[front_backend_graph_map_[child_graph]->graph_id()]; |
|
|
|
auto &parent_graph = parent_graphs_[front_backend_graph_map_[child_graph.get()]->graph_id()]; |
|
|
|
auto parent_graph_it = |
|
|
|
std::find(parent_graph.begin(), parent_graph.end(), front_backend_graph_map_[func_graph]->graph_id()); |
|
|
|
std::find(parent_graph.begin(), parent_graph.end(), front_backend_graph_map_[func_graph.get()]->graph_id()); |
|
|
|
if (parent_graph_it == parent_graph.end()) { |
|
|
|
parent_graph.push_back(front_backend_graph_map_[func_graph]->graph_id()); |
|
|
|
parent_graph.push_back(front_backend_graph_map_[func_graph.get()]->graph_id()); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
|