Browse Source

add graph partition

tags/v1.1.0
kswang 5 years ago
parent
commit
b88e722e12
19 changed files with 605 additions and 505 deletions
  1. +5
    -3
      mindspore/ccsrc/backend/session/executor.cc
  2. +2
    -2
      mindspore/ccsrc/backend/session/executor.h
  3. +2
    -2
      mindspore/ccsrc/backend/session/session_basic.cc
  4. +3
    -3
      mindspore/ccsrc/backend/session/session_basic.h
  5. +1
    -1
      mindspore/ccsrc/pipeline/jit/action.cc
  6. +2
    -1
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  7. +13
    -16
      mindspore/ccsrc/vm/backend.cc
  8. +2
    -1
      mindspore/ccsrc/vm/backend.h
  9. +447
    -0
      mindspore/ccsrc/vm/graph_partition.cc
  10. +48
    -0
      mindspore/ccsrc/vm/graph_partition.h
  11. +5
    -11
      mindspore/ccsrc/vm/segment_runner.cc
  12. +3
    -9
      mindspore/ccsrc/vm/segment_runner.h
  13. +20
    -410
      mindspore/ccsrc/vm/transform.cc
  14. +6
    -7
      mindspore/ccsrc/vm/transform.h
  15. +16
    -0
      mindspore/core/ir/anf.cc
  16. +7
    -0
      mindspore/core/ir/anf.h
  17. +13
    -0
      mindspore/core/ir/func_graph.cc
  18. +1
    -0
      mindspore/core/ir/func_graph.h
  19. +9
    -39
      tests/ut/cpp/vm/segment_runner_test.cc

+ 5
- 3
mindspore/ccsrc/backend/session/executor.cc View File

@@ -89,7 +89,8 @@ bool TensorInVector(const VectorRef *outputs) {

void CompileNodesTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_);
MS_EXCEPTION_IF_NULL(segment_);
graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
}

void CompileGraphTask::Run() {
@@ -226,10 +227,11 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
MsException::GetInstance().CheckException();
}

GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
const AnfNodePtrList &outputs) {
auto task = std::make_shared<CompileNodesTask>();
task->session_ = session;
task->nodes_ = lst;
task->segment_ = segment;
task->output_nodes_ = outputs;
SyncRunTask(task);
return task->graph_id_;


+ 2
- 2
mindspore/ccsrc/backend/session/executor.h View File

@@ -63,7 +63,7 @@ class CompileNodesTask : public Task {
CompileNodesTask() { type_ = kCompileNodes; }
~CompileNodesTask() override = default;
void Run() override;
AnfNodePtrList nodes_;
GraphSegmentPtr segment_;
AnfNodePtrList output_nodes_;
GraphId graph_id_{0};
};
@@ -151,7 +151,7 @@ class Executor {
~Executor();
void WorkerLoop();
void WorkerJoin();
GraphId CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
void BuildGraph(const SessionPtr &session, GraphId graphId);
void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,


+ 2
- 2
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1388,9 +1388,9 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve
return nullptr;
}

GraphId SessionBasic::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->CompileGraph(shared_from_this(), lst, outputs);
return executor_->CompileGraph(shared_from_this(), segment, outputs);
}

GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {


+ 3
- 3
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -68,7 +68,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {

virtual ~SessionBasic() { summary_callback_ = nullptr; }

GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
void BuildGraph(GraphId graphId);
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
@@ -102,6 +102,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {}
std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs);
// Get graph by graph id, if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const;
#ifdef ENABLE_DEBUGGER
// set debugger
void SetDebugger() {
@@ -147,8 +149,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {}
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
// Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const;

virtual void SetSummaryNodes(KernelGraph *graph);



+ 1
- 1
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -354,7 +354,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
auto context_ptr = MsContext::GetInstance();
std::string backend = MsContext::GetInstance()->backend_policy();
MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) {
if (func_graph->ContainMultiTarget()) {
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);


+ 2
- 1
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -923,7 +923,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
MS_EXCEPTION_IF_NULL(convert_fn);
// Convert CNodeList to LinConvertResult.
ConfigManager::GetInstance().set_iter_num(1);
auto runner = convert_fn({app_init}, "");
auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
auto runner = convert_fn(segment, "");
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id);
}


