|
|
|
@@ -149,23 +149,19 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { |
|
|
|
return op_exec_info; |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info) { |
|
|
|
std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, |
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors) { |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info); |
|
|
|
std::string graph_info; |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info->abstract); |
|
|
|
// get input tensor info |
|
|
|
size_t input_num = op_exec_info->op_inputs.size(); |
|
|
|
for (size_t index = 0; index < input_num; ++index) { |
|
|
|
if (py::isinstance<tensor::Tensor>(op_exec_info->op_inputs[index])) { |
|
|
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(op_exec_info->op_inputs[index]); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr); |
|
|
|
(void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); |
|
|
|
} |
|
|
|
for (const auto &input_tensor : input_tensors) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_tensor); |
|
|
|
(void)graph_info.append(input_tensor->GetShapeAndDataTypeInfo() + "_"); |
|
|
|
} |
|
|
|
// get prim and abstract info |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info->abstract); |
|
|
|
(void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + |
|
|
|
op_exec_info->abstract->ToString()); |
|
|
|
MS_LOG(INFO) << "Graph info [" << graph_info << "]"; |
|
|
|
return graph_info; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -337,14 +333,14 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat |
|
|
|
if (session == nullptr) { |
|
|
|
session = session::SessionFactory::Get().Create(device_target); |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session); |
|
|
|
session->Init(ms_context->device_id()); |
|
|
|
|
|
|
|
std::string graph_info = GetSingleOpGraphInfo(op_exec_info); |
|
|
|
std::vector<tensor::TensorPtr> input_tensors; |
|
|
|
std::vector<int> tensors_mask; |
|
|
|
ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); |
|
|
|
// get graph info for checking it whether existing in the cache |
|
|
|
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); |
|
|
|
session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); |
|
|
|
EraseValueNodeTensor(tensors_mask, &input_tensors); |
|
|
|
py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); |
|
|
|
|