| @@ -17,7 +17,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | |||||
| #include <set> | |||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| @@ -62,6 +62,11 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne | |||||
| AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| const std::set<std::string> no_need_to_convert_nodes = {kStackOpName}; | |||||
| auto node_type = AnfAlgo::GetCNodeName(cnode); | |||||
| if (no_need_to_convert_nodes.find(node_type) != no_need_to_convert_nodes.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_inputs; | std::vector<AnfNodePtr> new_inputs; | ||||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | ||||
| auto inputs = cnode->inputs(); | auto inputs = cnode->inputs(); | ||||
| @@ -119,6 +119,7 @@ constexpr auto kTransDataOpName = "TransData"; | |||||
| constexpr auto kStackInitOpName = "StackInit"; | constexpr auto kStackInitOpName = "StackInit"; | ||||
| constexpr auto kStackPushOpName = "StackPush"; | constexpr auto kStackPushOpName = "StackPush"; | ||||
| constexpr auto kStackPopOpName = "StackPop"; | constexpr auto kStackPopOpName = "StackPop"; | ||||
| constexpr auto kStackOpName = "Stack"; | |||||
| constexpr auto kStackDestroyOpName = "StackDestroy"; | constexpr auto kStackDestroyOpName = "StackDestroy"; | ||||
| constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad"; | constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad"; | ||||
| constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad"; | constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad"; | ||||
| @@ -72,6 +72,7 @@ class ParamInfo { | |||||
| this->be_cloned_ = true; | this->be_cloned_ = true; | ||||
| this->be_cloned_index_.push_back(index); | this->be_cloned_index_.push_back(index); | ||||
| clone->init_in_server_ = this->init_in_server_; | clone->init_in_server_ = this->init_in_server_; | ||||
| clone->ClearParameter(); | |||||
| return clone; | return clone; | ||||
| } | } | ||||
| @@ -88,6 +89,7 @@ class ParamInfo { | |||||
| void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } | void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } | ||||
| ParameterPtr parameter() { return parameter_; } | ParameterPtr parameter() { return parameter_; } | ||||
| void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; } | void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; } | ||||
| void ClearParameter() { parameter_ = nullptr; } | |||||
| private: | private: | ||||
| std::string name_{"Parameter"}; | std::string name_{"Parameter"}; | ||||