+ 13
- 16
mindspore/ccsrc/vm/backend.cc View File

@@ -34,30 +34,34 @@ namespace compile {
bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }

LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
Backend::Backend(const std::string &name) : name_(name) {
MS_LOG(DEBUG) << "select backend:" << name;
convert_fn_ = MsVmConvert;
is_multi_graph_sink_ = false;
}

LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
MS_LOG(DEBUG) << "MsConvert";
MS_EXCEPTION_IF_NULL(segment);
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
auto cached = g_ConvertCache.find(lst);
auto cached = g_ConvertCache.find(segment);
if (cached != g_ConvertCache.end()) {
return cached->second;
}

LinConvertResult result;

FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;

std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
result.inputs = inputs;
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
GraphId graph_id = kInvalidGraphId;
if (target != target_device_ && !target.empty()) {
CreateOtherSession(target);
graph_id = other_sess_->CompileGraph(lst, outputs);
graph_id = other_sess_->CompileGraph(segment, outputs);
} else {
graph_id = target_sess_->CompileGraph(lst, outputs);
graph_id = target_sess_->CompileGraph(segment, outputs);
}

if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
@@ -79,7 +83,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
result.graph_id = graph_id;

graph_id_map_[graph_id] = result;
(void)g_ConvertCache.emplace(lst, result);
(void)g_ConvertCache.emplace(segment, result);
return result;
}

@@ -154,12 +158,6 @@ void MsBackend::Link(GraphId graph_id) {
target_sess_->BuildGraph(graph_id);
}

Backend::Backend(const std::string &name) : name_(name) {
MS_LOG(DEBUG) << "select backend:" << name;
convert_fn_ = backends[name_];
is_multi_graph_sink_ = false;
}

MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
target_sess_ = session::SessionFactory::Get().Create(target);
@@ -194,6 +192,5 @@ VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return
#ifdef ENABLE_DEBUGGER
void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
#endif

} // namespace compile
} // namespace mindspore

+ 2
- 1
mindspore/ccsrc/vm/backend.h View File

@@ -25,6 +25,7 @@
#include "utils/contract.h"
#include "ir/anf.h"
#include "vm/segment_runner.h"
#include "vm/graph_partition.h"
#include "vm/vm.h"
#include "backend/session/session_basic.h"

@@ -63,7 +64,7 @@ class MsBackend : public Backend {
MsBackend(const std::string &name, const std::string &target, uint32_t device_id);
~MsBackend() override = default;

LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = "");
LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = "");
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");

VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args);


+ 447
- 0
mindspore/ccsrc/vm/graph_partition.cc View File

