Browse Source

fix bug with use trans_data to reduce print time in graph mode

tags/v0.6.0-beta
lvchangquan 5 years ago
parent
commit
a91f076e67
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc

+ 3
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc View File

@@ -482,7 +482,9 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape = GetDeviceShape(&host_shape);
if (type_id_name_map.find(type_id_) != type_id_name_map.end()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode && type_id_name_map.find(type_id_) != type_id_name_map.end()) {
std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_);
if (use_trans_data.find(type_format) != use_trans_data.end()) {
sync_ok = SyncDeviceToHostAndConvertFormatBasedOnTransData(host_shape, device_shape, size, type, host_ptr);


Loading…
Cancel
Save