Browse Source

!13049 Support ms_function + heterogenous

From: @HulkTang
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
c262acbd8e
6 changed files with 10 additions and 13 deletions
  1. +0
    -9
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
  2. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h
  3. +1
    -0
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h
  4. +1
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h
  5. +6
    -3
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  6. +1
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime.h

+ 0
- 9
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc View File

@@ -321,15 +321,6 @@ bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph) {
return true;
}

bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
if (AnfAlgo::OutputAddrExist(kernel, index)) {
auto address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == DeviceAddressType::kAscend;
}
return false;
}

bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
bool need_dump = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();


+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h View File

@@ -57,13 +57,13 @@ class AscendKernelRuntime : public KernelRuntime {
void *context() const override { return rt_context_; }
void PreInit() override;
uint64_t GetAvailableMemMaxSize() const override;
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; };
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }

protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool KernelMemNotReuse(const AnfNodePtr &node) override;

void KernelLaunchProfiling(const std::string &kernel_name) override;


+ 1
- 0
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h View File

@@ -46,6 +46,7 @@ class CPUKernelRuntime : public KernelRuntime {
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kCPU; };

protected:
bool SyncStream() override { return true; };


+ 1
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h View File

@@ -47,6 +47,7 @@ class GPUKernelRuntime : public KernelRuntime {
bool Run(session::KernelGraph *graph, bool is_task_sink) override;
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; };
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }



+ 6
- 3
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -49,7 +49,9 @@ bool KernelRuntime::LoadData(session::KernelGraph *graph) { return false; }
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::OutputAddrExist(kernel, index)) {
return true;
const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == GetTargetDeviceAddressType();
}
return false;
}
@@ -173,7 +175,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
auto output_address =
std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address());
if (output_address != nullptr) {
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
AnfAlgo::SetOutputAddr(output_address, index, item.get());
continue;
}
@@ -637,7 +639,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
MS_LOG(WARNING) << "Tensor is null";
return;
}
if (tensor->device_address() != nullptr) {
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;


+ 1
- 0
mindspore/ccsrc/runtime/device/kernel_runtime.h View File

@@ -108,6 +108,7 @@ class KernelRuntime {
virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); }
void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); }
virtual DeviceAddressType GetTargetDeviceAddressType() const = 0;
virtual void *compute_stream() const { return nullptr; }
virtual void *communication_stream() const { return nullptr; }



Loading…
Cancel
Save