@@ -0,0 +1,447 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "vm/graph_partition.h"
#include <string>
#include <functional>
#include <utility>
#include <map>
#include <queue>
#include <stack>
#include <set>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "utils/ms_context.h"
namespace mindspore {
const char kMsConvert[] = "ms";
const char kMsVm[] = "vm";
const char kGeVm[] = "ge";
namespace compile {
namespace {
bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node,
std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) {
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(behind_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
if (prior_node->isa<Parameter>()) {
for (auto &user : node_users[prior_node]) {
auto cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
prior_nodes->emplace_back(cnode);
}
}
} else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) {
prior_nodes->emplace_back(prior_node);
} else {
return false;
}
if (behind_node->isa<Parameter>()) {
for (auto &user : node_users[behind_node]) {
auto cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
depend_nodes->emplace_back(cnode);
}
}
} else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) {
depend_nodes->emplace_back(behind_node);
} else {
return false;
}
return true;
}
void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges,
std::map<AnfNodePtr, size_t> *nodes_ref) {
MS_EXCEPTION_IF_NULL(node);
auto input_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto prior_node = input_cnode->input(kControlDependPriorIndex);
auto depend_node = input_cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_EXCEPTION_IF_NULL(prim_ptr);
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
int64_t depend_mode = 0;
if (mode_ptr != nullptr) {
depend_mode = GetValue<int64_t>(mode_ptr);
}
if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) {
return;
}
std::vector<AnfNodePtr> prior_nodes;
std::vector<AnfNodePtr> behind_nodes;
if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) {
return;
}
for (auto &first_node : prior_nodes) {
for (auto &second_node : behind_nodes) {
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
auto iter = control_edges->find(second_node);
if (iter == control_edges->end()) {
(void)control_edges->insert(
std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node}));
} else {
iter->second.emplace_back(first_node);
}
auto ref_iter = nodes_ref->find(first_node);
if (ref_iter != nodes_ref->end()) {
ref_iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1));
}
}
}
}
void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref,
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) {
std::queue<AnfNodePtr> queue;
queue.push(graph->get_return());
std::set<AnfNodePtr> visited;
while (!queue.empty()) {
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) {
AddControlEdge(graph, input, control_edges, nodes_ref);
}
auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end()) {
iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
}
if (visited.find(input) != visited.end()) {
continue;
}
visited.insert(input);
queue.push(input);
}
}
}
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
for (auto &node : nodes) {
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "Invalid get item node";
}
auto &parent = inputs[1];
auto iter = node_positions.find(parent);
if (iter != node_positions.end()) {
size_t position = iter->second;
auto iter_nodes = insert_positions.find(position);
if (iter_nodes != insert_positions.end()) {
iter_nodes->second.push_back(node);
} else {
(void)insert_positions.insert(
std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
}
continue;
}
}
result.emplace_back(node);
node_positions[node] = result.size();
}
size_t insert_num = 0;
for (auto &item : insert_positions) {
size_t position = item.first + insert_num;
(void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
insert_num += item.second.size();
}
return result;
}
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> to_visit;
std::stack<AnfNodePtr> next_to_visit;
std::map<AnfNodePtr, size_t> nodes_ref;
std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
CalcNodeRefCount(graph, &nodes_ref, &control_edges);
std::string handle_target = default_target;
std::string next_target = "";
to_visit.push(graph->get_return());
while (!to_visit.empty() || !next_to_visit.empty()) {
if (to_visit.empty()) {
to_visit.swap(next_to_visit);
handle_target = next_target;
}
auto &node = to_visit.top();
MS_EXCEPTION_IF_NULL(node);
to_visit.pop();
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
std::reverse(node_inputs.begin(), node_inputs.end());
auto ctrl_inputs = control_edges.find(node);
if (ctrl_inputs != control_edges.end()) {
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
}
for (auto &input : node_inputs) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
to_visit.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == handle_target) {
to_visit.push(input);
} else if (next_to_visit.empty() || input_target == next_target) {
next_to_visit.push(input);
next_target = input_target;
} else {
MS_LOG(EXCEPTION) << "Only support two different target";
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
std::vector<AnfNodePtr> ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> handle_nodes;
std::stack<AnfNodePtr> next_handle_nodes;
std::stack<AnfNodePtr> wait_handle_nodes;
std::map<AnfNodePtr, size_t> nodes_ref;
std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
CalcNodeRefCount(graph, &nodes_ref, &control_edges);
std::string handle_target = default_target;
std::string next_target = "";
handle_nodes.push(graph->get_return());
while (!handle_nodes.empty() || !next_handle_nodes.empty() || !wait_handle_nodes.empty()) {
if (handle_nodes.empty()) {
handle_nodes.swap(next_handle_nodes);
handle_target.swap(next_target);
next_handle_nodes.swap(wait_handle_nodes);
}
auto &node = handle_nodes.top();
MS_EXCEPTION_IF_NULL(node);
handle_nodes.pop();
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
std::reverse(node_inputs.begin(), node_inputs.end());
auto ctrl_inputs = control_edges.find(node);
if (ctrl_inputs != control_edges.end()) {
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
}
std::vector<AnfNodePtr> same_target_ready_inputs;
std::vector<AnfNodePtr> diff_target_ready_inputs;
for (auto &input : node_inputs) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
handle_nodes.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == handle_target) {
same_target_ready_inputs.emplace_back(input);
} else {
if (next_target.empty()) {
next_target = input_target;
}
if (input_target != next_target) {
MS_LOG(EXCEPTION) << "Only support two different target";
}
diff_target_ready_inputs.emplace_back(input);
}
}
if (diff_target_ready_inputs.size() == 0) {
for (auto &input : same_target_ready_inputs) {
handle_nodes.push(input);
}
} else {
for (auto &input : same_target_ready_inputs) {
wait_handle_nodes.push(input);
}
for (auto &input : diff_target_ready_inputs) {
next_handle_nodes.push(input);
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
bool IsSubGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = inputs[0];
if (!IsValueNode<Primitive>(fn)) {
return false;
}
auto node_prim = GetValueNode<PrimitivePtr>(fn);
if (node_prim->name() == prim::kPrimPartial->name()) {
return true;
}
} else if (IsValueNode<FuncGraph>(node)) {
return true;
}
return false;
}
} // namespace
GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name)
: cut_list_(cut_list), backend_name_(backend_name) {}
bool GraphPartition::IsCut(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = inputs[0];
if (IsValueNode<FuncGraph>(fn)) {
auto fg = GetValueNode<FuncGraphPtr>(fn);
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
return false;
}
}
if (!IsValueNode<Primitive>(fn)) {
return true;
}
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
for (auto &prim : cut_list_) {
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == node_prim->name()) {
if (prim->name() == prim::kPrimBpropCut->name()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
}
if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
if (inputs.size() < 2) {
return false;
}
auto ret = IsSubGraph(inputs[1]);
return ret;
}
return true;
}
}
#ifdef ENABLE_GE
if (backend_name_ == kGeVm) {
auto name = GetCNodeFuncName(cnode);
auto adpt = transform::DfGraphConvertor::FindAdapter(name);
if (adpt == nullptr) {
return true;
}
}
#endif
}
return false;
}
std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = TopoSort(graph->get_return());
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
bool contain_multi_target = ContainMultiTarget(nodes);
if (contain_multi_target) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (graph != nullptr) {
nodes = SplitSort(graph, default_target);
} else {
nodes = ParallelSplitSort(graph, default_target);
}
nodes = OptimizeGetItemOrder(nodes);
}
std::vector<GraphSegmentPtr> segments;
std::vector<AnfNodePtr> segment_nodes;
std::string last_target;
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
if (segment_nodes.size() != 0) {
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
segments.emplace_back(segment);
segment_nodes.clear();
}
segment_nodes.emplace_back(node);
auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
segments.push_back(segment);
segment_nodes.clear();
} else if (node->isa<CNode>()) {
if (contain_multi_target) {
std::string cur_target = GetCNodeTarget(node);
if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) {
auto segment = std::make_shared<GraphSegment>(segment_nodes, false);
segments.emplace_back(segment);
segment_nodes.clear();
}
last_target = cur_target;
}
segment_nodes.emplace_back(node);
}
}
MS_LOG(DEBUG) << "Segment size:" << segments.size();
return segments;
}
} // namespace compile
} // namespace mindspore

