|
|
|
@@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy; |
|
|
|
|
|
|
|
const char SINGLE_OP_GRAPH[] = "single_op_graph"; |
|
|
|
// primitive unable to infer value for constant input in PyNative mode |
|
|
|
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; |
|
|
|
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"}; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace pynative { |
|
|
|
@@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } |
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { |
|
|
|
auto cell_id = GetId(cell); |
|
|
|
if (cell_graph_map_.count(cell_id) != 0) { |
|
|
|
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { |
|
|
|
resource_ = cell_resource_map_[cell_id]; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Newgraph already compiled"; |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
|
|
|
|
if (top_g_ == nullptr) { |
|
|
|
top_g_ = curr_g_ = g; |
|
|
|
resource_ = std::make_shared<pipeline::Resource>(); |
|
|
|
cell_resource_map_[cell_id] = resource_; |
|
|
|
df_builder_ = std::make_shared<FuncGraph>(); |
|
|
|
MS_LOG(DEBUG) << "First new graph" << top_g_.get(); |
|
|
|
Pushp(); |
|
|
|
@@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
MS_LOG(DEBUG) << "Clear res"; |
|
|
|
(void)graph_map_.erase(flag); |
|
|
|
(void)cell_graph_map_.erase(flag); |
|
|
|
(void)cell_resource_map_.erase(flag); |
|
|
|
Clean(); |
|
|
|
// Maybe exit in the pynative runing op, so need reset pynative flag. |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
@@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Clear"; |
|
|
|
top_g_ = nullptr; |
|
|
|
df_builder_ = nullptr; |
|
|
|
curr_g_ = nullptr; |
|
|
|
graph_info_map_.clear(); |
|
|
|
op_id_map_.clear(); |
|
|
|
@@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() { |
|
|
|
Clear(); |
|
|
|
grad_flag_ = false; |
|
|
|
op_forward_map_.clear(); |
|
|
|
df_builder_ = nullptr; |
|
|
|
ad::CleanRes(); |
|
|
|
pipeline::ReclaimOptimizer(); |
|
|
|
} |
|
|
|
|