Browse Source

optimize multi graph controldepend input

tags/v1.1.0
kswang 5 years ago
parent
commit
841661b8aa
8 changed files with 39 additions and 29 deletions
  1. +21
    -8
      mindspore/ccsrc/backend/session/executor.cc
  2. +4
    -2
      mindspore/ccsrc/backend/session/executor.h
  3. +3
    -12
      mindspore/ccsrc/backend/session/executor_manager.cc
  4. +2
    -2
      mindspore/ccsrc/backend/session/executor_manager.h
  5. +1
    -1
      mindspore/ccsrc/backend/session/session_basic.cc
  6. +4
    -0
      mindspore/ccsrc/vm/segment_runner.cc
  7. +3
    -3
      mindspore/core/ir/tensor.h
  8. +1
    -1
      mindspore/core/utils/ms_exception.h

+ 21
- 8
mindspore/ccsrc/backend/session/executor.cc View File

@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/session/executor.h" #include "backend/session/executor.h"
#include "backend/session/executor_manager.h"
#include <algorithm> #include <algorithm>
#include <exception> #include <exception>

#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "utils/scoped_long_running.h" #include "utils/scoped_long_running.h"
@@ -120,14 +119,15 @@ void RunGraphTask::Run() {
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
UpdateOutputTensors(&outputs_, tensor_to_node_); UpdateOutputTensors(&outputs_, tensor_to_node_);
} catch (const std::exception &e) { } catch (const std::exception &e) {
MsException::GetInstance().SetException();
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
MsException::Instance().SetException();
} }
graph->OnRunGraphFinished(); graph->OnRunGraphFinished();
for (auto &tensor : input_need_lock_tensors_) { for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false); tensor->SetNeedWait(false);
} }
NotifyOutputTensors(&outputs_); NotifyOutputTensors(&outputs_);
ExecutorManager::Instance().OnRunGraphFinished();
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
MS_LOG(INFO) << "End run graph " << graph_id_; MS_LOG(INFO) << "End run graph " << graph_id_;
} }


@@ -187,7 +187,8 @@ void Executor::WorkerLoop() {
try { try {
task->Run(); task->Run();
} catch (const std::exception &e) { } catch (const std::exception &e) {
MsException::GetInstance().SetException();
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
MsException::Instance().SetException();
} }
{ {
std::unique_lock<std::mutex> lock(task_mutex_); std::unique_lock<std::mutex> lock(task_mutex_);
@@ -214,6 +215,18 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
return new_ready_tasks; return new_ready_tasks;
} }


void Executor::OnEvent(const ExecutorEvent &event) {
if (event == ExecutorEvent::kRunGraphFinished) {
OnRunGraphFinished();
} else if (event == ExecutorEvent::kException) {
std::unique_lock<std::mutex> lock(task_mutex_);
while (!ready_tasks_.empty()) {
done_tasks_.emplace_back(ready_tasks_.front());
ready_tasks_.pop();
}
}
}

void Executor::OnRunGraphFinished() { void Executor::OnRunGraphFinished() {
auto new_ready_tasks = GetNewReadyTasks(); auto new_ready_tasks = GetNewReadyTasks();
std::unique_lock<std::mutex> lock(task_mutex_); std::unique_lock<std::mutex> lock(task_mutex_);
@@ -249,7 +262,7 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
done_tasks_.clear(); done_tasks_.clear();
task_cond_var_.notify_all(); task_cond_var_.notify_all();
sync_cond_var_.wait(lock); sync_cond_var_.wait(lock);
MsException::GetInstance().CheckException();
MsException::Instance().CheckException();
} }


GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
@@ -311,7 +324,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
} }
} }
} }
MsException::GetInstance().CheckException();
MsException::Instance().CheckException();
for (auto &tensor : task->input_need_lock_tensors_) { for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true); tensor->SetNeedWait(true);
} }
@@ -332,7 +345,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
mindspore::ScopedLongRunning long_running; mindspore::ScopedLongRunning long_running;
std::unique_lock<std::mutex> lock(reenter_mutex_); std::unique_lock<std::mutex> lock(reenter_mutex_);
reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
MsException::GetInstance().CheckException();
MsException::Instance().CheckException();
} }
} }




+ 4
- 2
mindspore/ccsrc/backend/session/executor.h View File

@@ -26,7 +26,6 @@
#include <thread> #include <thread>
#include <utility> #include <utility>
#include <vector> #include <vector>

#include "backend/session/session_basic.h" #include "backend/session/session_basic.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/tensor.h" #include "ir/tensor.h"
@@ -156,6 +155,8 @@ class ExitTask : public Task {
~ExitTask() override = default; ~ExitTask() override = default;
}; };


enum class ExecutorEvent { kClear, kRunGraphFinished, kException };