+ 48
- 0
mindspore/ccsrc/vm/graph_partition.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_
#define MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_
#include <vector>
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/graph_utils.h"
#include "base/base_ref.h"
namespace mindspore {
extern const char kMsVm[];
extern const char kGeVm[];
extern const char kMsConvert[];
namespace compile {
class GraphPartition {
public:
explicit GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name);
~GraphPartition() = default;
std::vector<GraphSegmentPtr> Partition(const FuncGraphPtr &func_graph);
private:
bool IsCut(const AnfNodePtr &node);
std::vector<PrimitivePtr> cut_list_;
std::string backend_name_;
};
using GraphPartitionPtr = std::shared_ptr<GraphPartition>;
} // namespace compile
} // namespace mindspore
#endif // MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_

+ 5
- 11
mindspore/ccsrc/vm/segment_runner.cc View File

@@ -34,10 +34,6 @@
#include "frontend/operator/ops.h"

namespace mindspore {
const char kMsConvert[] = "ms";
const char kMsVm[] = "vm";
const char kGeVm[] = "ge";

namespace compile {
// cached conversion
ConvertCache g_ConvertCache;
@@ -194,8 +190,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
// This implementation will convert the nodes into a subgraph
// that will run using the MsVM.
template <typename T>
LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
auto cached = g_ConvertCache.find(lst);
LinConvertResult Convert(const GraphSegmentPtr &segment, const std::string &) {
MS_EXCEPTION_IF_NULL(segment);
auto cached = g_ConvertCache.find(segment);
if (cached != g_ConvertCache.end()) {
return cached->second;
}
@@ -206,7 +203,7 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
AnfNodePtrList inputs;
AnfNodePtrList outputs;

std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);

// Clone in case g contains subgraphs that have a different manager
fg = BasicClone(fg);
@@ -219,18 +216,15 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) {
result.outputs = outputs;
result.graph_id = UINT32_MAX;

(void)g_ConvertCache.emplace(lst, result);
(void)g_ConvertCache.emplace(segment, result);
return result;
}

LinkFuncType MsVmConvert = Convert<VM>;

std::unordered_map<std::string, LinkFuncType> backends = {{kMsVm, MsVmConvert}};

std::set<std::string> backend_list = {
kMsConvert,
kMsVm,
};

} // namespace compile
} // namespace mindspore

