Browse Source

format some executor func name

pull/14426/head
kswang 4 years ago
parent
commit
54ef8520ab
9 changed files with 120 additions and 175 deletions
  1. +0
    -5
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +0
    -1
      mindspore/ccsrc/backend/session/ascend_session.h
  3. +52
    -53
      mindspore/ccsrc/backend/session/executor.cc
  4. +1
    -3
      mindspore/ccsrc/backend/session/executor.h
  5. +54
    -93
      mindspore/ccsrc/backend/session/session_basic.cc
  6. +0
    -5
      mindspore/ccsrc/backend/session/session_basic.h
  7. +12
    -12
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
  8. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h
  9. +0
    -2
      mindspore/ccsrc/runtime/device/kernel_runtime.h

+ 0
- 5
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -1517,11 +1517,6 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
} }
} }


GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
RunInfer(func_graph, inputs);
return CompileGraphImpl(func_graph);
}

void AscendSession::SyncStream() { void AscendSession::SyncStream() {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);


+ 0
- 1
mindspore/ccsrc/backend/session/ascend_session.h View File

@@ -49,7 +49,6 @@ class AscendSession : public SessionBasic {
void UnifyMindIR(const KernelGraphPtr &graph) override; void UnifyMindIR(const KernelGraphPtr &graph) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;
bool IsSupportSummary() override; bool IsSupportSummary() override;
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraphImpl(GraphId) override; void BuildGraphImpl(GraphId) override;


+ 52
- 53
mindspore/ccsrc/backend/session/executor.cc View File

@@ -32,7 +32,7 @@ namespace {
void UpdateOutputTensors(const VectorRef *outputs, void UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) { const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
for (auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) { if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item); auto vector_ref = utils::cast<VectorRef>(item);
UpdateOutputTensors(&vector_ref, tensor_to_node); UpdateOutputTensors(&vector_ref, tensor_to_node);
@@ -45,7 +45,6 @@ void UpdateOutputTensors(const VectorRef *outputs,
auto &output_index = iter->second.second; auto &output_index = iter->second.second;
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address); tensor->set_device_address(address);

if (AnfAlgo::IsDynamicShape(node)) { if (AnfAlgo::IsDynamicShape(node)) {
auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index); auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
ShapeVector int_shape; ShapeVector int_shape;
@@ -62,12 +61,12 @@ void UpdateOutputTensors(const VectorRef *outputs,
} }
} }


void NotifyOutputTensors(const VectorRef *outputs) {
void SetOutputTensorsWaitStatus(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
for (auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) { if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item); auto vector_ref = utils::cast<VectorRef>(item);
NotifyOutputTensors(&vector_ref);
SetOutputTensorsWaitStatus(&vector_ref);
} else if (utils::isa<tensor::TensorPtr>(item)) { } else if (utils::isa<tensor::TensorPtr>(item)) {
auto tensor = utils::cast<tensor::TensorPtr>(item); auto tensor = utils::cast<tensor::TensorPtr>(item);
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
@@ -78,7 +77,7 @@ void NotifyOutputTensors(const VectorRef *outputs) {


bool TensorInVector(const VectorRef *outputs) { bool TensorInVector(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
for (auto &item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) { if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item); auto vector_ref = utils::cast<VectorRef>(item);
if (TensorInVector(&vector_ref)) { if (TensorInVector(&vector_ref)) {
@@ -90,6 +89,50 @@ bool TensorInVector(const VectorRef *outputs) {
} }
return false; return false;
} }

bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
MS_EXCEPTION_IF_NULL(task);
for (auto &input : task->input_need_wait_tensors_) {
MS_EXCEPTION_IF_NULL(input);
if (input->NeedWait()) {
return false;
}
}
auto session = task->session_;
MS_EXCEPTION_IF_NULL(session);
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
return graph->IsPreGraphFinished();
}
return true;
}

void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
bool need_lock = false;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait()) {
if (tensor->IsGraphOutput()) {
task->input_need_wait_tensors_.emplace_back(tensor);
} else {
need_lock = true;
}
}
}
if (need_lock) {
mindspore::ScopedLongRunning long_running;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
MsException::Instance().CheckException();
tensor->Wait();
}
}
MsException::Instance().CheckException();
}
// need lock input parameters for optimizer
for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true);
}
}
} // namespace } // namespace


