Browse Source

session code review

tags/v0.6.0-beta
Margaret_wangrui 5 years ago
parent
commit
9c784a6c74
1 changed files with 14 additions and 16 deletions
  1. +14
    -16
      mindspore/ccsrc/session/session_basic.cc

+ 14
- 16
mindspore/ccsrc/session/session_basic.cc View File

@@ -38,8 +38,8 @@


namespace mindspore { namespace mindspore {
namespace session { namespace session {
static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras_;
void ClearPythonParasMap() { python_paras_ = nullptr; }
static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras;
void ClearPythonParasMap() { python_paras = nullptr; }
namespace { namespace {
const int kSummaryGetItem = 2; const int kSummaryGetItem = 2;


@@ -387,17 +387,17 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
MS_EXCEPTION_IF_NULL(graph_inputs); MS_EXCEPTION_IF_NULL(graph_inputs);
ParameterPtr new_parameter = nullptr; ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
} }
auto iter = python_paras_->find(param_value);
if (iter != python_paras_->end()) {
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {
new_parameter = iter->second; new_parameter = iter->second;
} else { } else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (param_value != nullptr) { if (param_value != nullptr) {
(*python_paras_)[param_value] = new_parameter;
(*python_paras)[param_value] = new_parameter;
} }
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
@@ -469,7 +469,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
cnode_inputs.emplace_back(new_value_node); cnode_inputs.emplace_back(new_value_node);
} }
continue; continue;
} else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
} else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
cnode_inputs.push_back(new_parameter); cnode_inputs.push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) { if (GetGraphIdByNode(anf) == kInvalidGraphId) {
@@ -481,15 +481,13 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) { } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
continue; continue;
} else if (anf->isa<AnfNode>()) {
} else {
*from_other_graph = true; *from_other_graph = true;
// the input node is a cnode from other graph // the input node is a cnode from other graph
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
cnode_inputs.push_back(parameter_from_cnode); cnode_inputs.push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode; (*other_graph_cnode)[anf] = parameter_from_cnode;
continue;
} }
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
} }
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs); auto new_cnode = graph->NewCNode(cnode_inputs);
@@ -660,17 +658,17 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph


auto param_value = GetParamDefaultValue(anf); auto param_value = GetParamDefaultValue(anf);
ParameterPtr new_parameter = nullptr; ParameterPtr new_parameter = nullptr;
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
} }
auto iter = python_paras_->find(param_value);
if (iter != python_paras_->end()) {
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {
new_parameter = iter->second; new_parameter = iter->second;
} else { } else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (param_value != nullptr) { if (param_value != nullptr) {
(*python_paras_)[param_value] = new_parameter;
(*python_paras)[param_value] = new_parameter;
} }
TraceManager::EndTrace(); TraceManager::EndTrace();
} }


Loading…
Cancel
Save