Browse Source

fix pynative sync error

tags/v0.5.0-beta
chujinjin 5 years ago
parent
commit
22d55adaa8
1 changed files with 6 additions and 5 deletions
  1. +6
    -5
      mindspore/ccsrc/device/ascend/ascend_device_address.cc

+ 6
- 5
mindspore/ccsrc/device/ascend/ascend_device_address.cc View File

@@ -96,6 +96,10 @@ void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kPynativeMode) {
MS_LOG(INFO) << "Finish!";
return;
}
auto device_id = ms_context->device_id();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
@@ -110,11 +114,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
void *host_ptr) const {
MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
SyncStream();
}
SyncStream();
bool sync_ok = false;
std::vector<size_t> host_shape;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);
@@ -205,6 +205,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
const void *host_ptr) const {
MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
SyncStream();
bool sync_ok = false;
std::vector<size_t> host_shape;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);


Loading…
Cancel
Save