|
|
|
@@ -17,7 +17,7 @@ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include <memory> |
|
|
|
#include <utility> |
|
|
|
#include <set> |
|
|
|
|
|
|
|
#include "ir/graph_utils.h" |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
@@ -62,6 +62,11 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne |
|
|
|
AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const std::set<std::string> no_need_to_convert_nodes = {kStackOpName}; |
|
|
|
auto node_type = AnfAlgo::GetCNodeName(cnode); |
|
|
|
if (no_need_to_convert_nodes.find(node_type) != no_need_to_convert_nodes.end()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); |
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
|