From: @HulkTang Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjintags/v1.2.0-rc1
| @@ -321,15 +321,6 @@ bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | |||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | |||||
| auto address = AnfAlgo::GetOutputAddr(kernel, index); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| return address->DeviceType() == DeviceAddressType::kAscend; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { | bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { | ||||
| bool need_dump = false; | bool need_dump = false; | ||||
| auto &dump_json_parser = DumpJsonParser::GetInstance(); | auto &dump_json_parser = DumpJsonParser::GetInstance(); | ||||
| @@ -57,13 +57,13 @@ class AscendKernelRuntime : public KernelRuntime { | |||||
| void *context() const override { return rt_context_; } | void *context() const override { return rt_context_; } | ||||
| void PreInit() override; | void PreInit() override; | ||||
| uint64_t GetAvailableMemMaxSize() const override; | uint64_t GetAvailableMemMaxSize() const override; | ||||
| DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; }; | |||||
| void *compute_stream() const override { return stream_; } | void *compute_stream() const override { return stream_; } | ||||
| void *communication_stream() const override { return communication_stream_; } | void *communication_stream() const override { return communication_stream_; } | ||||
| protected: | protected: | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| TypeId type_id) override; | TypeId type_id) override; | ||||
| bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; | |||||
| bool KernelMemNotReuse(const AnfNodePtr &node) override; | bool KernelMemNotReuse(const AnfNodePtr &node) override; | ||||
| void KernelLaunchProfiling(const std::string &kernel_name) override; | void KernelLaunchProfiling(const std::string &kernel_name) override; | ||||
| @@ -46,6 +46,7 @@ class CPUKernelRuntime : public KernelRuntime { | |||||
| void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | ||||
| bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } | bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } | ||||
| bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } | bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } | ||||
| DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kCPU; }; | |||||
| protected: | protected: | ||||
| bool SyncStream() override { return true; }; | bool SyncStream() override { return true; }; | ||||
| @@ -47,6 +47,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| bool Run(session::KernelGraph *graph, bool is_task_sink) override; | bool Run(session::KernelGraph *graph, bool is_task_sink) override; | ||||
| bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } | bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } | ||||
| bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } | bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } | ||||
| DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; }; | |||||
| void *compute_stream() const override { return stream_; } | void *compute_stream() const override { return stream_; } | ||||
| void *communication_stream() const override { return communication_stream_; } | void *communication_stream() const override { return communication_stream_; } | ||||
| @@ -49,7 +49,9 @@ bool KernelRuntime::LoadData(session::KernelGraph *graph) { return false; } | |||||
| bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| if (AnfAlgo::OutputAddrExist(kernel, index)) { | if (AnfAlgo::OutputAddrExist(kernel, index)) { | ||||
| return true; | |||||
| const auto &address = AnfAlgo::GetOutputAddr(kernel, index); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| return address->DeviceType() == GetTargetDeviceAddressType(); | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -173,7 +175,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> | |||||
| MS_EXCEPTION_IF_NULL(input_tensors[input_index]); | MS_EXCEPTION_IF_NULL(input_tensors[input_index]); | ||||
| auto output_address = | auto output_address = | ||||
| std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address()); | std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address()); | ||||
| if (output_address != nullptr) { | |||||
| if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) { | |||||
| AnfAlgo::SetOutputAddr(output_address, index, item.get()); | AnfAlgo::SetOutputAddr(output_address, index, item.get()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -637,7 +639,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||||
| MS_LOG(WARNING) << "Tensor is null"; | MS_LOG(WARNING) << "Tensor is null"; | ||||
| return; | return; | ||||
| } | } | ||||
| if (tensor->device_address() != nullptr) { | |||||
| auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||||
| if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) { | |||||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++, | AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++, | ||||
| value_node.get()); | value_node.get()); | ||||
| continue; | continue; | ||||
| @@ -108,6 +108,7 @@ class KernelRuntime { | |||||
| virtual uint64_t GetAvailableMemMaxSize() const { return 0; } | virtual uint64_t GetAvailableMemMaxSize() const { return 0; } | ||||
| void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); } | void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); } | ||||
| void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); } | void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); } | ||||
| virtual DeviceAddressType GetTargetDeviceAddressType() const = 0; | |||||
| virtual void *compute_stream() const { return nullptr; } | virtual void *compute_stream() const { return nullptr; } | ||||
| virtual void *communication_stream() const { return nullptr; } | virtual void *communication_stream() const { return nullptr; } | ||||