| @@ -817,7 +817,7 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g | |||||
| } | } | ||||
| // construct graph include one op | // construct graph include one op | ||||
| auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | |||||
| auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask, true); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| opt::RunOpAscendBackendIRFusionOptimization(graph); | opt::RunOpAscendBackendIRFusionOptimization(graph); | ||||
| // kernel select | // kernel select | ||||
| @@ -1569,7 +1569,8 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr | |||||
| std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, | std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask) { | |||||
| const std::vector<int64_t> &tensors_mask, | |||||
| bool is_ascend) { | |||||
| auto graph = std::make_shared<KernelGraph>(); | auto graph = std::make_shared<KernelGraph>(); | ||||
| graph->set_graph_id(graph_sum_); | graph->set_graph_id(graph_sum_); | ||||
| graph_sum_++; | graph_sum_++; | ||||
| @@ -1612,7 +1613,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf | |||||
| graph->set_execution_order(exe_order); | graph->set_execution_order(exe_order); | ||||
| graph->UpdateGraphDynamicAttr(); | graph->UpdateGraphDynamicAttr(); | ||||
| // set output | // set output | ||||
| CreateOutputNode(cnode, graph); | |||||
| if (is_ascend) { | |||||
| graph->set_output(cnode); | |||||
| } else { | |||||
| CreateOutputNode(cnode, graph); | |||||
| } | |||||
| graph->SetInputNodes(); | graph->SetInputNodes(); | ||||
| auto manager = MakeManager({graph}); | auto manager = MakeManager({graph}); | ||||
| if (manager != nullptr) { | if (manager != nullptr) { | ||||
| @@ -180,7 +180,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| // create a single run op graph | // create a single run op graph | ||||
| std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask); | |||||
| const std::vector<int64_t> &tensors_mask, bool is_ascend = false); | |||||
| // create a new kernel graph and update the graph sum | // create a new kernel graph and update the graph sum | ||||
| KernelGraphPtr NewKernelGraph(); | KernelGraphPtr NewKernelGraph(); | ||||
| std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); | std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); | ||||