|
|
|
@@ -32,15 +32,16 @@ |
|
|
|
#include "pre_activate/common/helper.h" |
|
|
|
#include "common/utils.h" |
|
|
|
#include "ir/dtype.h" |
|
|
|
#include "ir/anf.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
static std::shared_ptr<std::map<tensor::TensorPtr, ParameterPtr>> python_paras_; |
|
|
|
static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_; |
|
|
|
void ClearPythonParasMap() { python_paras_ = nullptr; } |
|
|
|
namespace { |
|
|
|
const int kSummaryGetItem = 2; |
|
|
|
|
|
|
|
tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { |
|
|
|
PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) { |
|
|
|
if (node == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -50,14 +51,7 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param()); |
|
|
|
auto py_param = param_value->value(); |
|
|
|
if (!py::hasattr(py_param, "default_input")) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto py_p_input = py_param.attr("default_input"); |
|
|
|
if (!py::hasattr(py_p_input, PYTHON_TENSOR_FLAG)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return py_p_input.cast<std::shared_ptr<tensor::Tensor>>(); |
|
|
|
return py_param.ptr(); |
|
|
|
} |
|
|
|
|
|
|
|
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { |
|
|
|
@@ -375,15 +369,17 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf |
|
|
|
ParameterPtr new_parameter = nullptr; |
|
|
|
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter |
|
|
|
if (python_paras_ == nullptr) { |
|
|
|
python_paras_ = std::make_shared<std::map<tensor::TensorPtr, ParameterPtr>>(); |
|
|
|
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>(); |
|
|
|
} |
|
|
|
if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) != kInvalidGraphId) { |
|
|
|
if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) { |
|
|
|
new_parameter = (*python_paras_)[m_tensor]; |
|
|
|
} else { |
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
if (m_tensor != nullptr) { |
|
|
|
(*python_paras_)[m_tensor] = new_parameter; |
|
|
|
} |
|
|
|
TraceManager::EndTrace(); |
|
|
|
} |
|
|
|
graph_inputs->push_back(new_parameter); |
|
|
|
valid_inputs->push_back(valid_input); |
|
|
|
|