Browse Source

!14416 fix run on device bug

From: @zhoufeng54
Reviewed-by: @xu-yfei,@kisnwang
Signed-off-by: @xu-yfei
pull/14416/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
a6bc4b74fd
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/ccsrc/cxx_api/graph/acl/model_process.cc

+ 5
- 1
mindspore/ccsrc/cxx_api/graph/acl/model_process.cc View File

@@ -101,6 +101,10 @@ Status ModelProcess::ConstructTensors(const std::vector<AclTensorInfo> &acl_tens
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
tensor_list->emplace_back(names[i], data_types[i], shapes[i], nullptr, mem_sizes[i]);
if (acl_tensor_list[i].cur_device_data == nullptr) {
// when run on device, cur_device_data is nullptr before first execute
continue;
}
auto ret = aclrtMemcpy((*tensor_list)[i].MutableData(), (*tensor_list)[i].DataSize(),
acl_tensor_list[i].cur_device_data, acl_tensor_list[i].buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
@@ -401,7 +405,7 @@ Status ModelProcess::CheckAndInitInput(const std::vector<MSTensor> &inputs) {
input_buffer = info.cur_device_data;
}
} else {
input_buffer = const_cast<void *>(data);
input_buffer = data;
}
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
if (data_buffer == nullptr) {


Loading…
Cancel
Save