void CompileNodesTask::Run() { void CompileNodesTask::Run() {
@@ -129,7 +172,7 @@ void RunGraphTask::Run() {
for (auto &tensor : input_need_lock_tensors_) { for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false); tensor->SetNeedWait(false);
} }
NotifyOutputTensors(&outputs_);
SetOutputTensorsWaitStatus(&outputs_);
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
} }


@@ -198,7 +241,7 @@ void Executor::WorkerLoop() {
} }
} }


std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks; std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
std::lock_guard<std::mutex> lock(pending_task_mutex_); std::lock_guard<std::mutex> lock(pending_task_mutex_);
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
@@ -249,7 +292,7 @@ void Executor::OnException() {
} }


void Executor::OnRunGraphFinished() { void Executor::OnRunGraphFinished() {
auto new_ready_tasks = GetNewReadyTasks();
auto new_ready_tasks = GetReadyTasksFromPendingList();
std::lock_guard<std::mutex> lock(task_mutex_); std::lock_guard<std::mutex> lock(task_mutex_);
for (auto &task : new_ready_tasks) { for (auto &task : new_ready_tasks) {
ready_tasks_.push(task); ready_tasks_.push(task);
@@ -260,23 +303,6 @@ void Executor::OnRunGraphFinished() {
reenter_cond_var_.notify_all(); reenter_cond_var_.notify_all();
} }


bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
MS_EXCEPTION_IF_NULL(task);
for (auto &input : task->input_need_wait_tensors_) {
MS_EXCEPTION_IF_NULL(input);
if (input->NeedWait()) {
return false;
}
}
auto session = task->session_;
MS_EXCEPTION_IF_NULL(session);
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
return graph->IsPreGraphFinished();
}
return true;
}

void Executor::ClearDoneTasks() { void Executor::ClearDoneTasks() {
std::lock_guard<std::mutex> lock(done_task_mutex_); std::lock_guard<std::mutex> lock(done_task_mutex_);
done_tasks_.clear(); done_tasks_.clear();
@@ -341,33 +367,6 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
RunTask(task, true, true); RunTask(task, true, true);
} }


void Executor::WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
bool need_lock = false;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait()) {
if (tensor->IsGraphOutput()) {
task->input_need_wait_tensors_.emplace_back(tensor);
} else {
need_lock = true;
}
}
}
if (need_lock) {
mindspore::ScopedLongRunning long_running;
for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
MsException::Instance().CheckException();
tensor->Wait();
}
}
MsException::Instance().CheckException();
}
// need lock input parameters for optimizer
for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true);
}
}

void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session); MS_EXCEPTION_IF_NULL(session);


+ 1
- 3
mindspore/ccsrc/backend/session/executor.h View File

@@ -172,9 +172,7 @@ class Executor {


private: private:
void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false); void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task);
std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
void OnWorkerExit(); void OnWorkerExit();
void OnClear(); void OnClear();
void OnRunGraphFinished(); void OnRunGraphFinished();


+ 54
- 93
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -118,68 +118,16 @@ ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
return parameter->param_info(); return parameter->param_info();
} }


tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
const KernelGraphPtr &graph) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor = nullptr;
std::vector<int64_t> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_sync_status(kNoNeedSync);
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}

tensor = graph->GetInternalOutputTensor(node, output_index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
graph->AddInternalOutputTensor(node, output_index, tensor);
}
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in pynative mode,data only copied to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}

static bool IsPynativeMode() { static bool IsPynativeMode() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode; return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
} }


BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
auto &node = node_output_pair.first; auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(tensor_to_node);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]";
if (HasAbstractMonad(node)) { if (HasAbstractMonad(node)) {
return std::make_shared<tensor::Tensor>(int64_t(0), kBool); return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
} }
@@ -189,7 +137,8 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
return value_node->value(); return value_node->value();
} }
bool output_addr_exist = AnfAlgo::OutputAddrExist(node, output_index);
MS_EXCEPTION_IF_NULL(graph);
bool output_addr_exist = AnfAlgo::OutputAddrExist(node, node_output_pair.second);
if (!output_addr_exist || (CheckIfNeedCreateOutputTensor(node) && !IsPynativeMode())) { if (!output_addr_exist || (CheckIfNeedCreateOutputTensor(node) && !IsPynativeMode())) {
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
@@ -205,7 +154,56 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
} }
} }
} }
auto tensor = CreateCNodeOutputTensor(node_output_pair, graph);
return nullptr;
}

BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
auto &node = node_output_pair.first;
auto &output_index = node_output_pair.second;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
if (tensor_from_input != nullptr) {
return tensor_from_input;
}
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (type_id == kTypeUnknown) {
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
tensor::TensorPtr tensor = nullptr;
std::vector<int64_t> temp_shape;
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_sync_status(kNoNeedSync);
} else {
tensor = graph->GetInternalOutputTensor(node, output_index);
if (tensor == nullptr) {
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
bool is_internal_output = graph->IsInternalOutput(node, output_index);
if (is_internal_output) {
graph->AddInternalOutputTensor(node, output_index, tensor);
}
}
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
// if in pynative mode,data only copied to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
}
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
(*tensor_to_node)[tensor] = node_output_pair; (*tensor_to_node)[tensor] = node_output_pair;
return tensor; return tensor;
} }
@@ -1778,43 +1776,6 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
summary_callback_ = callback; summary_callback_ = callback;
} }


void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
auto node_list = TopoSort(func_graph->get_return());
size_t tensor_index = 0;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
AbstractBasePtrList input_abstracts;
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t index = 0; index < input_num; ++index) {
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index);
MS_EXCEPTION_IF_NULL(input_node);
auto abstract = input_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
input_abstracts.emplace_back(abstract);
}
auto prim = AnfAlgo::GetCNodePrimitive(node);
if (prim->isa<ops::PrimitiveC>()) {
auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>();
MS_EXCEPTION_IF_NULL(prim_c);
auto abstract = prim_c->Infer(input_abstracts);
node->set_abstract(abstract);
}
} else if (node->isa<Parameter>()) {
if (tensor_index > inputs.size()) {
MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size";
}
node->set_abstract(inputs[tensor_index++]->ToAbstract());
} else {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
value_node->set_abstract(value->ToAbstract());
}
}
}

void SessionBasic::SetSummaryNodes(KernelGraph *graph) { void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start"; MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);


+ 0
- 5
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -166,9 +166,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0; virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0;
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
MS_EXCEPTION(NotExistsError) << "Call an empty function";
}
virtual void BuildGraphImpl(GraphId) {} virtual void BuildGraphImpl(GraphId) {}
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) = 0; VectorRef *outputs) = 0;
@@ -182,8 +179,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index, virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, size_t> &cnode_refcount) {} const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);

virtual void SetSummaryNodes(KernelGraph *graph); virtual void SetSummaryNodes(KernelGraph *graph);


virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,


+ 12
- 12
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc View File

@@ -131,7 +131,7 @@ void AscendKernelRuntime::SetContext() {
} }
} }


