Browse Source

optimize multi graph controldepend input

tags/v1.1.0
kswang 5 years ago
parent
commit
841661b8aa
8 changed files with 39 additions and 29 deletions
  1. +21
    -8
      mindspore/ccsrc/backend/session/executor.cc
  2. +4
    -2
      mindspore/ccsrc/backend/session/executor.h
  3. +3
    -12
      mindspore/ccsrc/backend/session/executor_manager.cc
  4. +2
    -2
      mindspore/ccsrc/backend/session/executor_manager.h
  5. +1
    -1
      mindspore/ccsrc/backend/session/session_basic.cc
  6. +4
    -0
      mindspore/ccsrc/vm/segment_runner.cc
  7. +3
    -3
      mindspore/core/ir/tensor.h
  8. +1
    -1
      mindspore/core/utils/ms_exception.h

+ 21
- 8
mindspore/ccsrc/backend/session/executor.cc View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include "backend/session/executor.h"
#include "backend/session/executor_manager.h"
#include <algorithm>
#include <exception>

#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/comm_manager.h"
#include "utils/scoped_long_running.h"
@@ -120,14 +119,15 @@ void RunGraphTask::Run() {
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
UpdateOutputTensors(&outputs_, tensor_to_node_);
} catch (const std::exception &e) {
MsException::GetInstance().SetException();
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
MsException::Instance().SetException();
}
graph->OnRunGraphFinished();
for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false);
}
NotifyOutputTensors(&outputs_);
ExecutorManager::Instance().OnRunGraphFinished();
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
MS_LOG(INFO) << "End run graph " << graph_id_;
}

@@ -187,7 +187,8 @@ void Executor::WorkerLoop() {
try {
task->Run();
} catch (const std::exception &e) {
MsException::GetInstance().SetException();
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
MsException::Instance().SetException();
}
{
std::unique_lock<std::mutex> lock(task_mutex_);
@@ -214,6 +215,18 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
return new_ready_tasks;
}

void Executor::OnEvent(const ExecutorEvent &event) {
if (event == ExecutorEvent::kRunGraphFinished) {
OnRunGraphFinished();
} else if (event == ExecutorEvent::kException) {
std::unique_lock<std::mutex> lock(task_mutex_);
while (!ready_tasks_.empty()) {
done_tasks_.emplace_back(ready_tasks_.front());
ready_tasks_.pop();
}
}
}

void Executor::OnRunGraphFinished() {
auto new_ready_tasks = GetNewReadyTasks();
std::unique_lock<std::mutex> lock(task_mutex_);
@@ -249,7 +262,7 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
done_tasks_.clear();
task_cond_var_.notify_all();
sync_cond_var_.wait(lock);
MsException::GetInstance().CheckException();
MsException::Instance().CheckException();
}

GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
@@ -311,7 +324,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
}
}
}
MsException::GetInstance().CheckException();
MsException::Instance().CheckException();
for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true);
}
@@ -332,7 +345,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
mindspore::ScopedLongRunning long_running;
std::unique_lock<std::mutex> lock(reenter_mutex_);
reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
MsException::GetInstance().CheckException();
MsException::Instance().CheckException();
}
}



+ 4
- 2
mindspore/ccsrc/backend/session/executor.h View File

@@ -26,7 +26,6 @@
#include <thread>
#include <utility>
#include <vector>

#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "ir/tensor.h"
@@ -156,6 +155,8 @@ class ExitTask : public Task {
~ExitTask() override = default;
};

enum class ExecutorEvent { kClear, kRunGraphFinished, kException };

class Executor {
public:
Executor(const std::string &device_name, uint32_t device_id);
@@ -176,9 +177,9 @@ class Executor {
VectorRef *outputs);
void CleanUselessTensors(const SessionPtr &session,
const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors);
void OnRunGraphFinished();
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
bool DestroyCommGroup(const std::string &group_name);
void OnEvent(const ExecutorEvent &event);

private:
void SyncRunTask(const std::shared_ptr<Task> &task);
@@ -188,6 +189,7 @@ class Executor {
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
void CheckException();
void OnWorkerExit();
void OnRunGraphFinished();

uint32_t device_id_;
std::string device_name_;


+ 3
- 12
mindspore/ccsrc/backend/session/executor_manager.cc View File

@@ -28,26 +28,17 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device
return executor;
}

void ExecutorManager::OnRunGraphFinished() {
void ExecutorManager::OnEvent(const ExecutorEvent &event) {
for (auto &item : executors_) {
auto &executor = item.second;
if (executor != nullptr) {
executor->OnRunGraphFinished();
}
}
}

void ExecutorManager::JoinExecutorWorkers() {
for (auto &item : executors_) {
auto &executor = item.second;
if (executor != nullptr) {
executor->WorkerJoin();
executor->OnEvent(event);
}
}
}

void ExecutorManager::Clear() {
JoinExecutorWorkers();
OnEvent(ExecutorEvent::kClear);
executors_.clear();
}
} // namespace session


+ 2
- 2
mindspore/ccsrc/backend/session/executor_manager.h View File

@@ -30,14 +30,14 @@ class ExecutorManager {
return instance;
}
std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id);
void OnRunGraphFinished();
void OnEvent(const ExecutorEvent &event);
void Clear();
private:
ExecutorManager() = default;
~ExecutorManager() = default;
DISABLE_COPY_AND_ASSIGN(ExecutorManager)
void JoinExecutorWorkers();
std::map<std::string, std::shared_ptr<Executor>> executors_;
};
} // namespace session


+ 1
- 1
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -637,7 +637,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
(*other_graph_cnode)[anf] = new_parameter;
}
continue;
} else if (optimize_control_depend) {
} else if (optimize_control_depend || IsPrimitiveCNode(anf, prim::kPrimControlDepend)) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx))));
} else {
// the input node is a cnode from other graph


+ 4
- 0
mindspore/ccsrc/vm/segment_runner.cc View File

@@ -87,6 +87,10 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
eqv[node] = node;
} else if (eqv.find(node) == eqv.end()) {
if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
eqv[node] = NewValueNode(MakeValue(0));
return eqv[node];
}
bool ignore_make_tuple = false;
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
ignore_make_tuple = true;


+ 3
- 3
mindspore/core/ir/tensor.h View File

@@ -87,10 +87,10 @@ class WaitEvent : public ExceptionListener {
if (!need_wait_) {
return;
}
MsException::GetInstance().AddExceptionListener(const_cast<WaitEvent *>(this));
MsException::Instance().AddExceptionListener(const_cast<WaitEvent *>(this));
cond_var_.wait(lock, [this] { return !need_wait_; });
MsException::GetInstance().CheckException();
MsException::GetInstance().RemoveExceptionListener(const_cast<WaitEvent *>(this));
MsException::Instance().CheckException();
MsException::Instance().RemoveExceptionListener(const_cast<WaitEvent *>(this));
}

void set_need_wait(bool need_wait) {


+ 1
- 1
mindspore/core/utils/ms_exception.h View File

@@ -27,7 +27,7 @@ class ExceptionListener {

class MsException {
public:
static MsException &GetInstance() {
static MsException &Instance() {
static MsException instance;
return instance;
}


Loading…
Cancel
Save