Browse Source

fix summary bug of gradient

tags/v0.3.0-alpha
lizhenyu 5 years ago
parent
commit
da55cf6d50
1 changed files with 3 additions and 0 deletions
  1. +3
    -0
      mindspore/ccsrc/session/session_basic.cc

+ 3
- 0
mindspore/ccsrc/session/session_basic.cc View File

@@ -785,6 +785,9 @@ void SessionBasic::Summary(KernelGraph *graph) {
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(address);
if (!address->GetPtr()) {
continue;
}
if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
tensor->data_type(), tensor->data_c(true))) {
MS_LOG(ERROR) << "Failed to sync output from device to host.";


Loading…
Cancel
Save