+ 3
- 9
mindspore/ccsrc/vm/segment_runner.h View File

@@ -27,14 +27,10 @@

#include "ir/anf.h"
#include "vm/vmimpl.h"
#include "vm/graph_partition.h"

namespace mindspore {
extern const char kMsVm[];
extern const char kGeVm[];
extern const char kMsConvert[];

namespace compile {

struct LinConvertResult {
RunFuncPtr run;
RunFuncPtr simu_run;
@@ -43,11 +39,9 @@ struct LinConvertResult {
uint32_t graph_id;
};

using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>;
using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>;
using LinkFuncType = std::function<LinConvertResult(const GraphSegmentPtr &, const std::string &)>;
using ConvertCache = std::unordered_map<GraphSegmentPtr, LinConvertResult>;
extern LinkFuncType MsVmConvert;
extern LinkFuncType GeVmConvert;
extern std::unordered_map<std::string, LinkFuncType> backends;
extern ConvertCache g_ConvertCache;
extern std::set<std::string> backend_list;



+ 20
- 410
mindspore/ccsrc/vm/transform.cc View File

@@ -21,8 +21,6 @@
#include <algorithm>
#include <map>
#include <queue>
#include <stack>
#include <set>
#include <string>
#include <vector>

@@ -52,386 +50,13 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
return ms_nonlinear_ops;
}

namespace {
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
for (auto &node : nodes) {
if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (last_target != cur_target) {
return true;
}
last_target = cur_target;
}
}
return false;
}

bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node,
std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) {
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(behind_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
if (prior_node->isa<Parameter>()) {
for (auto &user : node_users[prior_node]) {
auto cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
prior_nodes->emplace_back(cnode);
}
}
} else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) {
prior_nodes->emplace_back(prior_node);
} else {
return false;
}
if (behind_node->isa<Parameter>()) {
for (auto &user : node_users[behind_node]) {
auto cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
depend_nodes->emplace_back(cnode);
}
}
} else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) {
depend_nodes->emplace_back(behind_node);
} else {
return false;
}
return true;
}

void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges,
std::map<AnfNodePtr, size_t> *nodes_ref) {
MS_EXCEPTION_IF_NULL(node);
auto input_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto prior_node = input_cnode->input(kControlDependPriorIndex);
auto depend_node = input_cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_EXCEPTION_IF_NULL(prim_ptr);
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
int64_t depend_mode = 0;
if (mode_ptr != nullptr) {
depend_mode = GetValue<int64_t>(mode_ptr);
}
if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) {
return;
}
std::vector<AnfNodePtr> prior_nodes;
std::vector<AnfNodePtr> behind_nodes;
if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) {
return;
}
for (auto &first_node : prior_nodes) {
for (auto &second_node : behind_nodes) {
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
auto iter = control_edges->find(second_node);
if (iter == control_edges->end()) {
(void)control_edges->insert(
std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node}));
} else {
iter->second.emplace_back(first_node);
}
auto ref_iter = nodes_ref->find(first_node);
if (ref_iter != nodes_ref->end()) {
ref_iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1));
}
}
}
}

