|
|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include <algorithm> |
|
|
|
#include <map> |
|
|
|
#include <queue> |
|
|
|
#include <stack> |
|
|
|
#include <set> |
|
|
|
#include <string> |
|
|
|
#include <vector> |
|
|
|
@@ -75,7 +76,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
auto primitive = value->cast<PrimitivePtr>(); |
|
|
|
ValuePtr att_target = primitive->GetAttr("target"); |
|
|
|
ValuePtr att_target = primitive->GetAttr("primitive_target"); |
|
|
|
if (att_target != nullptr) { |
|
|
|
std::string target = GetValue<std::string>(att_target); |
|
|
|
return target; |
|
|
|
@@ -127,6 +128,58 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { |
|
|
|
std::vector<AnfNodePtr> result; |
|
|
|
std::stack<AnfNodePtr> to_visit; |
|
|
|
std::stack<AnfNodePtr> next_to_visit; |
|
|
|
std::map<AnfNodePtr, size_t> nodes_ref; |
|
|
|
CalcNodeRefCount(graph, &nodes_ref); |
|
|
|
std::string handle_target = default_target; |
|
|
|
std::string next_target = ""; |
|
|
|
to_visit.push(graph->get_return()); |
|
|
|
while (!to_visit.empty() || !next_to_visit.empty()) { |
|
|
|
if (to_visit.empty()) { |
|
|
|
to_visit.swap(next_to_visit); |
|
|
|
handle_target = next_target; |
|
|
|
} |
|
|
|
auto &node = to_visit.top(); |
|
|
|
to_visit.pop(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
result.emplace_back(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto node_inputs = cnode->inputs(); |
|
|
|
std::reverse(node_inputs.begin(), node_inputs.end()); |
|
|
|
for (auto &input : node_inputs) { |
|
|
|
auto iter = nodes_ref.find(input); |
|
|
|
if (iter != nodes_ref.end()) { |
|
|
|
iter->second--; |
|
|
|
if (iter->second != 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!input->isa<CNode>()) { |
|
|
|
to_visit.push(input); |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::string input_target = GetCNodeTarget(input); |
|
|
|
if (input_target == handle_target) { |
|
|
|
to_visit.push(input); |
|
|
|
} else if (next_to_visit.empty() || input_target == next_target) { |
|
|
|
next_to_visit.push(input); |
|
|
|
next_target = input_target; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "only support two different target"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
std::reverse(result.begin(), result.end()); |
|
|
|
return result; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) |
|
|
|
@@ -180,65 +233,16 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> CompileGraph::SplitSort(const FuncGraphPtr &graph) { |
|
|
|
std::vector<AnfNodePtr> result; |
|
|
|
std::queue<AnfNodePtr> queue; |
|
|
|
std::queue<AnfNodePtr> next_queue; |
|
|
|
std::map<AnfNodePtr, size_t> nodes_ref; |
|
|
|
CalcNodeRefCount(graph, &nodes_ref); |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string queue_target = context_ptr->device_target(); |
|
|
|
std::string next_target = ""; |
|
|
|
queue.push(graph->get_return()); |
|
|
|
while (!queue.empty() || !next_queue.empty()) { |
|
|
|
if (queue.empty()) { |
|
|
|
queue.swap(next_queue); |
|
|
|
queue_target = next_target; |
|
|
|
} |
|
|
|
auto &node = queue.front(); |
|
|
|
queue.pop(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
result.emplace_back(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
for (auto &input : cnode->inputs()) { |
|
|
|
auto iter = nodes_ref.find(input); |
|
|
|
if (iter != nodes_ref.end()) { |
|
|
|
iter->second--; |
|
|
|
if (iter->second != 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!input->isa<CNode>()) { |
|
|
|
queue.push(input); |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::string input_target = GetCNodeTarget(input); |
|
|
|
if (input_target == queue_target) { |
|
|
|
queue.push(input); |
|
|
|
} else if (next_queue.empty() || input_target == next_target) { |
|
|
|
next_queue.push(input); |
|
|
|
next_target = input_target; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "only support two different target"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
std::reverse(result.begin(), result.end()); |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
VectorRef splits; |
|
|
|
VectorRef split; |
|
|
|
auto nodes = TopoSort(graph->get_return()); |
|
|
|
if (ContainMultiTarget(nodes)) { |
|
|
|
nodes = SplitSort(graph); |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string default_target = context_ptr->device_target(); |
|
|
|
nodes = SplitSort(graph, default_target); |
|
|
|
} |
|
|
|
std::string last_target; |
|
|
|
MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); |
|
|
|
|