|
|
|
@@ -20,10 +20,12 @@ |
|
|
|
#include "runtime/device/kernel_runtime_manager.h" |
|
|
|
#include "utils/comm_manager.h" |
|
|
|
#include "utils/scoped_long_running.h" |
|
|
|
#include "pybind_api/ir/tensor_py.h" |
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) |
|
|
|
#include "ps/ps_cache/ps_cache_manager.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
using mindspore::tensor::TensorPy; |
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
namespace { |
|
|
|
@@ -399,12 +401,36 @@ void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const Gr |
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, |
|
|
|
const std::vector<int64_t> &tensors_mask) { |
|
|
|
MS_EXCEPTION_IF_NULL(session); |
|
|
|
for (auto &tensor : *input_tensors) { |
|
|
|
if (tensor->NeedWait()) { |
|
|
|
tensor->Wait(); |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
if (target == kGPUDevice) { |
|
|
|
for (auto &tensor : *input_tensors) { |
|
|
|
if (tensor->NeedWait()) { |
|
|
|
tensor->Wait(); |
|
|
|
} |
|
|
|
} |
|
|
|
{ |
|
|
|
// Release GIL before calling into (potentially long-running) C++ code |
|
|
|
if (Py_IsInitialized()) { |
|
|
|
py::gil_scoped_release release; |
|
|
|
} |
|
|
|
session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask); |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto task = std::make_shared<RunOpTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->op_run_info_ = op_run_info; |
|
|
|
task->graph_info_ = graph_info; |
|
|
|
task->input_tensors_ = input_tensors; |
|
|
|
task->tensors_mask_ = tensors_mask; |
|
|
|
for (auto &tensor : *input_tensors) { |
|
|
|
if (tensor->NeedWait()) { |
|
|
|
tensor->Wait(); |
|
|
|
} |
|
|
|
} |
|
|
|
RunTask(task, true, true); |
|
|
|
*outputs = task->outputs_; |
|
|
|
} |
|
|
|
session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
|