void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref,
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) {
std::queue<AnfNodePtr> queue;
queue.push(graph->get_return());
std::set<AnfNodePtr> visited;
while (!queue.empty()) {
auto &node = queue.front();
queue.pop();
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) {
if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) {
AddControlEdge(graph, input, control_edges, nodes_ref);
}
auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end()) {
iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1));
}
if (visited.find(input) != visited.end()) {
continue;
}
visited.insert(input);
queue.push(input);
}
}
}

std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
for (auto &node : nodes) {
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "Invalid get item node";
}
auto &parent = inputs[1];
auto iter = node_positions.find(parent);
if (iter != node_positions.end()) {
size_t position = iter->second;
auto iter_nodes = insert_positions.find(position);
if (iter_nodes != insert_positions.end()) {
iter_nodes->second.push_back(node);
} else {
(void)insert_positions.insert(
std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
}
continue;
}
}
result.emplace_back(node);
node_positions[node] = result.size();
}

size_t insert_num = 0;
for (auto &item : insert_positions) {
size_t position = item.first + insert_num;
(void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
insert_num += item.second.size();
}
return result;
}

std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> to_visit;
std::stack<AnfNodePtr> next_to_visit;
std::map<AnfNodePtr, size_t> nodes_ref;
std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
CalcNodeRefCount(graph, &nodes_ref, &control_edges);
std::string handle_target = default_target;
std::string next_target = "";
to_visit.push(graph->get_return());
while (!to_visit.empty() || !next_to_visit.empty()) {
if (to_visit.empty()) {
to_visit.swap(next_to_visit);
handle_target = next_target;
}
auto &node = to_visit.top();
MS_EXCEPTION_IF_NULL(node);
to_visit.pop();
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
std::reverse(node_inputs.begin(), node_inputs.end());
auto ctrl_inputs = control_edges.find(node);
if (ctrl_inputs != control_edges.end()) {
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
}
for (auto &input : node_inputs) {
auto iter = nodes_ref.find(input);
if (iter != nodes_ref.end()) {
iter->second--;
if (iter->second != 0) {
continue;
}
}
if (!input->isa<CNode>()) {
to_visit.push(input);
continue;
}
std::string input_target = GetCNodeTarget(input);
if (input_target == handle_target) {
to_visit.push(input);
} else if (next_to_visit.empty() || input_target == next_target) {
next_to_visit.push(input);
next_target = input_target;
} else {
MS_LOG(EXCEPTION) << "only support two different target";
}
}
}
std::reverse(result.begin(), result.end());
return result;
}

bool IsSubGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}

AnfNodePtr fn = inputs[0];
if (!IsValueNode<Primitive>(fn)) {
return false;
}
auto node_prim = GetValueNode<PrimitivePtr>(fn);
if (node_prim->name() == prim::kPrimPartial->name()) {
return true;
}
} else if (IsValueNode<FuncGraph>(node)) {
return true;
}
return false;
}
} // namespace

CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
: backend_(backend), cut_list_(cut_list) {
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
MS_EXCEPTION_IF_NULL(backend_);
lin_convert_ = backend_->convert_fn();
if (lin_convert_ == nullptr) {
MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
}

is_gevm_convert_ = false;
if (backend->name() == kGeVm) {
MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true";
is_gevm_convert_ = true;
}
}

bool CompileGraph::IsCut(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}

AnfNodePtr fn = inputs[0];
if (IsValueNode<FuncGraph>(fn)) {
auto fg = GetValueNode<FuncGraphPtr>(fn);
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
return false;
}
}

if (!IsValueNode<Primitive>(fn)) {
return true;
}

PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn);
for (auto &prim : cut_list_) {
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == node_prim->name()) {
if (prim->name() == prim::kPrimBpropCut->name()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
}

if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
if (inputs.size() < 2) {
return false;
}
auto ret = IsSubGraph(inputs[1]);
return ret;
}

return true;
}
}

#ifdef ENABLE_GE
if (is_gevm_convert_) {
auto name = GetCNodeFuncName(cnode);
auto adpt = transform::DfGraphConvertor::FindAdapter(name);
if (adpt == nullptr) {
return true;
}
}
#endif
}

return false;
}

VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = OptimizeGetItemOrder(input_nodes);
VectorRef splits;
VectorRef split;
std::string last_target;
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
if (split.size() != 0) {
splits.push_back(split);
}
splits.push_back(node);
split.clear();
} else if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (cur_target != last_target && !last_target.empty() && split.size() != 0) {
splits.push_back(split);
split.clear();
}
last_target = cur_target;
split.push_back(node);
}
}
return splits;
}

VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = TopoSort(graph->get_return());
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();

if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
nodes = SplitSort(graph, default_target);
return SplitNodesWithTarget(nodes, graph);
}

VectorRef splits;
VectorRef split;
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
if (split.size() != 0) {
splits.push_back(split);
}
splits.push_back(node);
split.clear();
} else if (node->isa<CNode>()) {
split.push_back(node);
}
}
return splits;
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
}

// Push the value node on the stack.
@@ -512,12 +137,12 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
}
}

int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list,
const std::string &target) {
int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
MS_EXCEPTION_IF_NULL(segment);
MS_LOG(DEBUG) << "LinConvert start";
LinConvertResult result;

result = lin_convert_(node_list, target);
result = lin_convert_(segment, target);

if (result.run == nullptr) {
MS_LOG(ERROR) << "LinConvert failed";
@@ -583,25 +208,23 @@ int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &n
return RET_SUCCESS;
}

bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
bool CompileGraph::Compile(const FuncGraphPtr &graph) {
MS_LOG(DEBUG) << "Start split graph";
MS_EXCEPTION_IF_NULL(graph);
VectorRef splits = SplitNodes(graph);
MS_EXCEPTION_IF_NULL(graph_partition_);
auto segments = graph_partition_->Partition(graph);

MS_LOG(DEBUG) << "Split nodes size:" << splits.size();
for (auto &split : splits) {
MS_LOG(DEBUG) << "Split nodes size:" << segments.size();
for (auto &segment : segments) {
MS_EXCEPTION_IF_NULL(segment);
int64_t ret = RET_SUCCESS;
if (utils::isa<VectorRef>(split)) {
if (!segment->is_cut_) {
MS_LOG(DEBUG) << "Start a extern LinConvert";
std::vector<AnfNodePtr> args;
auto vec_ref = utils::cast<VectorRef>(split);
(void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args),
[](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); });
if (args.size() > 0) {
std::string cur_target = GetCNodeTarget(args[0]);
ret = LinConvert(graph, args, cur_target);
if (segment->nodes_.size() > 0) {
std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
ret = LinConvert(graph, segment, cur_target);
} else {
ret = LinConvert(graph, args);
ret = LinConvert(graph, segment);
}
MS_LOG(DEBUG) << "End a extern LinConvert";
if (ret == RET_FAILED) {
@@ -612,10 +235,11 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) {
}
} else {
MS_LOG(DEBUG) << "Start a cut node";
if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) {
auto &cut_node = segment->nodes_[0];
if (!cut_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
}
CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>();
CNodePtr node = cut_node->cast<CNodePtr>();
ret = InterpretNode(graph, node);
MS_LOG(DEBUG) << "End a cut node";
if (ret == RET_BREAK) {
@@ -635,7 +259,7 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
int64_t param_height = height_;
MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);

if (!SplitGraph(graph)) {
if (!Compile(graph)) {
return inst_;
}

@@ -897,20 +521,6 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
return rt;
}

bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto graph_manager = graph->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
FuncGraphSet graphs = graph_manager->func_graphs();
for (auto &g : graphs) {
auto nodes = TopoSort(g->get_return());
if (ContainMultiTarget(nodes)) {
return true;
}
}
return false;
}

BackendPtr CreateBackend() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);


+ 6
- 7
mindspore/ccsrc/vm/transform.h View File

@@ -31,6 +31,7 @@
#include "frontend/operator/ops.h"
#include "vm/segment_runner.h"
#include "vm/backend.h"
#include "vm/graph_partition.h"

// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
@@ -59,7 +60,6 @@ class CompileGraph {
void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
void Ret(int64_t nargs);
int64_t Ref(const AnfNodePtr &node);
VectorRef SplitNodes(const FuncGraphPtr &func_graph);

void set_height(int64_t h) {
height_ = h;
@@ -76,10 +76,9 @@ class CompileGraph {
}

private:
VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph);
void PushParameters(const FuncGraphPtr &func_graph);
bool SplitGraph(const FuncGraphPtr &func_graph);
int64_t LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");
bool Compile(const FuncGraphPtr &func_graph);
int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
void AddPadStack(int64_t param_height);
@@ -97,11 +96,12 @@ class CompileGraph {
void AddInst(const Instruction &inst, const VectorRef &args);

BackendPtr backend_;
GraphPartitionPtr graph_partition_;
LinkFuncType lin_convert_;
bool is_gevm_convert_;
int64_t height_{0};
int64_t max_height_{0};
std::vector<PrimitivePtr> cut_list_;
std::unordered_map<AnfNodePtr, int64_t> slots_;
InstSet inst_;
};
@@ -123,7 +123,6 @@ class CompileGraphs {
void Compile(const FuncGraphPtr &func_graph);
FinalVMPtr Link(const FuncGraphPtr &func_graph);
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
static bool ContainMixedTarget(const FuncGraphPtr &graph);

private:
InstSet insts_;


+ 16
- 0
mindspore/core/ir/anf.cc View File

@@ -301,4 +301,20 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
}
return default_target;
}

bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
for (auto &node : nodes) {
if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
if (last_target != cur_target) {
return true;
}
last_target = cur_target;
}
}
return false;
}
} // namespace mindspore

