Browse Source

!9314 code review for pipeline parallel

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
bd6ee9c72f
3 changed files with 19 additions and 16 deletions
  1. +5
    -5
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  2. +6
    -6
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h
  3. +8
    -5
      mindspore/ccsrc/pipeline/jit/pipeline_split.cc

+ 5
- 5
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -197,8 +197,8 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node) {
return std::make_pair(shape_list, dtype); return std::make_pair(shape_list, dtype);
} }


SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter,
const int &user_node_stage, const int &node_stage) {
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int user_node_stage,
int node_stage) {
Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag));
send_tag += 1; send_tag += 1;
auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
@@ -221,7 +221,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
} }


void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node,
const int &index, const int &user_node_stage, const int &node_stage) {
int index, int user_node_stage, int node_stage) {
Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag));
recv_tag += 1; recv_tag += 1;
auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
@@ -369,9 +369,9 @@ void PipelineTransformer::ElimGraphStage() {
} }


bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
auto anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
auto prim = anf_node->value()->cast<PrimitivePtr>();
return (prim->name() == name); return (prim->name() == name);
} }




+ 6
- 6
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h View File

@@ -34,13 +34,14 @@ typedef struct {


class PipelineTransformer { class PipelineTransformer {
public: public:
PipelineTransformer(const FuncGraphManagerPtr &manager, const int &stage, const FuncGraphPtr &root,
const int64_t &global_rank, const int64_t &per_stage_rank_num)
PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank,
int64_t per_stage_rank_num)
: manager_(manager), : manager_(manager),
stage_(stage), stage_(stage),
root_(root), root_(root),
global_rank_(global_rank), global_rank_(global_rank),
per_stage_rank_num_(per_stage_rank_num) {} per_stage_rank_num_(per_stage_rank_num) {}
virtual ~PipelineTransformer() = default;
void Coloring(); void Coloring();
void BroadCastColoring(); void BroadCastColoring();
void HandleSharedParameter(); void HandleSharedParameter();
@@ -54,10 +55,9 @@ class PipelineTransformer {
std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
void DoBroadCast(const FuncGraphPtr &func); void DoBroadCast(const FuncGraphPtr &func);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, const int &user_node_stage,
const int &node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, const int &index,
const int &user_node_stage, const int &node_stage);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int user_node_stage, int node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int user_node_stage, int node_stage);
void CutBorder(const FuncGraphPtr &graph); void CutBorder(const FuncGraphPtr &graph);
void ElimRootParameter(); void ElimRootParameter();
bool IsStageNode(const CNodePtr &node); bool IsStageNode(const CNodePtr &node);


+ 8
- 5
mindspore/ccsrc/pipeline/jit/pipeline_split.cc View File

@@ -24,9 +24,6 @@


namespace mindspore { namespace mindspore {
namespace pipeline { namespace pipeline {

static int64_t GetRank();
static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num);
static int64_t GetRank() { static int64_t GetRank() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
@@ -40,7 +37,7 @@ static int64_t GetRank() {
MS_LOG(EXCEPTION) << "Invalid backend: " << backend; MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
} }
int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank(); int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
uint32_t rank_id;
uint32_t rank_id = 0;
if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) { if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
MS_LOG(EXCEPTION) << "Get rank id failed."; MS_LOG(EXCEPTION) << "Get rank id failed.";
@@ -50,7 +47,7 @@ static int64_t GetRank() {
return global_rank; return global_rank;
} }


static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num) {
static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) {
if (device_num % stage_num != 0) { if (device_num % stage_num != 0) {
MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
<< "stage_num: " << stage_num; << "stage_num: " << stage_num;
@@ -75,6 +72,12 @@ bool PipelineSplit(const ResourcePtr &res) {
auto root = res->func_graph(); auto root = res->func_graph();
auto global_rank = GetRank(); auto global_rank = GetRank();
auto device_num = parallel::ParallelContext::GetInstance()->device_num(); auto device_num = parallel::ParallelContext::GetInstance()->device_num();
if (device_num < 1) {
MS_LOG(EXCEPTION) << "Invalid device num: " << device_num;
}
if (global_rank < 0) {
MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank;
}
auto stage = InferStage(global_rank, stage_num, device_num); auto stage = InferStage(global_rank, stage_num, device_num);
auto per_stage_rank_num = device_num / stage_num; auto per_stage_rank_num = device_num / stage_num;
auto transformer = auto transformer =


Loading…
Cancel
Save