|
|
|
@@ -38,6 +38,16 @@ namespace parallel { |
|
|
|
static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map; |
|
|
|
static int send_tag = 0; |
|
|
|
static int recv_tag = 0; |
|
|
|
const std::set<PrimitivePtr> WHITE_LIST = {prim::kPrimCast, prim::kPrimTupleGetItem}; |
|
|
|
|
|
|
|
static bool IsInWhiteList(const CNodePtr &cnode) { |
|
|
|
for (auto &prim : WHITE_LIST) { |
|
|
|
if (IsPrimitiveCNode(cnode, prim)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void PipelineTransformer::Coloring() { |
|
|
|
auto need_coloring = true; |
|
|
|
@@ -85,7 +95,7 @@ void PipelineTransformer::BroadCastColoring() { |
|
|
|
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
if (prim == nullptr) { |
|
|
|
if (IsInWhiteList(cnode)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (IsInBlackList(prim)) { |
|
|
|
@@ -138,42 +148,21 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { |
|
|
|
|
|
|
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// handle send/recv a parameter |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv."; |
|
|
|
return std::make_pair(nullptr, nullptr); |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
OperatorInfoPtr op_info = nullptr; |
|
|
|
TensorInfo tensor_info; |
|
|
|
// op1(stage1)->op2(stage2) |
|
|
|
if (IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
op_info = CreateOpInfo(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
tensor_info = op_info->outputs_tensor_info()[0]; |
|
|
|
} else if (IsValueNode<FuncGraph>(cnode->input(0))) { |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto output = graph->output(); |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
auto output_cnode = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(output_cnode->input(0)); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (prim->name() == TUPLE_GETITEM) { |
|
|
|
auto index = GetTupleGetItemIndex(output_cnode); |
|
|
|
auto pre_getitem_node = output_cnode->input(1)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_getitem_node); |
|
|
|
op_info = CreateOpInfo(pre_getitem_node); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
tensor_info = op_info->outputs_tensor_info()[index]; |
|
|
|
} else { |
|
|
|
op_info = CreateOpInfo(output_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
tensor_info = op_info->outputs_tensor_info()[0]; |
|
|
|
} |
|
|
|
} |
|
|
|
// Handle Cast and TupleGetitem situation |
|
|
|
size_t tensor_info_index = 0; |
|
|
|
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) { |
|
|
|
cnode = cnode->input(1)->cast<CNodePtr>(); |
|
|
|
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { |
|
|
|
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode)); |
|
|
|
cnode = cnode->input(1)->cast<CNodePtr>(); |
|
|
|
} |
|
|
|
// Create OperatorInfo to get slice_shape for send/recv |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto op_info = CreateOpInfo(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(op_info); |
|
|
|
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index]; |
|
|
|
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info)); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -316,6 +305,29 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con |
|
|
|
return std::make_pair(shape_list, dtype); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (IsValueNode<FuncGraph>(cnode->input(0))) { |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); |
|
|
|
auto output = graph->output(); |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
if (output->isa<Parameter>()) { |
|
|
|
return output; |
|
|
|
} |
|
|
|
cnode = output->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
} |
|
|
|
if (IsInWhiteList(cnode)) { |
|
|
|
return cnode->cast<AnfNodePtr>(); |
|
|
|
} |
|
|
|
if (!IsPipelineCareNode(cnode)) { |
|
|
|
MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."; |
|
|
|
} |
|
|
|
return cnode->cast<AnfNodePtr>(); |
|
|
|
} |
|
|
|
|
|
|
|
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, |
|
|
|
int node_stage) { |
|
|
|
Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); |
|
|
|
@@ -330,7 +342,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod |
|
|
|
if (parameter->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(parameter); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(parameter); |
|
|
|
auto care_node = FindPipelineCareNode(parameter); |
|
|
|
if (care_node->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(care_node); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(care_node); |
|
|
|
} |
|
|
|
} |
|
|
|
auto tensor_info = op_info_pair.second; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info); |
|
|
|
@@ -360,7 +377,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(node); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(node); |
|
|
|
auto care_node = FindPipelineCareNode(node); |
|
|
|
if (care_node->isa<Parameter>()) { |
|
|
|
op_info_pair = GetParameterPair(care_node); |
|
|
|
} else { |
|
|
|
op_info_pair = GetOpInfo(care_node); |
|
|
|
} |
|
|
|
} |
|
|
|
auto tensor_info = op_info_pair.second; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info); |
|
|
|
|