Browse Source

clear parameter when param_info clone

pull/15809/head
jjfeing 4 years ago
parent
commit
88c92cd263
3 changed files with 9 additions and 1 deletions
  1. +6
    -1
      mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc
  2. +1
    -0
      mindspore/ccsrc/utils/utils.h
  3. +2
    -0
      mindspore/core/ir/param_info.h

+ 6
- 1
mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc View File

@@ -17,7 +17,7 @@


#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility>
#include <set>


#include "ir/graph_utils.h" #include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
@@ -62,6 +62,11 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne
AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
const std::set<std::string> no_need_to_convert_nodes = {kStackOpName};
auto node_type = AnfAlgo::GetCNodeName(cnode);
if (no_need_to_convert_nodes.find(node_type) != no_need_to_convert_nodes.end()) {
return nullptr;
}
std::vector<AnfNodePtr> new_inputs; std::vector<AnfNodePtr> new_inputs;
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
auto inputs = cnode->inputs(); auto inputs = cnode->inputs();


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -119,6 +119,7 @@ constexpr auto kTransDataOpName = "TransData";
constexpr auto kStackInitOpName = "StackInit"; constexpr auto kStackInitOpName = "StackInit";
constexpr auto kStackPushOpName = "StackPush"; constexpr auto kStackPushOpName = "StackPush";
constexpr auto kStackPopOpName = "StackPop"; constexpr auto kStackPopOpName = "StackPop";
constexpr auto kStackOpName = "Stack";
constexpr auto kStackDestroyOpName = "StackDestroy"; constexpr auto kStackDestroyOpName = "StackDestroy";
constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad"; constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad";
constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad"; constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad";


+ 2
- 0
mindspore/core/ir/param_info.h View File

@@ -72,6 +72,7 @@ class ParamInfo {
this->be_cloned_ = true; this->be_cloned_ = true;
this->be_cloned_index_.push_back(index); this->be_cloned_index_.push_back(index);
clone->init_in_server_ = this->init_in_server_; clone->init_in_server_ = this->init_in_server_;
clone->ClearParameter();
return clone; return clone;
} }


@@ -88,6 +89,7 @@ class ParamInfo {
void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; }
ParameterPtr parameter() { return parameter_; } ParameterPtr parameter() { return parameter_; }
void set_parameter(const ParameterPtr &parameter) { parameter_ = parameter; } void set_parameter(const ParameterPtr &parameter) { parameter_ = parameter; }
void ClearParameter() { parameter_ = nullptr; }


private: private:
std::string name_{"Parameter"}; std::string name_{"Parameter"};


Loading…
Cancel
Save