Browse Source

remove multiple circles

tags/v1.1.0
lingyunli63 5 years ago
parent
commit
dc95c63c03
1 changed files with 39 additions and 12 deletions
  1. +39
    -12
      mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc

+ 39
- 12
mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc View File

@@ -39,6 +39,27 @@


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace {
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) {
std::vector<AnfNodePtr> inputs;
for (auto &root : roots) {
auto tmp = DeepLinkedGraphSearch(root, include);
inputs.insert(inputs.end(), tmp.begin(), tmp.end());
}
return inputs;
}

std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include,
const FuncGraphManagerPtr &mng) {
std::vector<AnfNodePtr> users;
for (auto &root : roots) {
auto tmp = DeepUsersSearch(root, include, mng);
users.insert(users.end(), tmp.begin(), tmp.end());
}
return users;
}
} // namespace

bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
#if ENABLE_D #if ENABLE_D
std::vector<PrimitivePtr> basic_ops = { std::vector<PrimitivePtr> basic_ops = {
@@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
} }


bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) {
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) {
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false; return false;
} }
circle_nodes->clear();


std::set<AnfNodePtr> cached_done_set;
auto cnode = check_node->cast<CNodePtr>(); auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs(); const auto &inputs = cnode->inputs();
// there is a input not in fused_op_set, but the input depends on the fused_op_set // there is a input not in fused_op_set, but the input depends on the fused_op_set
bool has_circle = false;
for (auto input : inputs) { for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) { if (input->isa<CNode>() && !fused_op_set.count(input)) {
bool has_circle = false;
std::set<AnfNodePtr> done; std::set<AnfNodePtr> done;
std::vector<AnfNodePtr> todos = {input}; std::vector<AnfNodePtr> todos = {input};
while (!todos.empty()) { while (!todos.empty()) {
auto node = todos.back(); auto node = todos.back();
todos.pop_back(); todos.pop_back();
if (done.count(node) || cached_unconnected_set->count(node)) {
if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) {
continue; continue;
} }


done.insert(node); done.insert(node);
if (fused_op_set.count(node)) { if (fused_op_set.count(node)) {
has_circle = true; has_circle = true;
*circle_node = node;
break;
circle_nodes->push_back(node);
continue;
} }


if (node->isa<CNode>()) { if (node->isa<CNode>()) {
@@ -224,13 +247,15 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
} }


if (has_circle) { if (has_circle) {
return true;
cached_done_set.insert(done.begin(), done.end());
} else {
cached_unconnected_set->insert(done.begin(), done.end());
} }
cached_unconnected_set->insert(done.begin(), done.end());
done.clear();
} }
} }


return false;
return !circle_nodes->empty();
} }


std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) { std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
@@ -242,17 +267,19 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
} }
return EXCLUDE; return EXCLUDE;
}; };

std::vector<AnfNodePtr> circle_nodes;
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
AnfNodePtr circle_node;
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node);
circle_nodes.clear();
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes);
// delete the circle node and the node which depend on the circle node in fused op // delete the circle node and the node which depend on the circle node in fused op
if (has_circle) { if (has_circle) {
auto mng = (*iter)->func_graph()->manager(); auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes; std::vector<AnfNodePtr> erase_nodes;
if (is_backward) { if (is_backward) {
erase_nodes = DeepUsersSearch(circle_node, include, mng);
erase_nodes = DeepUsersSearch(circle_nodes, include, mng);
} else { } else {
erase_nodes = DeepLinkedGraphSearch(circle_node, include);
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
} }
for (auto erase_node : erase_nodes) { for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node); fused_op_set.erase(erase_node);


Loading…
Cancel
Save