|
|
|
@@ -34,8 +34,30 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
static std::shared_ptr<std::map<tensor::TensorPtr, ParameterPtr>> python_paras_; |
|
|
|
void ClearPythonParasMap() { python_paras_ = nullptr; } |
|
|
|
namespace { |
|
|
|
const int kSummaryGetItem = 2; |
|
|
|
|
|
|
|
tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { |
|
|
|
if (node == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto parameter = node->cast<ParameterPtr>(); |
|
|
|
if (parameter == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto py_param = parameter->default_param(); |
|
|
|
if (!py::hasattr(py_param, "default_input")) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto py_p_input = py_param.attr("default_input"); |
|
|
|
if (!py::hasattr(py_p_input, PYTHON_TENSOR_FLAG)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return py_p_input.cast<std::shared_ptr<tensor::Tensor>>(); |
|
|
|
} |
|
|
|
|
|
|
|
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { |
|
|
|
MS_LOG(DEBUG) << "Update summary Start"; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
@@ -195,21 +217,6 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
return new_value_node; |
|
|
|
} |
|
|
|
|
|
|
|
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
if (!anf->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; |
|
|
|
} |
|
|
|
auto graph_inputs = graph->MutableInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
auto valid_inputs = graph->MutableValidInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs); |
|
|
|
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
valid_inputs->push_back(valid_input); |
|
|
|
return new_parameter; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
@@ -358,6 +365,35 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { |
|
|
|
} // namespace |
|
|
|
|
|
|
|
GraphId SessionBasic::graph_sum_ = 0; |
|
|
|
ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, |
|
|
|
KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
if (!anf->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; |
|
|
|
} |
|
|
|
|
|
|
|
auto m_tensor = GetParamDefaultInputTensor(anf); |
|
|
|
auto valid_inputs = graph->MutableValidInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs); |
|
|
|
auto graph_inputs = graph->MutableInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_inputs); |
|
|
|
ParameterPtr new_parameter = nullptr; |
|
|
|
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter |
|
|
|
if (python_paras_ == nullptr) { |
|
|
|
python_paras_ = std::make_shared<std::map<tensor::TensorPtr, ParameterPtr>>(); |
|
|
|
} |
|
|
|
if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) != kInvalidGraphId) { |
|
|
|
new_parameter = (*python_paras_)[m_tensor]; |
|
|
|
} else { |
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
if (m_tensor != nullptr) { |
|
|
|
(*python_paras_)[m_tensor] = new_parameter; |
|
|
|
} |
|
|
|
} |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
valid_inputs->push_back(valid_input); |
|
|
|
return new_parameter; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, |
|
|
|
bool *from_other_graph, |
|
|
|
@@ -391,7 +427,6 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
|
} |
|
|
|
continue; |
|
|
|
} else if (anf->isa<Parameter>()) { |
|
|
|
// if anf is a parameter |
|
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); |
|
|
|
cnode_inputs.push_back(new_parameter); |
|
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) { |
|
|
|
|