|
|
|
@@ -36,6 +36,7 @@ constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; |
|
|
|
constexpr size_t k5dDims = 5; |
|
|
|
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), |
|
|
|
prim::kPrimAssignSub->name()}; |
|
|
|
|
|
|
|
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -129,7 +130,34 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std:: |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetNodeGroup(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { |
|
|
|
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup); |
|
|
|
} |
|
|
|
return ""; |
|
|
|
} |
|
|
|
|
|
|
|
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) { |
|
|
|
MS_EXCEPTION_IF_NULL(optimized_comm_group); |
|
|
|
auto node_group = GetNodeGroup(node); |
|
|
|
if (node_group.find(kSyncBnGroup) != string::npos) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto node_name = AnfAlgo::GetCNodeName(node); |
|
|
|
auto iter = optimized_comm_group->find(node_name); |
|
|
|
if (iter == optimized_comm_group->end()) { |
|
|
|
(*optimized_comm_group)[node_name] = node_group; |
|
|
|
return true; |
|
|
|
} else if (iter->second == node_group) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { |
|
|
|
auto value_node = node->cast<ValueNodePtr>(); |
|
|
|
if (value_node == nullptr) { |
|
|
|
@@ -153,7 +181,7 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const { |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) { |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) { |
|
|
|
MS_EXCEPTION_IF_NULL(visit_queue); |
|
|
|
MS_EXCEPTION_IF_NULL(visited_nodes); |
|
|
|
auto it = node_output_edges_.find(node); |
|
|
|
@@ -184,7 +212,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod |
|
|
|
// allreduce first |
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { |
|
|
|
(void)visited_nodes->insert(next_node); |
|
|
|
if (AnfAlgo::IsCommunicationOp(next_node)) { |
|
|
|
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node); |
|
|
|
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) { |
|
|
|
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); |
|
|
|
visit_queue->push(next_node); |
|
|
|
} else { |
|
|
|
@@ -206,18 +235,19 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
execution_order_.clear(); |
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes; |
|
|
|
std::queue<AnfNodePtr> zero_input_nodes; |
|
|
|
AnfNodePtr last_communication_node = nullptr; |
|
|
|
std::stack<AnfNodePtr> delay_comm_stack; |
|
|
|
std::queue<AnfNodePtr> communication_descendants; |
|
|
|
while (!seed_nodes.empty() || last_communication_node != nullptr) { |
|
|
|
// seed nodes first, then visit last all reduce node descendant |
|
|
|
std::map<std::string, std::string> optimized_comm_group; |
|
|
|
while (!seed_nodes.empty() || !delay_comm_stack.empty()) { |
|
|
|
// seed nodes first, then delay comm nodes |
|
|
|
if (seed_nodes.empty()) { |
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); |
|
|
|
last_communication_node = nullptr; |
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); |
|
|
|
delay_comm_stack.pop(); |
|
|
|
} else { |
|
|
|
zero_input_nodes.push(seed_nodes.front()); |
|
|
|
seed_nodes.pop(); |
|
|
|
} |
|
|
|
// all reduce node descendant first, then common queue |
|
|
|
// comm descendant first, then common queue |
|
|
|
while (!zero_input_nodes.empty() || !communication_descendants.empty()) { |
|
|
|
AnfNodePtr node = nullptr; |
|
|
|
bool is_communication_descendant = false; |
|
|
|
@@ -234,12 +264,20 @@ void KernelGraph::SetExecOrderByDefault() { |
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) { |
|
|
|
execution_order_.push_back(node->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
// for all reduce node, visit last all reduce node descendant |
|
|
|
if (AnfAlgo::IsCommunicationOp(node)) { |
|
|
|
if (last_communication_node != nullptr) { |
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); |
|
|
|
// delay execute comm ops that need optimize |
|
|
|
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node); |
|
|
|
bool optimize_comm = is_fused_comm; |
|
|
|
if (optimize_comm) { |
|
|
|
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group); |
|
|
|
} |
|
|
|
if (optimize_comm) { |
|
|
|
while (!delay_comm_stack.empty()) { |
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); |
|
|
|
delay_comm_stack.pop(); |
|
|
|
} |
|
|
|
last_communication_node = node; |
|
|
|
delay_comm_stack.push(node); |
|
|
|
} else if (is_fused_comm) { |
|
|
|
delay_comm_stack.push(node); |
|
|
|
} else if (is_communication_descendant) { |
|
|
|
VisitNodeDescendants(node, &communication_descendants, &visited_nodes); |
|
|
|
} else { |
|
|
|
@@ -540,7 +578,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
auto parameter = node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
bool is_weight = AnfAlgo ::IsParameterWeight(parameter); |
|
|
|
bool is_weight = AnfAlgo::IsParameterWeight(parameter); |
|
|
|
kernel_info->set_feature_map_flag(!is_weight); |
|
|
|
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); |
|
|
|
} |
|
|
|
@@ -746,6 +784,7 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons |
|
|
|
// delete old kernel |
|
|
|
(void)backend_front_anf_map_.erase(old_backend_anf); |
|
|
|
} |
|
|
|
|
|
|
|
// get kernel by anf |
|
|
|
AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { |
|
|
|
if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { |
|
|
|
|