|
|
|
@@ -44,6 +44,18 @@ Status GatherV2PInfo::GetAttrs() { |
|
|
|
} |
|
|
|
axis_ = axis; |
|
|
|
|
|
|
|
// get target |
|
|
|
auto target_iter = attrs_.find(TARGET); |
|
|
|
if (target_iter != attrs_.end()) { |
|
|
|
MS_EXCEPTION_IF_NULL(target_iter->second); |
|
|
|
if (target_iter->second->isa<StringImm>()) { |
|
|
|
target_ = target_iter->second->cast<StringImmPtr>()->value(); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << " : The value of target is not a string."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -61,8 +73,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
auto param_shape = inputs_shape_.at(0); |
|
|
|
auto param_strategy = strategy->GetInputDim().at(0); |
|
|
|
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); |
|
|
|
if (slice_shape % 8 != 0) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; |
|
|
|
if (slice_shape % 8 != 0 && slice_shape != 1) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": Last dim of param slice shape need 32Byte aligned."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -74,20 +86,20 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
|
|
|
|
// don't support scalar index |
|
|
|
if (inputs_shape_.at(1).size() == 0) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Don't support scalar index."; |
|
|
|
MS_LOG(DEBUG) << name_ << ": Don't support scalar index."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
// axis=0, index_shape(0)%param_strategy(0) must be 0 |
|
|
|
Shape index_shape = inputs_shape_.at(1); |
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { |
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; |
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 |
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { |
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; |
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -95,7 +107,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
auto index_strategy = strategy->GetInputDim().at(1); |
|
|
|
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>()); |
|
|
|
if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { |
|
|
|
MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; |
|
|
|
MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -104,7 +116,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); |
|
|
|
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; |
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -290,18 +302,85 @@ Status GatherV2PInfo::InferBias() { |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferGroup() { |
|
|
|
std::vector<Group> group_list; |
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
size_t dim = IntToSize(axis_); |
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { |
|
|
|
dim = (axis_ + 1) % 2; |
|
|
|
} |
|
|
|
if (CreateGroupByDim(dim, &group_list) != SUCCESS) { |
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager); |
|
|
|
int32_t rank = g_device_manager->global_rank(); |
|
|
|
RankList dev_list = g_device_manager->GetDeviceListByStageId(0); |
|
|
|
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); |
|
|
|
RankList group_devices; |
|
|
|
if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Create group failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (group_devices.size() == 1) { |
|
|
|
MS_LOG(INFO) << "the group is empty"; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
group_ = g_device_manager->CreateGroup(group_devices); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int32_t> GetRankFromGroup(const Group &group) { |
|
|
|
std::vector<int32_t> rank_list; |
|
|
|
auto device_list = group.GetDevicesList(); |
|
|
|
for (auto &device : device_list) { |
|
|
|
rank_list.insert(rank_list.end(), device.rank() % 8); |
|
|
|
} |
|
|
|
return rank_list; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferForwardCommunication() { |
|
|
|
forward_op_.clear(); |
|
|
|
if (target_ != CPU) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
// don't split axis, no need forward communication |
|
|
|
if (param_strategy.at(IntToSize(axis_)) == 1) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
// split axis |
|
|
|
OperatorName operator_name; |
|
|
|
if (InferGroup() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto group_size = group_.GetDevNum(); |
|
|
|
Attr attr_group; |
|
|
|
// group size <= 8 |
|
|
|
std::vector<int32_t> rank_list; |
|
|
|
if (group_size <= 8) { |
|
|
|
reduce_scatter_flag_ = false; |
|
|
|
operator_name = HOST_REDUCE_SCATTER; |
|
|
|
rank_list = GetRankFromGroup(group_); |
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(rank_list)); |
|
|
|
} else { |
|
|
|
// group size > 8 |
|
|
|
reduce_scatter_flag_ = true; |
|
|
|
split_num_ = SizeToInt(group_size / 8); |
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
operator_name = REDUCE_SCATTER; |
|
|
|
int32_t rank = g_device_manager->global_rank(); |
|
|
|
size_t repeat = group_size / 8; |
|
|
|
for (size_t i = 0; i < repeat; ++i) { |
|
|
|
rank_list.push_back(rank + SizeToInt(i * 8)); |
|
|
|
} |
|
|
|
Group g = g_device_manager->CreateGroup(rank_list); |
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(g.name())); |
|
|
|
} |
|
|
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); |
|
|
|
OperatorAttrs attrs = {attr_op, attr_group}; |
|
|
|
OperatorParams params; |
|
|
|
OperatorArgs args = std::make_pair(attrs, params); |
|
|
|
Operator op = std::make_pair(operator_name, args); |
|
|
|
|
|
|
|
group_ = group_list.at(0); |
|
|
|
forward_op_.push_back(op); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -346,6 +425,10 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
// target_ == CPU, no need to raplace graph |
|
|
|
if (target_ == CPU) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; |
|
|
|
return nullptr; |
|
|
|
@@ -353,11 +436,34 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
return replace_graph_; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::ComputeReplaceOp() { |
|
|
|
if (InferBias() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer offset failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
OperatorName op_name = EMBEDDING_LOOKUP; |
|
|
|
OperatorAttrs attrs; |
|
|
|
Attr param_offset = std::make_pair("offset", MakeValue(bias_)); |
|
|
|
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); |
|
|
|
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); |
|
|
|
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5), |
|
|
|
std::make_pair(param_split_num, 6)}; |
|
|
|
OperatorArgs args = std::make_pair(attrs, params); |
|
|
|
Operator op = std::make_pair(op_name, args); |
|
|
|
replace_op_.push_back(op); |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::Init(const StrategyPtr &strategy) { |
|
|
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Init failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
// only target_ == CPU, we need to replace op |
|
|
|
if (target_ == CPU && ComputeReplaceOp() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << name_ << ": Init success."; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|