+ 7
- 0
mindspore/core/ir/anf.h View File

@@ -482,6 +482,13 @@ void reset_id();
using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
std::string GetCNodeTarget(const AnfNodePtr &node);
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
struct GraphSegment {
GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
std::vector<AnfNodePtr> nodes_;
bool is_cut_{false};
};
using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
} // namespace mindspore

#endif // MINDSPORE_CORE_IR_ANF_H_

+ 13
- 0
mindspore/core/ir/func_graph.cc View File

@@ -647,6 +647,19 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
return parameter;
}

bool FuncGraph::ContainMultiTarget() const {
auto graph_manager = manager();
MS_EXCEPTION_IF_NULL(graph_manager);
FuncGraphSet graphs = graph_manager->func_graphs();
for (auto &g : graphs) {
auto nodes = TopoSort(g->get_return());
if (mindspore::ContainMultiTarget(nodes)) {
return true;
}
}
return false;
}

size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
return ++fg_seen_generation;


+ 1
- 0
mindspore/core/ir/func_graph.h View File

@@ -354,6 +354,7 @@ class FuncGraph : public FuncGraphBase {
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
bool ContainMultiTarget() const;

private:
// graph is manipulated by manager and others


+ 9
- 39
tests/ut/cpp/vm/segment_runner_test.cc View File

@@ -52,21 +52,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);

BackendPtr b = std::make_shared<Backend>("vm");
CompileGraph transform_(b);
auto splits = transform_.SplitNodes(g);
auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name());
auto segments = graph_partition->Partition(g);
VectorRef args({1.0, 2.0});

std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);

AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
auto convertResult = MsVmConvert(anf_list, "");
auto convertResult = MsVmConvert(segments[0], "");
auto runResult = (*(convertResult.run))(args);
ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0);
}
@@ -76,21 +66,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);

BackendPtr b = std::make_shared<Backend>("vm");
CompileGraph transform_(b);
auto splits = transform_.SplitNodes(g);
auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name());
auto segments = graph_partition->Partition(g);
VectorRef args({1.0, 2.0});

std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);

AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
auto convertResult = MsVmConvert(anf_list, "");
auto convertResult = MsVmConvert(segments[0], "");
auto runResult = (*(convertResult.run))(args);
ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0);
}
@@ -100,21 +80,11 @@ TEST_F(TestCompileSegmentRunner, test_if) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g);

BackendPtr b = std::make_shared<Backend>("vm");
CompileGraph transform_(b);
auto splits = transform_.SplitNodes(g);
auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name());
auto segments = graph_partition->Partition(g);
VectorRef args({1.0, 2.0});

std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);

AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
auto convertResult = MsVmConvert(anf_list, "");
auto convertResult = MsVmConvert(segments[0], "");
auto runResult = (*(convertResult.run))(args);

auto result = py::cast<bool>(BaseRefToPyData(runResult[0]));


Loading…
Cancel
Save