class Executor { class Executor {
public: public:
Executor(const std::string &device_name, uint32_t device_id); Executor(const std::string &device_name, uint32_t device_id);
@@ -176,9 +177,9 @@ class Executor {
VectorRef *outputs); VectorRef *outputs);
void CleanUselessTensors(const SessionPtr &session, void CleanUselessTensors(const SessionPtr &session,
const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
void OnRunGraphFinished();
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
bool DestroyCommGroup(const std::string &group_name); bool DestroyCommGroup(const std::string &group_name);
void OnEvent(const ExecutorEvent &event);


private: private:
void SyncRunTask(const std::shared_ptr<Task> &task); void SyncRunTask(const std::shared_ptr<Task> &task);
@@ -188,6 +189,7 @@ class Executor {
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task); bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
void CheckException(); void CheckException();
void OnWorkerExit(); void OnWorkerExit();
void OnRunGraphFinished();


uint32_t device_id_; uint32_t device_id_;
std::string device_name_; std::string device_name_;


+ 3
- 12
mindspore/ccsrc/backend/session/executor_manager.cc View File

@@ -28,26 +28,17 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device
return executor; return executor;
} }


void ExecutorManager::OnRunGraphFinished() {
void ExecutorManager::OnEvent(const ExecutorEvent &event) {
for (auto &item : executors_) { for (auto &item : executors_) {
auto &executor = item.second; auto &executor = item.second;
if (executor != nullptr) { if (executor != nullptr) {
executor->OnRunGraphFinished();
}
}
}

void ExecutorManager::JoinExecutorWorkers() {
for (auto &item : executors_) {
auto &executor = item.second;
if (executor != nullptr) {
executor->WorkerJoin();
executor->OnEvent(event);
} }
} }
} }


void ExecutorManager::Clear() { void ExecutorManager::Clear() {
JoinExecutorWorkers();
OnEvent(ExecutorEvent::kClear);
executors_.clear(); executors_.clear();
} }
} // namespace session } // namespace session


+ 2
- 2
mindspore/ccsrc/backend/session/executor_manager.h View File

@@ -30,14 +30,14 @@ class ExecutorManager {
return instance; return instance;
} }
std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id); std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id);
void OnRunGraphFinished();
void OnEvent(const ExecutorEvent &event);
void Clear(); void Clear();
private: private:
ExecutorManager() = default; ExecutorManager() = default;
~ExecutorManager() = default; ~ExecutorManager() = default;
DISABLE_COPY_AND_ASSIGN(ExecutorManager) DISABLE_COPY_AND_ASSIGN(ExecutorManager)
void JoinExecutorWorkers();
std::map<std::string, std::shared_ptr<Executor>> executors_; std::map<std::string, std::shared_ptr<Executor>> executors_;
}; };
} // namespace session } // namespace session


+ 1
- 1
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -637,7 +637,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
(*other_graph_cnode)[anf] = new_parameter; (*other_graph_cnode)[anf] = new_parameter;
} }
continue; continue;
} else if (optimize_control_depend) {
} else if (optimize_control_depend || IsPrimitiveCNode(anf, prim::kPrimControlDepend)) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx)))); cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx))));
} else { } else {
// the input node is a cnode from other graph // the input node is a cnode from other graph


+ 4
- 0
mindspore/ccsrc/vm/segment_runner.cc View File

@@ -87,6 +87,10 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
eqv[node] = node; eqv[node] = node;
} else if (eqv.find(node) == eqv.end()) { } else if (eqv.find(node) == eqv.end()) {
if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
eqv[node] = NewValueNode(MakeValue(0));
return eqv[node];
}
bool ignore_make_tuple = false; bool ignore_make_tuple = false;
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
ignore_make_tuple = true; ignore_make_tuple = true;


+ 3
- 3
mindspore/core/ir/tensor.h View File

@@ -87,10 +87,10 @@ class WaitEvent : public ExceptionListener {
if (!need_wait_) { if (!need_wait_) {
return; return;
} }
MsException::GetInstance().AddExceptionListener(const_cast<WaitEvent *>(this));
MsException::Instance().AddExceptionListener(const_cast<WaitEvent *>(this));
cond_var_.wait(lock, [this] { return !need_wait_; }); cond_var_.wait(lock, [this] { return !need_wait_; });
MsException::GetInstance().CheckException();
MsException::GetInstance().RemoveExceptionListener(const_cast<WaitEvent *>(this));
MsException::Instance().CheckException();
MsException::Instance().RemoveExceptionListener(const_cast<WaitEvent *>(this));
} }


void set_need_wait(bool need_wait) { void set_need_wait(bool need_wait) {


+ 1
- 1
mindspore/core/utils/ms_exception.h View File

@@ -27,7 +27,7 @@ class ExceptionListener {


class MsException { class MsException {
public: public:
static MsException &GetInstance() {
static MsException &Instance() {
static MsException instance; static MsException instance;
return instance; return instance;
} }


Loading…
Cancel
Save