Browse Source

add async ops excute for pynative

tags/v0.5.0-beta
chujinjin 5 years ago
parent
commit
dde03ce944
6 changed files with 29 additions and 6 deletions
  1. +19
    -0
      mindspore/ccsrc/device/ascend/ascend_device_address.cc
  2. +1
    -0
      mindspore/ccsrc/device/ascend/ascend_device_address.h
  3. +1
    -1
      mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h
  4. +7
    -0
      mindspore/ccsrc/device/gpu/gpu_device_address.cc
  5. +0
    -4
      mindspore/ccsrc/device/kernel_runtime.cc
  6. +1
    -1
      mindspore/ccsrc/device/kernel_runtime.h

+ 19
- 0
mindspore/ccsrc/device/ascend/ascend_device_address.cc View File

@@ -92,10 +92,29 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
return true;
}

void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream();
if (!ret) {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
MS_LOG(INFO) << "Finish!";
}

bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t size, mindspore::TypeId type,
void *host_ptr) const {
MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
SyncStream();
}
bool sync_ok = false;
std::vector<size_t> host_shape;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);


+ 1
- 0
mindspore/ccsrc/device/ascend/ascend_device_address.h View File

@@ -44,6 +44,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const;
void SyncStream() const;
};
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
} // namespace ascend


+ 1
- 1
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h View File

@@ -41,12 +41,12 @@ class AscendKernelRuntime : public KernelRuntime {
bool RunTask(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph) override;
void ClearGraphRuntimeResource(uint32_t graph_id) override;
bool SyncStream() override;

protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool SyncStream() override;

private:
bool InitDevice();


+ 7
- 0
mindspore/ccsrc/device/gpu/gpu_device_address.cc View File

@@ -28,6 +28,13 @@ namespace device {
namespace gpu {
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr);
auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream);
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
if (!ret) {
MS_LOG(ERROR) << "SyncStream failed";
return ret;
}
if (size != size_) {
MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_;
return true;


+ 0
- 4
mindspore/ccsrc/device/kernel_runtime.cc View File

@@ -680,10 +680,6 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false;
}
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed!";
return false;
}
return true;
}



+ 1
- 1
mindspore/ccsrc/device/kernel_runtime.h View File

@@ -55,6 +55,7 @@ class KernelRuntime {
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
virtual bool SyncStream() = 0;

#ifdef ENABLE_DUMP_E2E
DumpConfPtr GetDumpConf();
@@ -68,7 +69,6 @@ class KernelRuntime {
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool SyncStream() = 0;
void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph);


Loading…
Cancel
Save