Browse Source

!7941 fix pynative grad error

From: @kisnwang
Reviewed-by: @jjfeing,@chujinjin
Signed-off-by: @jjfeing
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
beb86391fe
2 changed files with 6 additions and 5 deletions
  1. +4
    -4
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +2
    -1
      mindspore/core/ir/tensor.cc

+ 4
- 4
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -999,16 +999,16 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
}
tensor->set_sync_status(kNoNeedSync);
}


+ 2
- 1
mindspore/core/ir/tensor.cc View File

@@ -581,7 +581,8 @@ void Tensor::data_sync(bool need_wait) const {
if (device_sync_ == nullptr) {
return;
}
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
auto address = device_sync_;
if (!address->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
sync_status_ = kNeedSyncHostToDevice;


Loading…
Cancel
Save