diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index d16e0ace82..c6f7adde30 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -269,9 +269,13 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); } -std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, - const std::vector &input_tensors) { +std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector &input_tensors, + const std::vector &tensors_mask) { MS_EXCEPTION_IF_NULL(op_exec_info); + if (input_tensors.size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " is not equal to the size of tensors mask " + << tensors_mask.size(); + } std::string graph_info; // get input tensor info for (size_t index = 0; index < input_tensors.size(); ++index) { @@ -290,7 +294,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, (void)graph_info.append(std::dynamic_pointer_cast(tensor_addr)->format()); graph_info += "_"; } - if (static_cast(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) { + if (tensors_mask[index] == kValueNodeTensorMask) { if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) { (void)graph_info.append(std::to_string(*reinterpret_cast(input_tensors[index]->data_c()))); graph_info += "_"; @@ -1499,7 +1503,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); ConvertAttrToUnifyMindIR(op_exec_info); // get graph info for checking it whether existing in the cache - std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); + std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors, tensors_mask); #if defined(__APPLE__) session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive,