|
|
|
@@ -852,6 +852,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
graph->set_summary_node_exist(true); |
|
|
|
} |
|
|
|
opt::BackendCommonOptimization(graph); |
|
|
|
graph->SetInputNodes(); |
|
|
|
graph->SetOptimizerFlag(); |
|
|
|
return graph; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -971,11 +973,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
if (kernel_graph->input_ctrl_tensors()) { |
|
|
|
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> input_nodes; |
|
|
|
for (const auto &input_node : kernel_graph->inputs()) { |
|
|
|
auto params = AnfAlgo::GetAllOutput(input_node); |
|
|
|
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes)); |
|
|
|
} |
|
|
|
auto &input_nodes = kernel_graph->input_nodes(); |
|
|
|
|
|
|
|
if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() |
|
|
|
<< ", input_ctrl_size:" << input_ctrl_size; |
|
|
|
@@ -1026,19 +1025,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<tensor::TensorPtr> SessionBasic::GetNeedLockInputTensors(const GraphId &graph_id, |
|
|
|
std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs) { |
|
|
|
auto graph = GetGraph(graph_id); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
bool has_optimizer = false; |
|
|
|
for (const auto &cnode : graph->execution_order()) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) { |
|
|
|
has_optimizer = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!has_optimizer) { |
|
|
|
if (!graph->has_optimizer()) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
std::vector<tensor::TensorPtr> result; |
|
|
|
@@ -1339,6 +1330,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf |
|
|
|
graph->set_execution_order(exe_order); |
|
|
|
// set output |
|
|
|
CreateOutputNode(cnode, graph); |
|
|
|
graph->SetInputNodes(); |
|
|
|
return graph; |
|
|
|
} |
|
|
|
|
|
|
|
|