Merge pull request !3007 from john_tzanakakis/master_ms1tags/v0.7.0-beta
| @@ -279,6 +279,9 @@ checkopts() | |||
| done | |||
| } | |||
| checkopts "$@" | |||
| if [[ "X$ENABLE_GPU" = "Xon" ]] && [[ "X$ENABLE_DUMPE2E" = "Xon" ]]; then | |||
| ENABLE_DEBUGGER="on" | |||
| fi | |||
| echo "---------------- MindSpore: build start ----------------" | |||
| mkdir -pv "${BUILD_PATH}/package/mindspore/lib" | |||
| git submodule update --init graphengine | |||
| @@ -37,6 +37,7 @@ | |||
| #include "common/trans.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/base_ref_extends.h" | |||
| #include "debug/tensor_load.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| @@ -164,7 +165,11 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| #ifdef ENABLE_DEBUGGER | |||
| if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) { | |||
| #else | |||
| if (!runtime_instance->Run(kernel_graph.get())) { | |||
| #endif | |||
| MS_LOG(EXCEPTION) << "GPU execute graph failed!"; | |||
| } | |||
| } | |||
| @@ -229,6 +234,9 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||
| void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||
| auto &kernel_graph = graphs_[graph_id]; | |||
| #ifdef ENABLE_DEBUGGER | |||
| PreIterationDbg(kernel_graph); | |||
| #endif | |||
| // Load input data from user input | |||
| LoadInputData(kernel_graph, inputs); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| @@ -245,6 +253,9 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| // Run graph on GPU | |||
| Execute(kernel_graph); | |||
| } | |||
| #ifdef ENABLE_DEBUGGER | |||
| PostLoadTensor(kernel_graph); | |||
| #endif | |||
| // Get result from GPU | |||
| UpdateOutputs(kernel_graph, outputs, inputs); | |||
| // Summary | |||
| @@ -253,6 +264,9 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| if (context_ptr->enable_gpu_summary()) { | |||
| Summary(kernel_graph.get()); | |||
| } | |||
| #ifdef ENABLE_DEBUGGER | |||
| PostIterationDbg(kernel_graph); | |||
| #endif | |||
| } | |||
| void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| @@ -296,6 +310,70 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph | |||
| RunOpClearMemory(kernel_graph.get()); | |||
| return tuple_tensors; | |||
| } | |||
| #ifdef ENABLE_DEBUGGER | |||
| void GPUSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| #ifdef ENABLE_DUMP_E2E | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| (void)runtime_instance->DumpData(kernel_graph.get(), debugger_.get()); | |||
| #endif | |||
| } | |||
| bool GPUSession::DumpDataEnabledIteration() const { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| return runtime_instance->DumpDataEnabledIteration(); | |||
| } | |||
| void GPUSession::PreIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| if (debugger_) { | |||
| debugger_->PreExecute(kernel_graph); | |||
| } | |||
| PreLoadTensor(kernel_graph); | |||
| } | |||
| void GPUSession::PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| bool dump_enabled = DumpDataEnabledIteration(); | |||
| // debug used for dump | |||
| if (debugger_ && dump_enabled) { | |||
| Dump(kernel_graph); | |||
| } | |||
| if (debugger_) { | |||
| debugger_->PostExecute(); | |||
| } | |||
| } | |||
| void GPUSession::PreLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| bool dump_enabled = DumpDataEnabledIteration(); | |||
| if (!(debugger_ && (debugger_->debugger_enabled() || dump_enabled))) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| DebugServices *debug_services = debugger_->debug_services(); | |||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||
| tensor_loader->EmptyTensor(); | |||
| uint32_t iter_num = tensor_loader->GetIterNum(); | |||
| tensor_loader->set_iter_num(++iter_num); | |||
| } | |||
| void GPUSession::PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| bool dump_enabled = DumpDataEnabledIteration(); | |||
| if (!(debugger_ && (debugger_->debugger_enabled() || dump_enabled))) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| DebugServices *debug_services = debugger_->debug_services(); | |||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||
| tensor_loader->EmptyPrevTensor(); | |||
| } | |||
| #endif | |||
| } // namespace gpu | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -67,6 +67,20 @@ class GPUSession : public SessionBasic { | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const override; | |||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| #ifdef ENABLE_DEBUGGER | |||
| void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| bool DumpDataEnabledIteration() const; | |||
| void PreIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void PreLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| #endif | |||
| }; | |||
| using GPUSessionPtr = std::shared_ptr<GPUSession>; | |||
| MS_REG_SESSION(kGPUDevice, GPUSession); | |||
| @@ -24,7 +24,6 @@ | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "common/trans.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/config_manager.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/oplib/oplib.h" | |||
| @@ -32,6 +32,7 @@ | |||
| #include "utils/contract.h" | |||
| #include "pipeline/pynative/pynative_execute.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "utils/context/ms_context.h" | |||
| #ifdef ENABLE_DEBUGGER | |||
| #include "debug/debugger/debugger.h" | |||
| #endif | |||
| @@ -112,7 +113,9 @@ class SessionBasic { | |||
| // set debugger | |||
| void SetDebugger() { | |||
| debugger_ = Debugger::GetInstance(); | |||
| debugger_->Init(device_id_); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| debugger_->Init(device_id_, ms_context->device_target()); | |||
| } | |||
| #endif | |||
| @@ -16,6 +16,7 @@ if (ENABLE_DEBUGGER) | |||
| "${CMAKE_CURRENT_SOURCE_DIR}/debugger/grpc_client.cc" | |||
| "${CMAKE_CURRENT_SOURCE_DIR}/debugger/proto_exporter.cc" | |||
| "${CMAKE_CURRENT_SOURCE_DIR}/debug_services.cc" | |||
| "${CMAKE_CURRENT_SOURCE_DIR}/common.cc" | |||
| ) | |||
| endif (ENABLE_DEBUGGER) | |||
| @@ -21,6 +21,7 @@ | |||
| #include "debug/debugger/debugger.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| using debugger::EventReply; | |||
| using debugger::GraphProto; | |||
| @@ -41,17 +42,20 @@ Debugger::Debugger() | |||
| : grpc_client_(nullptr), | |||
| debug_services_(nullptr), | |||
| device_id_(0), | |||
| device_target_(""), | |||
| num_step_(0), | |||
| debugger_enabled_(false), | |||
| is_dataset_graph_(false), | |||
| partial_memory_(false) {} | |||
| void Debugger::Init(const uint32_t device_id) { | |||
| void Debugger::Init(const uint32_t device_id, const std::string device_target) { | |||
| // access lock for public method | |||
| std::lock_guard<std::mutex> a_lock(access_lock_); | |||
| // save device_id | |||
| MS_LOG(INFO) << "Debugger got device_id: " << device_id; | |||
| device_id_ = device_id; | |||
| MS_LOG(INFO) << "Debugger got device_target: " << device_target; | |||
| device_target_ = device_target; | |||
| } | |||
| void Debugger::EnableDebugger() { | |||
| @@ -62,6 +66,14 @@ void Debugger::EnableDebugger() { | |||
| grpc_client_ = nullptr; | |||
| debug_services_ = nullptr; | |||
| // see if dump is enabled | |||
| bool dump_enabled = false; | |||
| if (device_target_ == kGPUDevice) { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| dump_enabled = runtime_instance->DumpDataEnabled(); | |||
| } | |||
| // get env variables to configure debugger | |||
| const char *env_enable_str = std::getenv("ENABLE_MS_DEBUGGER"); | |||
| if (env_enable_str != nullptr) { | |||
| @@ -70,7 +82,8 @@ void Debugger::EnableDebugger() { | |||
| debugger_enabled_ = true; | |||
| } | |||
| } | |||
| if (!debugger_enabled_) { | |||
| if (!debugger_enabled_ && !dump_enabled) { | |||
| MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger."; | |||
| return; | |||
| } | |||
| @@ -118,7 +131,10 @@ void Debugger::EnableDebugger() { | |||
| } | |||
| // initialize grpc client | |||
| grpc_client_ = std::make_unique<GrpcClient>(host, port); | |||
| if (debugger_enabled_) { | |||
| grpc_client_ = std::make_unique<GrpcClient>(host, port); | |||
| } | |||
| debug_services_ = std::make_unique<DebugServices>(); | |||
| } | |||
| @@ -127,6 +143,7 @@ void Debugger::Reset() { | |||
| std::lock_guard<std::mutex> a_lock(access_lock_); | |||
| // reset components | |||
| device_id_ = 0; | |||
| device_target_ = ""; | |||
| num_step_ = 0; | |||
| debugger_enabled_ = false; | |||
| is_dataset_graph_ = false; | |||
| @@ -55,7 +55,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> { | |||
| // init | |||
| // only save device_id | |||
| void Init(const uint32_t device_id); | |||
| void Init(const uint32_t device_id, const std::string device_target); | |||
| // reset debugger | |||
| void Reset(); | |||
| @@ -128,6 +128,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> { | |||
| std::unique_ptr<DebugServices> debug_services_; | |||
| KernelGraphPtr graph_ptr_; | |||
| uint32_t device_id_; | |||
| std::string device_target_; | |||
| int32_t num_step_; | |||
| bool debugger_enabled_; | |||
| bool is_dataset_graph_; | |||
| @@ -24,6 +24,10 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include "debug/tensor_data.h" | |||
| #include "ir/dtype.h" | |||
| #ifdef ENABLE_DUMP_E2E | |||
| #include "debug/e2e_dump.h" | |||
| #endif | |||
| namespace mindspore { | |||
| class TensorLoader { | |||
| public: | |||
| @@ -72,8 +76,54 @@ class TensorLoader { | |||
| void EmptyPrevTensor() { prev_tensor_list_map.clear(); } | |||
| void EmptyCurrentTensor() { | |||
| tensor_list_map.clear(); | |||
| tensor_list.clear(); | |||
| } | |||
| void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; } | |||
| #ifdef ENABLE_DUMP_E2E | |||
| bool DumpTensorToFile(std::string tensor_name, bool trans_flag, const std::string &filepath, | |||
| const std::string &host_fmt, const std::vector<int> &host_shape, TypeId host_type, | |||
| TypeId addr_type_id, std::string addr_format, size_t slot) const { | |||
| bool ret = false; | |||
| if (filepath.empty()) { | |||
| MS_LOG(ERROR) << "Dump file path is null!"; | |||
| return ret; | |||
| } | |||
| std::string shape = "shape"; | |||
| if (host_shape.size()) { | |||
| for (auto &value : host_shape) { | |||
| shape = shape + '_' + std::to_string(value); | |||
| } | |||
| } else { | |||
| shape = shape + "_0"; | |||
| } | |||
| std::string file_extension = ".bin"; | |||
| std::string path = ""; | |||
| if (trans_flag) { | |||
| path = filepath + '_' + shape + '_' + TypeIdLabel(host_type) + '_' + host_fmt + file_extension; | |||
| } else { | |||
| path = filepath + '_' + shape + '_' + TypeIdToType(addr_type_id)->ToString() + '_' + addr_format + file_extension; | |||
| } | |||
| MS_LOG(INFO) << "Dump path is " << path; | |||
| std::string tensor_loader_name = tensor_name + ":" + std::to_string(slot); | |||
| auto iter = tensor_list_map.find(tensor_loader_name); | |||
| if (iter != tensor_list_map.end()) { | |||
| std::shared_ptr<TensorData> node = iter->second; | |||
| mindspore::tensor::TensorPtr out_tensor = node->GetTensor(); | |||
| size_t host_size = out_tensor->data().nbytes(); | |||
| ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); | |||
| } | |||
| return ret; | |||
| } | |||
| #endif | |||
| private: | |||
| std::vector<std::shared_ptr<TensorData>> tensor_list; | |||
| std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map; | |||
| @@ -275,7 +275,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p | |||
| } // namespace | |||
| #endif | |||
| bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { | |||
| bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| #ifdef ENABLE_DUMP_E2E | |||
| MS_LOG(INFO) << "Start dump step"; | |||
| @@ -38,7 +38,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| AscendKernelRuntime() = default; | |||
| ~AscendKernelRuntime() override; | |||
| bool Init() override; | |||
| bool DumpData(session::KernelGraph *graph) override; | |||
| bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||
| bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; | |||
| bool GenTask(const session::KernelGraph *graph) override; | |||
| bool RunTask(const session::KernelGraph *graph) override; | |||
| @@ -270,7 +270,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput | |||
| resource_manager_.DecreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { | |||
| bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| resource_manager_.IncreaseAddressRefCount(kernel_graph); | |||
| @@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime { | |||
| ~CPUKernelRuntime() override = default; | |||
| bool Init() override { return true; } | |||
| bool Run(session::KernelGraph *graph) override; | |||
| bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||
| void AssignKernelAddress(session::KernelGraph *kernel_graph); | |||
| void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs, std::vector<tensor::TensorPtr> *need_sync_outputs); | |||
| @@ -16,9 +16,16 @@ | |||
| #include "runtime/device/gpu/gpu_device_address.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "runtime/device/gpu/gpu_device_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "runtime/device/gpu/gpu_memory_allocator.h" | |||
| #include "ir/tensor.h" | |||
| #ifdef ENABLE_DEBUGGER | |||
| #include "debug/debug_services.h" | |||
| #include "debug/tensor_load.h" | |||
| #include "debug/debugger/debugger.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -59,6 +66,36 @@ GPUDeviceAddress::~GPUDeviceAddress() { | |||
| ptr_ = nullptr; | |||
| } | |||
| } | |||
| #ifdef ENABLE_DEBUGGER | |||
| bool GPUDeviceAddress::LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, | |||
| const std::vector<int> &host_shape, TypeId host_type, size_t slot, | |||
| Debugger *debugger, bool keep_prev) const { | |||
| bool ret = false; | |||
| if (size_ == 0) { | |||
| return true; | |||
| } | |||
| DebugServices *debug_services = debugger->debug_services(); | |||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||
| mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape); | |||
| size_t host_size = out_tensor->data().nbytes(); | |||
| auto ret_rt_memcpy = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); | |||
| if (!ret_rt_memcpy) { | |||
| MS_LOG(ERROR) << "Copy device mem to host failed"; | |||
| return ret; | |||
| } | |||
| auto tensor_data = std::make_shared<mindspore::TensorData>(); | |||
| tensor_data->SetName(tensor_name); | |||
| tensor_data->SetExecutionOrder(execution_order); | |||
| tensor_data->SetTensor(out_tensor); | |||
| tensor_data->SetSlot(slot); | |||
| ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); | |||
| MS_LOG(INFO) << "E2E tensor name is " << tensor_name; | |||
| return ret; | |||
| } | |||
| #endif | |||
| } // namespace gpu | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,9 @@ | |||
| #include "runtime/device/device_address.h" | |||
| namespace mindspore { | |||
| #ifdef ENABLE_DEBUGGER | |||
| class Debugger; | |||
| #endif | |||
| namespace device { | |||
| namespace gpu { | |||
| class GPUDeviceAddress : public DeviceAddress { | |||
| @@ -37,6 +40,11 @@ class GPUDeviceAddress : public DeviceAddress { | |||
| DeviceAddressStatus status() const { return status_; } | |||
| DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } | |||
| #ifdef ENABLE_DEBUGGER | |||
| bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, | |||
| const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger, | |||
| bool keep_prev) const; | |||
| #endif | |||
| private: | |||
| DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; | |||
| }; | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "runtime/device/gpu/gpu_kernel_runtime.h" | |||
| #include <algorithm> | |||
| #include "runtime/device/gpu/gpu_device_address.h" | |||
| #include "runtime/device/gpu/cuda_driver.h" | |||
| #include "runtime/device/gpu/gpu_buffer_mgr.h" | |||
| @@ -29,6 +29,8 @@ | |||
| #include "runtime/device/gpu/gpu_memory_manager.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "runtime/device/gpu/gpu_memory_copy_manager.h" | |||
| #include "common/trans.h" | |||
| #include "ir/dtype.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -36,6 +38,7 @@ namespace gpu { | |||
| using mindspore::device::memswap::MemSwapInfoSet; | |||
| using mindspore::device::memswap::MemSwapManager; | |||
| using mindspore::device::memswap::SwapKind; | |||
| static const size_t PARAMETER_OUTPUT_INDEX = 0; | |||
| bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } | |||
| bool GPUKernelRuntime::Init() { | |||
| @@ -43,7 +46,15 @@ bool GPUKernelRuntime::Init() { | |||
| GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); | |||
| return true; | |||
| } | |||
| auto ret = InitDevice(); | |||
| bool ret = false; | |||
| #ifdef ENABLE_DUMP_E2E | |||
| ret = SetDumpConf(); | |||
| if (!ret) { | |||
| MS_LOG(INFO) << "No dump conf to set!"; | |||
| } | |||
| #endif | |||
| ret = InitDevice(); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "InitDevice error."; | |||
| return ret; | |||
| @@ -63,6 +74,216 @@ bool GPUKernelRuntime::Init() { | |||
| return ret; | |||
| } | |||
| #ifdef ENABLE_DUMP_E2E | |||
| namespace { | |||
| void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf, | |||
| Debugger *debugger) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(dump_conf); | |||
| bool trans_flag = dump_conf->trans_flag(); | |||
| const auto &apply_kernels = graph->execution_order(); | |||
| for (const auto &node : apply_kernels) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto node_name = AnfAlgo::GetCNodeName(node); | |||
| std::string kernel_name = node->fullname_with_scope(); | |||
| if (!dump_conf->IsKernelNeedDump(kernel_name)) { | |||
| continue; | |||
| } | |||
| const std::string strsrc = "/"; | |||
| const std::string strdst = "--"; | |||
| std::string::size_type pos = 0; | |||
| std::string::size_type srclen = strsrc.size(); | |||
| std::string::size_type dstlen = strdst.size(); | |||
| while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) { | |||
| kernel_name.replace(pos, srclen, strdst); | |||
| pos += dstlen; | |||
| } | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t j = 0; j < output_size; ++j) { | |||
| auto addr = AnfAlgo::GetOutputAddr(node, j); | |||
| TypeId addr_type_id = addr->type_id(); | |||
| std::string addr_format = addr->format(); | |||
| std::vector<int> int_shapes; | |||
| if (trans_flag) { | |||
| int_shapes = trans::GetRuntimePaddingShape(node, j); | |||
| } else { | |||
| auto shape = AnfAlgo::GetOutputDeviceShape(node, j); | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), | |||
| [](size_t inner_item) { return SizeToInt(inner_item); }); | |||
| } | |||
| auto type = AnfAlgo::GetOutputInferDataType(node, j); | |||
| auto format = kOpFormat_DEFAULT; | |||
| string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); | |||
| DebugServices *debug_services = debugger->debug_services(); | |||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||
| std::string original_kernel_name = node->fullname_with_scope(); | |||
| size_t slot = j; | |||
| auto ret = tensor_loader->DumpTensorToFile(original_kernel_name, trans_flag, filepath, format, int_shapes, type, | |||
| addr_type_id, addr_format, slot); | |||
| if (!ret) { | |||
| std::string error = "DumpTensorToFile Failed: flag:" + std::to_string(trans_flag) + ", path:" + filepath + | |||
| ", host_format:" + format + ".!"; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf, | |||
| Debugger *debugger) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(dump_conf); | |||
| bool trans_flag = dump_conf->trans_flag(); | |||
| const auto ¶meters = graph->inputs(); | |||
| for (auto &item : parameters) { | |||
| if (!item->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| std::string parameter_name = item->fullname_with_scope(); | |||
| if (!dump_conf->IsKernelNeedDump(parameter_name)) { | |||
| continue; | |||
| } | |||
| auto addr = AnfAlgo::GetOutputAddr(item, PARAMETER_OUTPUT_INDEX); | |||
| TypeId addr_type_id = addr->type_id(); | |||
| std::string addr_format = addr->format(); | |||
| std::vector<int> int_shapes; | |||
| if (trans_flag) { | |||
| int_shapes = trans::GetRuntimePaddingShape(item, PARAMETER_OUTPUT_INDEX); | |||
| } else { | |||
| auto shape = AnfAlgo::GetOutputDeviceShape(item, PARAMETER_OUTPUT_INDEX); | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), | |||
| [](size_t inner_item) { return SizeToInt(inner_item); }); | |||
| } | |||
| auto type = AnfAlgo::GetOutputInferDataType(item, PARAMETER_OUTPUT_INDEX); | |||
| auto format = kOpFormat_DEFAULT; | |||
| string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; | |||
| DebugServices *debug_services = debugger->debug_services(); | |||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||
| std::string original_kernel_name = parameter_name; | |||
| size_t slot = 0; | |||
| auto ret = tensor_loader->DumpTensorToFile(original_kernel_name, trans_flag, filepath, format, int_shapes, type, | |||
| addr_type_id, addr_format, slot); | |||
| if (!ret) { | |||
| std::string error = "DumpTensorToFile Failed: flag:" + std::to_string(trans_flag) + ", path:" + filepath + | |||
| ", host_format:" + format + ".!"; | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| bool GPUKernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Start dump step"; | |||
| DumpConfPtr dump_conf = GetDumpConf(); | |||
| MS_EXCEPTION_IF_NULL(dump_conf); | |||
| dump_conf->UpdataCurIter(); | |||
| bool dump_flag = dump_conf->dump_enable(); | |||
| if (!dump_flag) { | |||
| MS_LOG(INFO) << "Dump flag is disable, pass dump step"; | |||
| return true; | |||
| } | |||
| uint32_t cur_iter = dump_conf->cur_iter(); | |||
| if (dump_conf->dump_iter() != 0) { | |||
| if (cur_iter != dump_conf->dump_iter()) { | |||
| return true; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Cur iter is " << cur_iter; | |||
| std::string net_name = dump_conf->dump_net_name(); | |||
| std::string iterator = std::to_string(cur_iter); | |||
| std::string dump_path = dump_conf->dump_path(); | |||
| if (dump_path.back() == '/') { | |||
| dump_path = dump_path + net_name + '/' + iterator; | |||
| } else { | |||
| dump_path = dump_path + '/' + net_name + '/' + iterator; | |||
| } | |||
| // dump output | |||
| DumpOutput(graph, dump_path, dump_conf, debugger); | |||
| // dump parameters | |||
| DumpParameters(graph, dump_path, dump_conf, debugger); | |||
| return true; | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_DEBUGGER | |||
| namespace { | |||
| void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, | |||
| const std::vector<mindspore::kernel::AddressPtr> &kernel_inputs, | |||
| const std::vector<mindspore::kernel::AddressPtr> &kernel_workspaces, | |||
| const std::vector<mindspore::kernel::AddressPtr> &kernel_outputs, int exec_order, void *stream_ptr, | |||
| bool dump_enabled) { | |||
| if (!(debugger && (debugger->debugger_enabled() || dump_enabled))) { | |||
| return; | |||
| } | |||
| std::string kernel_name = kernel->fullname_with_scope(); | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(kernel); | |||
| for (size_t j = 0; j < output_size; ++j) { | |||
| auto addr = kernel_outputs[j]; | |||
| auto type = AnfAlgo::GetOutputInferDataType(kernel, j); | |||
| auto format = kOpFormat_DEFAULT; | |||
| auto gpu_addr = std::make_unique<GPUDeviceAddress>(addr->addr, addr->size, format, type); | |||
| string tensor_name = kernel_name + ':' + std::to_string(j); | |||
| std::vector<int> int_shapes; | |||
| auto shape = AnfAlgo::GetOutputDeviceShape(kernel, j); | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), | |||
| [](size_t inner_item) { return SizeToInt(inner_item); }); | |||
| auto ret = gpu_addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, debugger, false); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "LoadMemToHost:" | |||
| << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!"; | |||
| } | |||
| } | |||
| } | |||
| void LoadParameters(const session::KernelGraph *graph, Debugger *debugger, bool dump_enabled) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (!(debugger && (debugger->debugger_enabled() || dump_enabled))) { | |||
| return; | |||
| } | |||
| const auto ¶meters = graph->inputs(); | |||
| // for parameters, set its execution order to be 0; | |||
| int exec_order = 0; | |||
| for (auto &item : parameters) { | |||
| if (!item->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| std::string parameter_name = item->fullname_with_scope(); | |||
| auto addr = AnfAlgo::GetOutputAddr(item, PARAMETER_OUTPUT_INDEX); | |||
| auto type = AnfAlgo::GetOutputInferDataType(item, PARAMETER_OUTPUT_INDEX); | |||
| auto format = kOpFormat_DEFAULT; | |||
| string tensor_name = parameter_name + ':' + "0"; | |||
| auto gpu_addr = dynamic_cast<const mindspore::device::gpu::GPUDeviceAddress *>(addr); | |||
| std::vector<int> int_shapes; | |||
| auto shape = AnfAlgo::GetOutputDeviceShape(item, PARAMETER_OUTPUT_INDEX); | |||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), | |||
| [](size_t inner_item) { return SizeToInt(inner_item); }); | |||
| auto ret = gpu_addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, debugger, true); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "LoadMemToHost:" | |||
| << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!"; | |||
| } | |||
| } | |||
| } | |||
| void ClearCurrentData(Debugger *debugger, bool dump_enabled) { | |||
| if (debugger && (debugger->debugger_enabled() || dump_enabled)) { | |||
| DebugServices *debug_services = debugger->debug_services(); | |||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | |||
| tensor_loader->EmptyCurrentTensor(); | |||
| } | |||
| } | |||
| } // namespace | |||
| #endif | |||
| DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| TypeId type_id) { | |||
| return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id); | |||
| @@ -147,7 +368,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||
| } | |||
| } | |||
| bool GPUKernelRuntime::Run(session::KernelGraph *graph) { | |||
| bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { | |||
| struct timeval start_time, end_time; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| bool ret = true; | |||
| @@ -170,7 +391,7 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { | |||
| mem_reuse_util_ = mem_reuse_iter->second; | |||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_); | |||
| ret = RunOneStep(graph); | |||
| ret = RunOneStep(graph, debugger); | |||
| } else { | |||
| ret = LaunchKernel(graph); | |||
| } | |||
| @@ -182,28 +403,28 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { | |||
| return ret; | |||
| } | |||
| bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) { | |||
| bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph, Debugger *debugger) { | |||
| bool ret = true; | |||
| auto graph_id = graph->graph_id(); | |||
| if (!is_first_step_map_[graph_id]) { | |||
| // Normally run graph | |||
| ret = LaunchKernelDynamic(graph); | |||
| ret = LaunchKernelDynamic(graph, debugger); | |||
| } else { | |||
| // Mock run first step | |||
| ret = LaunchKernelDynamic(graph, true, false); | |||
| ret = LaunchKernelDynamic(graph, debugger, true, false); | |||
| if (ret) { | |||
| // Normally run graph | |||
| ret = LaunchKernelDynamic(graph); | |||
| ret = LaunchKernelDynamic(graph, debugger); | |||
| } else { | |||
| // Trigger memory swap | |||
| ret = SearchMemSwapScheme(graph); | |||
| ret = SearchMemSwapScheme(graph, debugger); | |||
| } | |||
| is_first_step_map_[graph_id] = false; | |||
| } | |||
| return ret; | |||
| } | |||
| bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { | |||
| bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger) { | |||
| MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; | |||
| bool ret = false; | |||
| ClearKernelOldOutputAndWorkspace(graph); | |||
| @@ -217,7 +438,7 @@ bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { | |||
| if (!mem_swap_manager_->RetreatSwapInfo()) { | |||
| return false; | |||
| } | |||
| ret = LaunchKernelDynamic(graph, true, false); | |||
| ret = LaunchKernelDynamic(graph, debugger, true, false); | |||
| if (!ret) { | |||
| ClearKernelOldOutputAndWorkspace(graph); | |||
| } | |||
| @@ -225,14 +446,14 @@ bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { | |||
| mem_swap_manager_->AssignHostMemory(); | |||
| // Time profiling | |||
| ret = LaunchKernelDynamic(graph, false, true); | |||
| ret = LaunchKernelDynamic(graph, debugger, false, true); | |||
| if (!ret) { | |||
| return ret; | |||
| } | |||
| return RefineMemSwapScheme(graph); | |||
| return RefineMemSwapScheme(graph, debugger); | |||
| } | |||
| bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) { | |||
| bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger) { | |||
| MS_LOG(WARNING) << "Refine memory swap scheme, it may take some time, please wait a moment."; | |||
| auto &kernels = graph->execution_order(); | |||
| for (const auto &kernel : kernels) { | |||
| @@ -245,7 +466,7 @@ bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) { | |||
| bool ret = false; | |||
| while (!ret) { | |||
| mem_swap_manager_->AdjustSwapInPos(kernel, swap_in_task_idx); | |||
| ret = LaunchKernelDynamic(graph, true, false); | |||
| ret = LaunchKernelDynamic(graph, debugger, true, false); | |||
| if (!ret) { | |||
| ClearKernelOldOutputAndWorkspace(graph); | |||
| ClearSwapInfo(true); | |||
| @@ -384,14 +605,24 @@ void GPUKernelRuntime::ClearKernelWorkspaceAddress(const session::KernelGraph *g | |||
| } | |||
| } | |||
| bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bool mock, bool profiling) { | |||
| bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger, bool mock, | |||
| bool profiling) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_); | |||
| // Reset the reference count. | |||
| mem_reuse_util_->ResetDynamicUsedRefCount(); | |||
| // The inputs and outputs memory of communication kernel need be continuous, so separate processing. | |||
| AllocCommunicationOpDynamicRes(graph); | |||
| #ifdef ENABLE_DEBUGGER | |||
| bool dump_enabled = GPUKernelRuntime::DumpDataEnabledIteration(); | |||
| if (!mock) { | |||
| // collect weights and bias | |||
| LoadParameters(graph, debugger, dump_enabled); | |||
| } | |||
| #endif | |||
| auto &kernels = graph->execution_order(); | |||
| int exec_order = 1; | |||
| for (const auto &kernel : kernels) { | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| @@ -400,6 +631,12 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo | |||
| AddressPtrList kernel_outputs; | |||
| auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs, mock); | |||
| if (!ret) { | |||
| #ifdef ENABLE_DEBUGGER | |||
| if (!mock) { | |||
| // invalidate current data collected by the debugger | |||
| ClearCurrentData(debugger, dump_enabled); | |||
| } | |||
| #endif | |||
| return false; | |||
| } | |||
| if (!mock) { | |||
| @@ -409,9 +646,21 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo | |||
| } else { | |||
| LaunchKernelWithTimeProfiling(kernel, kernel_inputs, kernel_workspaces, kernel_outputs); | |||
| } | |||
| #ifdef ENABLE_DEBUGGER | |||
| // called once per kernel to collect the outputs to the kernel (does a SyncDeviceToHost) | |||
| LoadKernelData(debugger, kernel, kernel_inputs, kernel_workspaces, kernel_outputs, exec_order, stream_, | |||
| dump_enabled); | |||
| #endif | |||
| } | |||
| exec_order = exec_order + 1; | |||
| FreeKernelDynamicRes(kernel); | |||
| if (!UpdateMemorySwapTask(kernel, mock, profiling)) { | |||
| #ifdef ENABLE_DEBUGGER | |||
| if (!mock) { | |||
| // invalidate current data collected by the debugger | |||
| ClearCurrentData(debugger, dump_enabled); | |||
| } | |||
| #endif | |||
| return false; | |||
| } | |||
| } | |||
| @@ -38,7 +38,10 @@ class GPUKernelRuntime : public KernelRuntime { | |||
| bool Init() override; | |||
| void ReleaseDeviceRes() override; | |||
| void AssignMemory(session::KernelGraph *graph) override; | |||
| bool Run(session::KernelGraph *graph) override; | |||
| bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||
| #ifdef ENABLE_DUMP_E2E | |||
| bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override; | |||
| #endif | |||
| protected: | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| @@ -61,10 +64,11 @@ class GPUKernelRuntime : public KernelRuntime { | |||
| void ClearKernelOutputAddress(const session::KernelGraph *graph); | |||
| void ClearKernelWorkspaceAddress(const session::KernelGraph *graph); | |||
| void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph); | |||
| bool RunOneStep(const session::KernelGraph *graph); | |||
| bool SearchMemSwapScheme(const session::KernelGraph *graph); | |||
| bool RefineMemSwapScheme(const session::KernelGraph *graph); | |||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); | |||
| bool RunOneStep(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||
| bool SearchMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||
| bool RefineMemSwapScheme(const session::KernelGraph *graph, Debugger *debugger = nullptr); | |||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, Debugger *debugger = nullptr, bool mock = false, | |||
| bool profiling = false); | |||
| void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||
| const AddressPtrList &workspace, const AddressPtrList &outputs); | |||
| bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); | |||
| @@ -41,7 +41,7 @@ KernelRuntime::~KernelRuntime() { | |||
| #endif | |||
| } | |||
| bool KernelRuntime::Run(session::KernelGraph *graph) { | |||
| bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { | |||
| bool ret = false; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -72,7 +72,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph) { | |||
| } | |||
| // for D to impl | |||
| bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { | |||
| bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) { | |||
| if (graph != nullptr) { | |||
| return true; | |||
| } | |||
| @@ -190,6 +190,39 @@ void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { | |||
| } | |||
| } | |||
| bool KernelRuntime::DumpDataEnabled() { | |||
| bool ret = false; | |||
| #ifdef ENABLE_DUMP_E2E | |||
| DumpConfPtr dump_conf = GetDumpConf(); | |||
| MS_EXCEPTION_IF_NULL(dump_conf); | |||
| bool dump_flag = dump_conf->dump_enable(); | |||
| if (!dump_flag) { | |||
| return ret; | |||
| } | |||
| ret = true; | |||
| #endif | |||
| return ret; | |||
| } | |||
| bool KernelRuntime::DumpDataEnabledIteration() { | |||
| bool ret = false; | |||
| #ifdef ENABLE_DUMP_E2E | |||
| if (!DumpDataEnabled()) { | |||
| return ret; | |||
| } | |||
| DumpConfPtr dump_conf = GetDumpConf(); | |||
| MS_EXCEPTION_IF_NULL(dump_conf); | |||
| uint32_t cur_iter = dump_conf->cur_iter() + 1; | |||
| if (dump_conf->dump_iter() != 0) { | |||
| if (cur_iter != dump_conf->dump_iter()) { | |||
| return ret; | |||
| } | |||
| } | |||
| ret = true; | |||
| #endif | |||
| return ret; | |||
| } | |||
| void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) { | |||
| AssignStaticMemoryInput(graph); | |||
| AssignStaticMemoryValueNode(graph); | |||
| @@ -55,8 +55,10 @@ class KernelRuntime { | |||
| virtual void AssignMemory(session::KernelGraph *graph); | |||
| void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph); | |||
| void RunOpClearMemory(const session::KernelGraph *graph); | |||
| virtual bool Run(session::KernelGraph *graph); | |||
| virtual bool DumpData(session::KernelGraph *graph); | |||
| bool DumpDataEnabled(); | |||
| bool DumpDataEnabledIteration(); | |||
| virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr); | |||
| virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr); | |||
| virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); | |||
| virtual bool RunTask(const session::KernelGraph *graph); | |||
| virtual bool GenTask(const session::KernelGraph *graph); | |||