Merge pull request !705 from chujinjin/add_pynative_cachetags/v0.3.0-alpha
| @@ -22,7 +22,7 @@ namespace mindspore { | |||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| const uint64_t kAscendDeviceMemGB = 20; | const uint64_t kAscendDeviceMemGB = 20; | ||||
| const uint64_t kAscendMemPoolGB = 5; | |||||
| const uint64_t kAscendMemPoolGB = 10; | |||||
| const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); | const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); | ||||
| const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); | const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); | ||||
| @@ -38,6 +38,7 @@ | |||||
| #include "parallel/graph_util/get_parallel_info.h" | #include "parallel/graph_util/get_parallel_info.h" | ||||
| #include "device/kernel_runtime_manager.h" | #include "device/kernel_runtime_manager.h" | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "pynative/pynative_execute.h" | |||||
| #if (ENABLE_GE || ENABLE_D) | #if (ENABLE_GE || ENABLE_D) | ||||
| #include "pipeline/pipeline_ge.h" | #include "pipeline/pipeline_ge.h" | ||||
| @@ -829,6 +830,7 @@ void FinalizeBackend() { | |||||
| void ClearResAtexit() { | void ClearResAtexit() { | ||||
| MS_LOG(DEBUG) << "Pipeline clear all resource"; | MS_LOG(DEBUG) << "Pipeline clear all resource"; | ||||
| pynative::ClearPyNativeSession(); | |||||
| device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | ||||
| ad::g_k_prims.clear(); | ad::g_k_prims.clear(); | ||||
| @@ -44,6 +44,7 @@ const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "ze | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pynative { | namespace pynative { | ||||
| static std::shared_ptr<session::SessionBasic> session = nullptr; | |||||
| inline ValuePtr PyAttrValue(const py::object &obj) { | inline ValuePtr PyAttrValue(const py::object &obj) { | ||||
| ValuePtr converted_ret = nullptr; | ValuePtr converted_ret = nullptr; | ||||
| bool converted = parse::ConvertData(obj, &converted_ret); | bool converted = parse::ConvertData(obj, &converted_ret); | ||||
| @@ -310,7 +311,11 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||||
| if (device_target != kAscendDevice && device_target != kGPUDevice) { | if (device_target != kAscendDevice && device_target != kGPUDevice) { | ||||
| MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; | MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; | ||||
| } | } | ||||
| std::shared_ptr<session::SessionBasic> session = session::SessionFactory::Get().Create(device_target); | |||||
| if (session == nullptr) { | |||||
| session = session::SessionFactory::Get().Create(device_target); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(session); | MS_EXCEPTION_IF_NULL(session); | ||||
| session->Init(ms_context->device_id()); | session->Init(ms_context->device_id()); | ||||
| @@ -407,5 +412,7 @@ py::tuple RunOp(const py::args &args) { | |||||
| MS_LOG(INFO) << "RunOp end"; | MS_LOG(INFO) << "RunOp end"; | ||||
| return result; | return result; | ||||
| } | } | ||||
| void ClearPyNativeSession() { session = nullptr; } | |||||
| } // namespace pynative | } // namespace pynative | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,9 @@ namespace py = pybind11; | |||||
| py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); | py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); | ||||
| py::tuple RunOp(const py::args &args); | py::tuple RunOp(const py::args &args); | ||||
| void ClearPyNativeSession(); | |||||
| } // namespace pynative | } // namespace pynative | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -249,10 +249,23 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { | |||||
| if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<bool> &tensors_mask) { | const std::vector<bool> &tensors_mask) { | ||||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; | MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; | ||||
| if (GraphCacheExist(graph_info)) { | |||||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; | |||||
| return; | |||||
| } | |||||
| // construct graph include one op | // construct graph include one op | ||||
| auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| @@ -267,6 +280,7 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||||
| RunOpAdjustKernel(graph); | RunOpAdjustKernel(graph); | ||||
| BuildKernel(graph); | BuildKernel(graph); | ||||
| run_op_graphs_[graph_info] = graph; | run_op_graphs_[graph_info] = graph; | ||||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; | |||||
| } | } | ||||
| py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| @@ -291,7 +305,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr | |||||
| } | } | ||||
| py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_; | py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_; | ||||
| py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj); | py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj); | ||||
| run_op_graphs_.clear(); | |||||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | ||||
| return tuple_tensors; | return tuple_tensors; | ||||
| } | } | ||||
| @@ -111,6 +111,8 @@ class AscendSession : public SessionBasic { | |||||
| std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id); | std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id); | ||||
| // copy output of if and else | // copy output of if and else | ||||
| void CopyOutputOfIf(GraphId false_graph_id); | void CopyOutputOfIf(GraphId false_graph_id); | ||||
| // check if graph cache exist | |||||
| bool GraphCacheExist(const GraphInfo &graph_info) const; | |||||
| // member variables | // member variables | ||||
| // key is final_graph_id,value is child graph execute order of final graph | // key is final_graph_id,value is child graph execute order of final graph | ||||
| @@ -125,7 +125,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||||
| // if in paynative mode,data only copyed to host when user want to print data | // if in paynative mode,data only copyed to host when user want to print data | ||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| if (ms_context->enable_pynative_infer()) { | |||||
| if (ms_context->execution_mode() == kPynativeMode) { | |||||
| tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | ||||
| } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), | } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), | ||||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | LongToSize(tensor->data().nbytes()), tensor->data_type(), | ||||