Browse Source

optimize depend for mix target

tags/v0.6.0-beta
kswang 5 years ago
parent
commit
a3843659b4
4 changed files with 55 additions and 41 deletions
  1. +11
    -2
      mindspore/ccsrc/session/session_basic.cc
  2. +8
    -10
      mindspore/ccsrc/vm/segment_runner.cc
  3. +35
    -29
      mindspore/ccsrc/vm/transform.cc
  4. +1
    -0
      mindspore/ccsrc/vm/transform.h

+ 11
- 2
mindspore/ccsrc/session/session_basic.cc View File

@@ -386,9 +386,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
auto new_fg = BasicClone(fg);
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
}
auto origin_inputs = cnode->inputs();
bool optimize_depend = false;
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
optimize_depend = true;
}
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->inputs()[input_idx];
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
auto anf = origin_inputs[input_idx];
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
@@ -413,6 +419,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
(*other_graph_cnode)[anf] = new_parameter;
}
continue;
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
continue;
} else if (anf->isa<AnfNode>()) {
*from_other_graph = true;
// the input node is a cnode from other graph


+ 8
- 10
mindspore/ccsrc/vm/segment_runner.cc View File

@@ -28,6 +28,7 @@
#include <string>

#include "utils/log_adapter.h"
#include "utils/utils.h"
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "operator/ops.h"
@@ -85,7 +86,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty";
}

auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
eqv[a] = a;
@@ -95,17 +95,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
eqv[a]->set_abstract(a->abstract());
eqv[a]->set_kernel_info(a->kernel_info_ptr());
}

return eqv[a];
};

// Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) {
if (!n->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inst is not CNode";
}
auto &inps = n->cast<CNodePtr>()->inputs();

if (inps.empty()) {
MS_LOG(EXCEPTION) << "Input is empty";
}
@@ -114,21 +111,22 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode";
}

auto fn = inps[0];

std::vector<AnfNodePtr> args{fn};
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);

if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() &&
eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
args.emplace_back(inps[kRealInputIndexInDepend]);
args.emplace_back(inps[kRealInputIndexInDepend]);
} else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
}
eqv[n] = fg->NewCNode(args);
eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr());
}

std::vector<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys),
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });

auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys);
AnfNodePtr fg_output;
if (outputs.size() > 1) {


+ 35
- 29
mindspore/ccsrc/vm/transform.cc View File

@@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
}
}

bool IsGetItemNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
if (!IsValueNode<Primitive>(inputs[0])) {
return true;
}
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
return node_prim->name() == prim::kPrimTupleGetItem->name();
}
return false;
}

std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
for (auto &node : nodes) {
if (IsGetItemNode(node)) {
if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
@@ -241,7 +224,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
}
}
std::reverse(result.begin(), result.end());
return ReorderGetItemNode(result);
return result;
}
} // namespace

@@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
return false;
}

VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = OptimizeGetItemOrder(input_nodes);
VectorRef splits;
VectorRef split;
auto nodes = TopoSort(graph->get_return());
if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
nodes = SplitSort(graph, default_target);
}
std::string last_target;
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
@@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
return splits;
}

VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = TopoSort(graph->get_return());
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();

if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
nodes = SplitSort(graph, default_target);
return SplitNodesWithTarget(nodes, graph);
}

VectorRef splits;
VectorRef split;
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (IsCut(node)) {
if (split.size() != 0) {
splits.push_back(split);
}
splits.push_back(node);
split.clear();
} else if (node->isa<CNode>()) {
split.push_back(node);
}
}
return splits;
}

// Push the value node on the stack.
void CompileGraph::Push(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);


+ 1
- 0
mindspore/ccsrc/vm/transform.h View File

@@ -78,6 +78,7 @@ class CompileGraph {
}

private:
VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph);
void PushParameters(const FuncGraphPtr &func_graph);
bool SplitGraph(const FuncGraphPtr &func_graph);
int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");


Loading…
Cancel
Save