|
|
|
@@ -55,6 +55,10 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; |
|
|
|
// it will be one item in map with key: C, and value: (B, i) |
|
|
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; |
|
|
|
static void HandleNoUsedParameter(const FuncGraphPtr &root); |
|
|
|
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, |
|
|
|
const std::string &instance_name); |
|
|
|
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, |
|
|
|
const std::string &opt_shard_group); |
|
|
|
|
|
|
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { |
|
|
|
if (new_node_input.empty()) { |
|
|
|
@@ -125,6 +129,30 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An |
|
|
|
MS_LOG(INFO) << "Insert " << instance_name << " success"; |
|
|
|
} |
|
|
|
|
|
|
|
// Replace pre_node with pre_node->op |
|
|
|
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, |
|
|
|
const std::string &instance_name) { |
|
|
|
// insert new node before the node |
|
|
|
FuncGraphManagerPtr manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
ScopePtr scope = pre_node->scope(); |
|
|
|
MS_EXCEPTION_IF_NULL(scope); |
|
|
|
std::vector<AnfNodePtr> node_input = CreateInput(op, pre_node, instance_name); |
|
|
|
CNodePtr new_node = func_graph->NewCNode(node_input); |
|
|
|
MS_EXCEPTION_IF_NULL(new_node); |
|
|
|
if (instance_name.find(SPLIT_SENS) == std::string::npos) { |
|
|
|
new_node->set_in_forward_flag(true); // mark forward flag |
|
|
|
} |
|
|
|
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]); |
|
|
|
new_node_prim->set_instance_name(instance_name); |
|
|
|
new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); |
|
|
|
new_node->set_scope(scope); |
|
|
|
node_input[0]->set_scope(scope); |
|
|
|
manager->Replace(pre_node, new_node); |
|
|
|
MS_LOG(INFO) << "Insert " << instance_name << " success"; |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
|
|
|
|
std::string CreateInstanceName(const CNodePtr &node, size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!IsValueNode<Primitive>(node->input(0))) { |
|
|
|
@@ -1380,18 +1408,26 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int |
|
|
|
auto cnode = res.first->cast<CNodePtr>(); |
|
|
|
auto graph = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); |
|
|
|
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode_prim); |
|
|
|
CNodePtr allgather; |
|
|
|
if (cnode_prim->name() == CAST) { |
|
|
|
allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); |
|
|
|
} else { |
|
|
|
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); |
|
|
|
allgather = cnode->input(res.second)->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
// add fusion flag |
|
|
|
auto allgather = cnode->input(res.second)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(allgather); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); |
|
|
|
auto attrs = prim->attrs(); |
|
|
|
// enable fusion flag later when it's supported in backend |
|
|
|
attrs["fusion"] = MakeValue(0); |
|
|
|
attrs["fusion"] = MakeValue<int64_t>(0); |
|
|
|
prim->SetAttrs(attrs); |
|
|
|
} |
|
|
|
|
|
|
|
void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, |
|
|
|
const std::string &opt_shard_group) { |
|
|
|
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, |
|
|
|
const std::string &opt_shard_group) { |
|
|
|
if (opt_shard_group.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|