void AscendKernelRuntime::InnerSetContext() {
void AscendKernelRuntime::SetCurrentContext() {
if (rt_context_ == nullptr) { if (rt_context_ == nullptr) {
return; return;
} }
@@ -142,7 +142,7 @@ void AscendKernelRuntime::InnerSetContext() {
} }


void AscendKernelRuntime::ClearGraphModelMap() { void AscendKernelRuntime::ClearGraphModelMap() {
InnerSetContext();
SetCurrentContext();
for (auto &iter : graph_data_dumper_) { for (auto &iter : graph_data_dumper_) {
MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first;
auto &data_dumper = iter.second; auto &data_dumper = iter.second;
@@ -168,7 +168,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &, void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
const std::unordered_set<ValueNodePtr> &, const std::unordered_set<ValueNodePtr> &,
const std::vector<CNodePtr> &) { const std::vector<CNodePtr> &) {
InnerSetContext();
SetCurrentContext();
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper"; MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) { if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
MS_LOG(DEBUG) << "Unload dump info " << graph_id; MS_LOG(DEBUG) << "Unload dump info " << graph_id;
@@ -247,7 +247,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
if (!initialized_) { if (!initialized_) {
return; return;
} }
InnerSetContext();
SetCurrentContext();
ReportProfilingData(); ReportProfilingData();
// release ge runtime // release ge runtime
ClearGraphModelMap(); ClearGraphModelMap();
@@ -284,7 +284,7 @@ void AscendKernelRuntime::PreInit() {


bool AscendKernelRuntime::Init() { bool AscendKernelRuntime::Init() {
if (initialized_) { if (initialized_) {
InnerSetContext();
SetCurrentContext();
return true; return true;
} }
OpTilingCalculater::GetInstance().Init(); OpTilingCalculater::GetInstance().Init();
@@ -437,7 +437,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {


bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
InnerSetContext();
SetCurrentContext();
if (graph->is_dynamic_shape()) { if (graph->is_dynamic_shape()) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) { if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode."; MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode.";
@@ -498,7 +498,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {


bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
InnerSetContext();
SetCurrentContext();
if (graph->is_dynamic_shape()) { if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step"; MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
return true; return true;
@@ -716,7 +716,7 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap


bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
current_graph_ = graph; current_graph_ = graph;
InnerSetContext();
SetCurrentContext();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (graph->is_dynamic_shape()) { if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async"; MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
@@ -761,7 +761,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
} }


bool AscendKernelRuntime::SyncStream() { bool AscendKernelRuntime::SyncStream() {
InnerSetContext();
SetCurrentContext();
if (stream_ == nullptr) { if (stream_ == nullptr) {
MS_LOG(ERROR) << "SyncStream failed. stream_ is nullptr"; MS_LOG(ERROR) << "SyncStream failed. stream_ is nullptr";
return false; return false;
@@ -779,7 +779,7 @@ bool AscendKernelRuntime::SyncStream() {
} }


bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) { bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) {
InnerSetContext();
SetCurrentContext();
if (stream_ == nullptr) { if (stream_ == nullptr) {
MS_LOG(ERROR) << "MemcpyAsync failed. stream_ is nullptr"; MS_LOG(ERROR) << "MemcpyAsync failed. stream_ is nullptr";
return false; return false;
@@ -803,7 +803,7 @@ void AscendKernelRuntime::CreateContext() {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
} }
} }
InnerSetContext();
SetCurrentContext();
} }


bool AscendKernelRuntime::InitDevice() { bool AscendKernelRuntime::InitDevice() {
@@ -850,7 +850,7 @@ bool AscendKernelRuntime::InitDevice() {
} }


bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
InnerSetContext();
SetCurrentContext();
if (stream_ != nullptr) { if (stream_ != nullptr) {
auto ret = rtStreamDestroy(stream_); auto ret = rtStreamDestroy(stream_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {


+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h View File

@@ -76,7 +76,7 @@ class AscendKernelRuntime : public KernelRuntime {
static bool NeedDestroyHccl(); static bool NeedDestroyHccl();
static bool DestroyHccl(); static bool DestroyHccl();
static bool DestroySingleOpHccl(); static bool DestroySingleOpHccl();
void InnerSetContext();
void SetCurrentContext();


void ClearGraphModelMap(); void ClearGraphModelMap();
void ReleaseDeviceRes() override; void ReleaseDeviceRes() override;


+ 0
- 2
mindspore/ccsrc/runtime/device/kernel_runtime.h View File

@@ -121,10 +121,8 @@ class KernelRuntime {


void AssignStaticMemory(session::KernelGraph *graph); void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph);
void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index); void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index);
void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node); void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node);
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);


void UpdateRefNodeOutputMem(const session::KernelGraph *graph); void UpdateRefNodeOutputMem(const session::KernelGraph *graph);




Loading…
Cancel
Save