/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ #define MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ #include #include #include #include #include #include #include "backend/session/kernel_graph.h" #include "debug/debugger/grpc_client.h" #include "debug/debug_services.h" using debugger::Chunk; using debugger::DataType; using debugger::EventReply; using debugger::GraphProto; using debugger::ModelProto; using debugger::TensorProto; using debugger::WatchCondition; using debugger::WatchNode; using debugger::WatchpointHit; template using ProtoVector = google::protobuf::RepeatedPtrField; namespace mindspore { // different types of command recieved by debugger // need to keep sync with client-side proto and server-side proto enum class DebuggerCommand { kExitCMD = 2, kRunCMD = 3, kSetCMD = 4, kViewCMD = 5, kUnknownCMD = -1 }; class Debugger : public std::enable_shared_from_this { public: static std::shared_ptr GetInstance() { std::lock_guard i_lock(instance_lock_); if (debugger_ == nullptr) { debugger_ = std::shared_ptr(new (std::nothrow) Debugger()); } return debugger_; } // deconstructor ~Debugger() = default; // init // only save device_id void Init(const uint32_t device_id, const std::string device_target); // reset debugger void Reset(); // enable debugger // send graph and wait for command // do nothing if graph is set already void PreExecute(const KernelGraphPtr &graph_ptr); // analyze tensors and wait for command // don't need a graph_ptr because it is saved during pre_execute void PostExecute(); bool ReadNodeDataRequired(); void PostExecuteNode(); // suspend the execution after a debug_op void PostDebugOp(); DebugServices *debug_services() const; bool debugger_enabled() const; bool partial_memory(); void SetCurNode(std::string cur_name); std::string run_level() const; void SetStepNum(int32_t cur_num_step); int32_t step_num() const; void SetStreamTaskToOpnameMap(const std::map, std::string> &mapping); // check if any feature that uses the debugger backend is enabled bool DebuggerBackendEnabled(); void SetTrainingDone(bool training_done); void SendMetadata(); void LoadParametersAndConst(); void UpdateStepNum(); void ClearCurrentData(); void LoadGraphOutputs(); private: // private constructor for singleton Debugger(); // enable debugger // instantiate class members // read env variable for grpc client void EnableDebugger(); // check if dump using debugger backend is enabled bool CheckDebuggerDumpEnabled(); // check if debugger enabled bool CheckDebuggerEnabled(); bool CheckDebuggerPartialMemoryEnabled(); // check and save graph pointer void CheckGraphPtr(const KernelGraphPtr &graph_ptr); // check if the graph is a dataset graph void CheckDatasetGraph(); // serialize graph and get proto GraphProto GetGraphProto() const; // send graph and enter command wait loop void SendGraphAndSuspend(const GraphProto &graph_proto); // wait for command and process command // send command request and process reply in a loop // break if RunCMD void CommandLoop(); // set what nodes and conditions to watch void SetWatchpoint(const ProtoVector &nodes, const WatchCondition &condition, const int32_t id); // remove watchpoint with id void RemoveWatchpoint(const int32_t id); // load tensor for view command std::list LoadTensors(const ProtoVector &tensors) const; // terminate training process void Exit(); // analyze tensors and check watchpoint conditions // return names of tensors and what condition they hit std::list CheckWatchpoints(const std::string &watchnode = std::string()); // send watchpoints that hit and enter command wait loop void SendWatchpointsAndSuspend(const std::list &points); // Find if any operation overflow happened and return their names std::vector CheckOpOverflow(); // Check if the port is valid bool CheckPort(const char *port); void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index); // class members std::unique_ptr grpc_client_; std::unique_ptr debug_services_; KernelGraphPtr graph_ptr_; uint32_t device_id_; std::string device_target_; int32_t num_step_; bool debugger_enabled_; std::string run_level_; std::string node_name_; std::string cur_name_; bool training_done_; bool is_dataset_graph_; bool partial_memory_; std::mutex access_lock_; std::map, std::string> stream_task_to_opname_; double last_overflow_bin_; std::string overflow_bin_path_; // singleton static std::mutex instance_lock_; static std::shared_ptr debugger_; }; using DebuggerPtr = std::shared_ptr; // get debugger ModelProto std::string GetDebuggerFuncGraphProtoString(const FuncGraphPtr &func_graph); ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph); // for getting proto DataType from Type of Tensor DataType GetDebuggerNumberDataType(const TypePtr &type); // process reply and command type DebuggerCommand GetCommand(const EventReply &reply); // parse other data out of EventReply ProtoVector GetWatchnodes(const EventReply &reply); std::string GetNodeName(const EventReply &reply); std::string GetRunLevel(const EventReply &reply); WatchCondition GetWatchcondition(const EventReply &reply); int32_t GetWatchpointID(const EventReply &reply); bool GetWatchpointDelete(const EventReply &reply); ProtoVector GetTensors(const EventReply &reply); // get the full name of a tensor, which is the name used in TensorLoader std::string GetTensorFullName(const TensorProto &tensor); uint64_t BytestoInt64(const std::vector &buffer); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_