Browse Source

Only weight tensor need to bind with Parameter device_address

tags/v0.5.0-beta
caifubi 5 years ago
parent
commit
88ac2ae514
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      mindspore/ccsrc/session/session_basic.cc

+ 3
- 1
mindspore/ccsrc/session/session_basic.cc View File

@@ -676,7 +676,9 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
} }
} }
if (need_sync) { if (need_sync) {
tensor->set_device_address(device_address);
if (AnfAlgo::IsParameterWeight(pk_node)) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(), LongToSize(tensor->data().nbytes()), tensor->data_type(),


Loading…
Cancel
Save