Merge pull request !2604 from lichen_101010/lichen_tmptags/v0.6.0-beta
| @@ -13,13 +13,16 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" | #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" | ||||
| #include "backend/optimizer/mem_reuse/mem_reuse.h" | #include "backend/optimizer/mem_reuse/mem_reuse.h" | ||||
| #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" | #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" | ||||
| #ifdef ENABLE_D | #ifdef ENABLE_D | ||||
| #include "runtime/device/ascend/ascend_stream_assign.h" | #include "runtime/device/ascend/ascend_stream_assign.h" | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_DEBUGGER | |||||
| #include "debug/debugger/debugger.h" | |||||
| #include "debug/debug_services.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace memreuse { | namespace memreuse { | ||||
| @@ -75,6 +78,15 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | |||||
| MS_EXCEPTION_IF_NULL(mem_buf); | MS_EXCEPTION_IF_NULL(mem_buf); | ||||
| auto kernel_prev = mem_buf->used_kernel_; | auto kernel_prev = mem_buf->used_kernel_; | ||||
| MS_EXCEPTION_IF_NULL(kernel_prev); | MS_EXCEPTION_IF_NULL(kernel_prev); | ||||
| #ifdef ENABLE_DEBUGGER | |||||
| auto debugger_ = mindspore::Debugger::GetInstance(); | |||||
| DebugServices *debug_services = debugger_->debug_services(); | |||||
| auto watchpoint_table = debug_services->GetWatchpointTable(); | |||||
| std::string current_kernel_name = kernel_curr->scope_full_name(); | |||||
| if (debug_services->IsWatchPoint(current_kernel_name, watchpoint_table)) { | |||||
| return false; | |||||
| } | |||||
| #endif | |||||
| auto curr_stream_id = kernel_curr->stream_id(); | auto curr_stream_id = kernel_curr->stream_id(); | ||||
| auto prev_stream_id = kernel_prev->stream_id(); | auto prev_stream_id = kernel_prev->stream_id(); | ||||
| if (curr_stream_id == prev_stream_id) { | if (curr_stream_id == prev_stream_id) { | ||||
| @@ -331,6 +331,11 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); | device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); | ||||
| // build kernel | // build kernel | ||||
| BuildKernel(root_graph); | BuildKernel(root_graph); | ||||
| #ifdef ENABLE_DEBUGGER | |||||
| if (debugger_) { | |||||
| debugger_->PreExecute(root_graph); | |||||
| } | |||||
| #endif | |||||
| // alloc mem | // alloc mem | ||||
| MemoryAlloc(root_graph.get()); | MemoryAlloc(root_graph.get()); | ||||
| // task generate | // task generate | ||||
| @@ -407,6 +412,11 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||||
| BuildKernel(graph); | BuildKernel(graph); | ||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| #ifdef ENABLE_DEBUGGER | |||||
| if (debugger_) { | |||||
| debugger_->PreExecute(graph); | |||||
| } | |||||
| #endif | |||||
| if (ms_context->precompile_only()) { | if (ms_context->precompile_only()) { | ||||
| MS_LOG(INFO) << "Precompile only, stop in build kernel step"; | MS_LOG(INFO) << "Precompile only, stop in build kernel step"; | ||||
| } else { | } else { | ||||
| @@ -475,12 +485,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor:: | |||||
| LoadInputData(kernel_graph, inputs); | LoadInputData(kernel_graph, inputs); | ||||
| // convert inputs to model | // convert inputs to model | ||||
| predictmodel::StepConvertWeight(inputs); | predictmodel::StepConvertWeight(inputs); | ||||
| #ifdef ENABLE_DEBUGGER | |||||
| // debugger pre-execution processing | |||||
| if (debugger_) { | |||||
| debugger_->PreExecute(kernel_graph); | |||||
| } | |||||
| #endif | |||||
| { | { | ||||
| py::gil_scoped_release release; | py::gil_scoped_release release; | ||||
| // run task on device | // run task on device | ||||
| @@ -791,7 +795,8 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| DebugServices *debug_services = debugger_->debug_services(); | DebugServices *debug_services = debugger_->debug_services(); | ||||
| TensorLoader *tensor_loader = debug_services->get_tensor_loader(); | |||||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||||
| // TensorData will be freed up here | |||||
| tensor_loader->EmptyTensor(); | tensor_loader->EmptyTensor(); | ||||
| uint32_t iter_num = tensor_loader->GetIterNum(); | uint32_t iter_num = tensor_loader->GetIterNum(); | ||||
| tensor_loader->set_iter_num(++iter_num); | tensor_loader->set_iter_num(++iter_num); | ||||
| @@ -37,8 +37,8 @@ DebugServices &DebugServices::operator=(const DebugServices &other) { | |||||
| DebugServices::~DebugServices() { delete tensor_loader_; } | DebugServices::~DebugServices() { delete tensor_loader_; } | ||||
| void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition, | |||||
| const std::vector<std::tuple<std::string, bool>> &check_node_list) { | |||||
| void DebugServices::AddWatchpoint(unsigned int id, unsigned int watch_condition, | |||||
| const std::vector<std::tuple<std::string, bool>> &check_node_list) { | |||||
| std::lock_guard<std::mutex> lg(lock_); | std::lock_guard<std::mutex> lg(lock_); | ||||
| watchpoint_t watchpoint_item; | watchpoint_t watchpoint_item; | ||||
| @@ -57,14 +57,14 @@ void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition | |||||
| watchpoint_table[id] = watchpoint_item; | watchpoint_table[id] = watchpoint_item; | ||||
| } | } | ||||
| void DebugServices::remove_watchpoint(unsigned int id) { | |||||
| void DebugServices::RemoveWatchpoint(unsigned int id) { | |||||
| std::lock_guard<std::mutex> lg(lock_); | std::lock_guard<std::mutex> lg(lock_); | ||||
| watchpoint_table.erase(id); | watchpoint_table.erase(id); | ||||
| } | } | ||||
| void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, | |||||
| std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size, | |||||
| std::vector<int> *condition, std::vector<unsigned int> *wacthpoint_id) { | |||||
| void DebugServices::CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, | |||||
| std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size, | |||||
| std::vector<int> *condition, std::vector<unsigned int> *wacthpoint_id) { | |||||
| std::lock_guard<std::mutex> lg(lock_); | std::lock_guard<std::mutex> lg(lock_); | ||||
| std::vector<std::shared_ptr<TensorData>> tensor_list = tensor_loader_->GetTensor(); | std::vector<std::shared_ptr<TensorData>> tensor_list = tensor_loader_->GetTensor(); | ||||
| @@ -171,9 +171,9 @@ void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vecto | |||||
| } | } | ||||
| } | } | ||||
| void DebugServices::read_nodes_tensors(std::vector<std::string> name, std::vector<std::string> *ret_name, | |||||
| std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size, | |||||
| std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape) { | |||||
| void DebugServices::ReadNodesTensors(std::vector<std::string> name, std::vector<std::string> *ret_name, | |||||
| std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size, | |||||
| std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape) { | |||||
| std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list; | std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list; | ||||
| tensor_loader_->SearchTensors(name, &result_list); | tensor_loader_->SearchTensors(name, &result_list); | ||||
| @@ -189,6 +189,28 @@ void DebugServices::read_nodes_tensors(std::vector<std::string> name, std::vecto | |||||
| } | } | ||||
| } | } | ||||
| TensorLoader *DebugServices::get_tensor_loader() const { return tensor_loader_; } | |||||
| bool DebugServices::IsWatchPoint(std::string kernel_name, | |||||
| std::unordered_map<unsigned int, watchpoint_t> watchpoint_table) { | |||||
| bool ret = false; | |||||
| for (auto w_table_item : watchpoint_table) { | |||||
| auto check_node_list = std::get<1>(w_table_item).check_node_list; | |||||
| for (auto check_node : check_node_list) { | |||||
| std::string w_name = std::get<0>(check_node); | |||||
| bool w_type = std::get<1>(check_node); | |||||
| if ((w_type == true && | |||||
| ((kernel_name.find(w_name) != string::npos && kernel_name.rfind(w_name, 0) == 0) || w_name == "*")) || | |||||
| (w_type == false && kernel_name == w_name)) { | |||||
| ret = true; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| TensorLoader *DebugServices::tensor_loader() const { return tensor_loader_; } | |||||
| std::unordered_map<unsigned int, DebugServices::watchpoint_t> DebugServices::GetWatchpointTable() { | |||||
| return watchpoint_table; | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,22 +37,6 @@ class DebugServices { | |||||
| ~DebugServices(); | ~DebugServices(); | ||||
| void add_watchpoint(unsigned int id, unsigned int watch_condition, | |||||
| const std::vector<std::tuple<std::string, bool>> &check_node_list); | |||||
| void remove_watchpoint(unsigned int id); | |||||
| void check_watchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, std::vector<char *> *data_ptr, | |||||
| std::vector<unsigned int> *data_size, std::vector<int> *condition, | |||||
| std::vector<unsigned int> *wacthpoint_id); | |||||
| void read_nodes_tensors(std::vector<std::string> name, std::vector<std::string> *ret_name, | |||||
| std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size, | |||||
| std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape); | |||||
| TensorLoader *get_tensor_loader() const; | |||||
| private: | |||||
| typedef struct condition_no_param { | typedef struct condition_no_param { | ||||
| bool enabled = false; | bool enabled = false; | ||||
| } condition_no_param_t; | } condition_no_param_t; | ||||
| @@ -84,6 +68,26 @@ class DebugServices { | |||||
| std::vector<std::tuple<std::string, bool>> check_node_list; | std::vector<std::tuple<std::string, bool>> check_node_list; | ||||
| } watchpoint_t; | } watchpoint_t; | ||||
| void AddWatchpoint(unsigned int id, unsigned int watch_condition, | |||||
| const std::vector<std::tuple<std::string, bool>> &check_node_list); | |||||
| void RemoveWatchpoint(unsigned int id); | |||||
| void CheckWatchpoints(std::vector<std::string> *name, std::vector<std::string> *slot, std::vector<char *> *data_ptr, | |||||
| std::vector<unsigned int> *data_size, std::vector<int> *condition, | |||||
| std::vector<unsigned int> *wacthpoint_id); | |||||
| void ReadNodesTensors(std::vector<std::string> name, std::vector<std::string> *ret_name, | |||||
| std::vector<char *> *data_ptr, std::vector<unsigned int> *data_size, | |||||
| std::vector<TypePtr> *dtype, std::vector<std::vector<int>> *shape); | |||||
| bool IsWatchPoint(std::string kernel_name, std::unordered_map<unsigned int, watchpoint_t> watchpoint_table); | |||||
| TensorLoader *tensor_loader() const; | |||||
| std::unordered_map<unsigned int, watchpoint_t> GetWatchpointTable(); | |||||
| private: | |||||
| std::mutex lock_; | std::mutex lock_; | ||||
| std::unordered_map<unsigned int, watchpoint_t> watchpoint_table; | std::unordered_map<unsigned int, watchpoint_t> watchpoint_table; | ||||
| @@ -43,7 +43,8 @@ Debugger::Debugger() | |||||
| device_id_(0), | device_id_(0), | ||||
| num_step_(0), | num_step_(0), | ||||
| debugger_enabled_(false), | debugger_enabled_(false), | ||||
| is_dataset_graph_(false) {} | |||||
| is_dataset_graph_(false), | |||||
| partial_memory_(false) {} | |||||
| void Debugger::Init(const uint32_t device_id) { | void Debugger::Init(const uint32_t device_id) { | ||||
| // access lock for public method | // access lock for public method | ||||
| @@ -57,6 +58,7 @@ void Debugger::EnableDebugger() { | |||||
| // reset some of the class members | // reset some of the class members | ||||
| num_step_ = 0; | num_step_ = 0; | ||||
| debugger_enabled_ = false; | debugger_enabled_ = false; | ||||
| partial_memory_ = false; | |||||
| grpc_client_ = nullptr; | grpc_client_ = nullptr; | ||||
| debug_services_ = nullptr; | debug_services_ = nullptr; | ||||
| @@ -72,7 +74,8 @@ void Debugger::EnableDebugger() { | |||||
| MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger."; | MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger."; | ||||
| return; | return; | ||||
| } | } | ||||
| // configure host | |||||
| // configure grpc host | |||||
| const char *env_host_str = std::getenv("MS_DEBUGGER_HOST"); | const char *env_host_str = std::getenv("MS_DEBUGGER_HOST"); | ||||
| std::string host; | std::string host; | ||||
| if (env_host_str != nullptr) { | if (env_host_str != nullptr) { | ||||
| @@ -82,7 +85,7 @@ void Debugger::EnableDebugger() { | |||||
| MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost"; | MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost"; | ||||
| host = "localhost"; | host = "localhost"; | ||||
| } | } | ||||
| // configure port | |||||
| // configure grpc port | |||||
| const char *env_port_str = std::getenv("MS_DEBUGGER_PORT"); | const char *env_port_str = std::getenv("MS_DEBUGGER_PORT"); | ||||
| std::string port; | std::string port; | ||||
| if (env_port_str != nullptr) { | if (env_port_str != nullptr) { | ||||
| @@ -93,6 +96,27 @@ void Debugger::EnableDebugger() { | |||||
| port = "50051"; | port = "50051"; | ||||
| } | } | ||||
| // configure partial memory reuse | |||||
| const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM"); | |||||
| if (env_partial_mem_str != nullptr) { | |||||
| MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str; | |||||
| if (std::strcmp(env_partial_mem_str, "1") == 0) { | |||||
| partial_memory_ = true; | |||||
| } | |||||
| } | |||||
| // switch memory reuse on or off | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| context_ptr->set_enable_mem_reuse(partial_memory_); | |||||
| // print some message about memory reuse to user | |||||
| if (partial_memory_) { | |||||
| MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first " | |||||
| "step. 2. Tensor values are only available for nodes that are watched by any watchpoint."; | |||||
| } else { | |||||
| MS_LOG(WARNING) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory " | |||||
| "usage for large models."; | |||||
| } | |||||
| // initialize grpc client | // initialize grpc client | ||||
| grpc_client_ = std::make_unique<GrpcClient>(host, port); | grpc_client_ = std::make_unique<GrpcClient>(host, port); | ||||
| debug_services_ = std::make_unique<DebugServices>(); | debug_services_ = std::make_unique<DebugServices>(); | ||||
| @@ -106,6 +130,7 @@ void Debugger::Reset() { | |||||
| num_step_ = 0; | num_step_ = 0; | ||||
| debugger_enabled_ = false; | debugger_enabled_ = false; | ||||
| is_dataset_graph_ = false; | is_dataset_graph_ = false; | ||||
| partial_memory_ = false; | |||||
| graph_ptr_ = nullptr; | graph_ptr_ = nullptr; | ||||
| grpc_client_ = nullptr; | grpc_client_ = nullptr; | ||||
| debug_services_ = nullptr; | debug_services_ = nullptr; | ||||
| @@ -317,11 +342,10 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon | |||||
| [](WatchNode node) -> std::tuple<std::string, bool> { | [](WatchNode node) -> std::tuple<std::string, bool> { | ||||
| return make_tuple(node.node_name(), node.node_type() == "scope"); | return make_tuple(node.node_name(), node.node_type() == "scope"); | ||||
| }); | }); | ||||
| debug_services_->add_watchpoint(id, condition.condition(), check_node_list); | |||||
| debug_services_->AddWatchpoint(id, condition.condition(), check_node_list); | |||||
| } | } | ||||
| void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); } | |||||
| void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); } | |||||
| std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const { | std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const { | ||||
| std::vector<std::string> name; | std::vector<std::string> name; | ||||
| @@ -335,7 +359,7 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten | |||||
| // ret_name will contain tensor names that are found in TensorLoader | // ret_name will contain tensor names that are found in TensorLoader | ||||
| // items in ret_name will be in the same order with tensors if found | // items in ret_name will be in the same order with tensors if found | ||||
| debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape); | |||||
| debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape); | |||||
| std::list<TensorProto> tensor_list; | std::list<TensorProto> tensor_list; | ||||
| unsigned int result_index = 0; | unsigned int result_index = 0; | ||||
| @@ -384,8 +408,7 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() const { | |||||
| std::vector<int> condition; | std::vector<int> condition; | ||||
| std::vector<unsigned int> watchpoint_id; | std::vector<unsigned int> watchpoint_id; | ||||
| debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); | |||||
| debug_services_->CheckWatchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); | |||||
| std::list<WatchpointHit> hits; | std::list<WatchpointHit> hits; | ||||
| for (unsigned int i = 0; i < name.size(); i++) { | for (unsigned int i = 0; i < name.size(); i++) { | ||||
| WatchpointHit hit; | WatchpointHit hit; | ||||
| @@ -494,4 +517,6 @@ std::string GetTensorFullName(const TensorProto &tensor) { | |||||
| return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter()); | return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter()); | ||||
| } | } | ||||
| bool Debugger::partial_memory() { return partial_memory_; } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -76,6 +76,8 @@ class Debugger : public std::enable_shared_from_this<Debugger> { | |||||
| bool debugger_enabled() const; | bool debugger_enabled() const; | ||||
| bool partial_memory(); | |||||
| private: | private: | ||||
| // private constructor for singleton | // private constructor for singleton | ||||
| Debugger(); | Debugger(); | ||||
| @@ -129,6 +131,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> { | |||||
| int32_t num_step_; | int32_t num_step_; | ||||
| bool debugger_enabled_; | bool debugger_enabled_; | ||||
| bool is_dataset_graph_; | bool is_dataset_graph_; | ||||
| bool partial_memory_; | |||||
| std::mutex access_lock_; | std::mutex access_lock_; | ||||
| // singleton | // singleton | ||||
| @@ -51,25 +51,13 @@ class TensorData { | |||||
| int GetExecutionOrder() { return this->execution_order; } | int GetExecutionOrder() { return this->execution_order; } | ||||
| int SetExecutionOrder(int execution_order) { | |||||
| this->execution_order = execution_order; | |||||
| return true; | |||||
| } | |||||
| void SetExecutionOrder(int execution_order) { this->execution_order = execution_order; } | |||||
| int SetName(const std::string &name) { | |||||
| this->name = name; | |||||
| return true; | |||||
| } | |||||
| void SetName(const std::string &name) { this->name = name; } | |||||
| bool SetTensor(mindspore::tensor::TensorPtr out_tensor) { | |||||
| this->tensor_ptr = out_tensor; | |||||
| return true; | |||||
| } | |||||
| void SetTensor(mindspore::tensor::TensorPtr out_tensor) { this->tensor_ptr = out_tensor; } | |||||
| bool SetSlot(size_t slot) { | |||||
| this->slot = slot; | |||||
| return true; | |||||
| } | |||||
| void SetSlot(size_t slot) { this->slot = slot; } | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ | #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <mutex> | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -28,9 +29,10 @@ class TensorLoader { | |||||
| public: | public: | ||||
| TensorLoader() : iter_num(-1) {} | TensorLoader() : iter_num(-1) {} | ||||
| ~TensorLoader() {} | |||||
| ~TensorLoader() { EmptyTensor(); } | |||||
| bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) { | bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) { | ||||
| std::lock_guard<std::mutex> lg(lock_); | |||||
| if (keep_prev) { | if (keep_prev) { | ||||
| // add prev step tensor into current step map with ":prev" suffix | // add prev step tensor into current step map with ":prev" suffix | ||||
| auto handle = prev_tensor_list_map.extract(tensor->GetName()); | auto handle = prev_tensor_list_map.extract(tensor->GetName()); | ||||
| @@ -61,11 +63,11 @@ class TensorLoader { | |||||
| } | } | ||||
| } | } | ||||
| bool EmptyTensor() { | |||||
| void EmptyTensor() { | |||||
| std::lock_guard<std::mutex> lg(lock_); | |||||
| prev_tensor_list_map.clear(); | prev_tensor_list_map.clear(); | ||||
| tensor_list_map.swap(prev_tensor_list_map); | tensor_list_map.swap(prev_tensor_list_map); | ||||
| tensor_list.clear(); | tensor_list.clear(); | ||||
| return true; | |||||
| } | } | ||||
| void EmptyPrevTensor() { prev_tensor_list_map.clear(); } | void EmptyPrevTensor() { prev_tensor_list_map.clear(); } | ||||
| @@ -77,6 +79,7 @@ class TensorLoader { | |||||
| std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map; | std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map; | ||||
| std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map; | std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map; | ||||
| uint32_t iter_num; | uint32_t iter_num; | ||||
| std::mutex lock_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ | #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ | ||||
| @@ -372,10 +372,13 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens | |||||
| const std::string &host_fmt, const std::vector<int> &host_shape, | const std::string &host_fmt, const std::vector<int> &host_shape, | ||||
| TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const { | TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const { | ||||
| bool ret = false; | bool ret = false; | ||||
| DebugServices *debug_services = debugger->debug_services(); | DebugServices *debug_services = debugger->debug_services(); | ||||
| TensorLoader *tensor_loader = debug_services->get_tensor_loader(); | |||||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||||
| // TensorData is freed up in AscendSession class | |||||
| auto tensor_data = std::make_shared<mindspore::TensorData>(); | |||||
| tensor_data->SetName(tensor_name); | |||||
| tensor_data->SetExecutionOrder(execution_order); | |||||
| tensor_data->SetSlot(slot); | |||||
| if (trans_flag) { | if (trans_flag) { | ||||
| MS_LOG(INFO) << "E2E tensor name is " << tensor_name; | MS_LOG(INFO) << "E2E tensor name is " << tensor_name; | ||||
| mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape); | mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape); | ||||
| @@ -385,28 +388,18 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens | |||||
| MS_LOG(ERROR) << "Copy device mem to host failed"; | MS_LOG(ERROR) << "Copy device mem to host failed"; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| auto tensor_data = std::make_shared<mindspore::TensorData>(); | |||||
| tensor_data->SetName(tensor_name); | |||||
| tensor_data->SetExecutionOrder(execution_order); | |||||
| tensor_data->SetTensor(out_tensor); | tensor_data->SetTensor(out_tensor); | ||||
| tensor_data->SetSlot(slot); | |||||
| ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); | |||||
| } else { | } else { | ||||
| mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape); | mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape); | ||||
| size_t host_size = out_tensor->data().nbytes(); | size_t host_size = out_tensor->data().nbytes(); | ||||
| auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); | auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); | ||||
| auto tensor_data = std::make_shared<mindspore::TensorData>(); | |||||
| tensor_data->SetName(tensor_name); | |||||
| tensor_data->SetExecutionOrder(execution_order); | |||||
| tensor_data->SetTensor(out_tensor); | |||||
| tensor_data->SetSlot(slot); | |||||
| ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); | |||||
| if (ret_rt_memcpy != RT_ERROR_NONE) { | if (ret_rt_memcpy != RT_ERROR_NONE) { | ||||
| MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; | MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; | ||||
| } | } | ||||
| MS_LOG(INFO) << "E2E tensor name is " << tensor_name; | MS_LOG(INFO) << "E2E tensor name is " << tensor_name; | ||||
| tensor_data->SetTensor(out_tensor); | |||||
| } | } | ||||
| ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -311,15 +311,24 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { | |||||
| namespace { | namespace { | ||||
| void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { | void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // trans_flag: "true" means tensor values will be transfered to host format, otherwise not. | |||||
| bool trans_flag = false; | bool trans_flag = false; | ||||
| const auto &apply_kernels = graph->execution_order(); | const auto &apply_kernels = graph->execution_order(); | ||||
| // for kernels, execution order starts from 1 | // for kernels, execution order starts from 1 | ||||
| int exec_order = 1; | int exec_order = 1; | ||||
| auto debugger_ = mindspore::Debugger::GetInstance(); | |||||
| DebugServices *debug_services = debugger_->debug_services(); | |||||
| auto watchpoint_table = debug_services->GetWatchpointTable(); | |||||
| for (const auto &node : apply_kernels) { | for (const auto &node : apply_kernels) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto node_name = AnfAlgo::GetCNodeName(node); | auto node_name = AnfAlgo::GetCNodeName(node); | ||||
| std::string kernel_name = node->fullname_with_scope(); | std::string kernel_name = node->fullname_with_scope(); | ||||
| auto output_size = AnfAlgo::GetOutputTensorNum(node); | auto output_size = AnfAlgo::GetOutputTensorNum(node); | ||||
| if (debugger_->partial_memory()) { | |||||
| if (!debug_services->IsWatchPoint(kernel_name, watchpoint_table)) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| for (size_t j = 0; j < output_size; ++j) { | for (size_t j = 0; j < output_size; ++j) { | ||||
| auto addr = AnfAlgo::GetOutputAddr(node, j); | auto addr = AnfAlgo::GetOutputAddr(node, j); | ||||
| auto type = AnfAlgo::GetOutputInferDataType(node, j); | auto type = AnfAlgo::GetOutputInferDataType(node, j); | ||||
| @@ -347,6 +356,7 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { | |||||
| void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { | void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // trans_flag: "true" means tensor values will be transfered to host format, otherwise not. | |||||
| bool trans_flag = false; | bool trans_flag = false; | ||||
| const auto ¶meters = graph->inputs(); | const auto ¶meters = graph->inputs(); | ||||
| // for parameters, set its execution order to be 0; | // for parameters, set its execution order to be 0; | ||||