From 5c9f738bfc673da81027a9c7dedd530dfba539c4 Mon Sep 17 00:00:00 2001 From: kswang Date: Thu, 29 Oct 2020 10:34:51 +0800 Subject: [PATCH] fix pynative grad error --- mindspore/ccsrc/backend/session/session_basic.cc | 8 ++++---- mindspore/core/ir/tensor.cc | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index fd838b2c42..d7462d1ebc 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -988,16 +988,16 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); - if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode || - AnfAlgo::IsParameterWeight(input_node->cast())) { - tensor->set_device_address(device_address); - } MS_EXCEPTION_IF_NULL(device_address); if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode || + AnfAlgo::IsParameterWeight(input_node->cast())) { + tensor->set_device_address(device_address); + } } tensor->set_sync_status(kNoNeedSync); } diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 049d77c540..6711ea6587 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -580,7 +580,8 @@ void Tensor::data_sync(bool need_wait) const { if (device_sync_ == nullptr) { return; } - if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { + auto address = device_sync_; + if (!address->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; } sync_status_ = kNeedSyncHostToDevice;