Browse Source

fix_bug_of_memory_access_error_when_using_input_mask

pull/15529/head
lvliang 5 years ago
parent
commit
8904a07e55
1 changed files with 8 additions and 4 deletions
  1. +8
    -4
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 8
- 4
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -269,9 +269,13 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
}

std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(op_exec_info);
if (input_tensors.size() != tensors_mask.size()) {
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " is not equal to the size of tensors mask "
<< tensors_mask.size();
}
std::string graph_info;
// get input tensor info
for (size_t index = 0; index < input_tensors.size(); ++index) {
@@ -290,7 +294,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format());
graph_info += "_";
}
if (static_cast<int64_t>(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) {
if (tensors_mask[index] == kValueNodeTensorMask) {
if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) {
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())));
graph_info += "_";
@@ -1499,7 +1503,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
ConvertAttrToUnifyMindIR(op_exec_info);
// get graph info for checking it whether existing in the cache
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask);
#if defined(__APPLE__)
session::OpRunInfo op_run_info = {op_exec_info->op_name,
op_exec_info->py_primitive,


Loading…
Cancel
Save