|
|
|
@@ -20,6 +20,7 @@ |
|
|
|
#include <numeric> |
|
|
|
#include <functional> |
|
|
|
#include <utility> |
|
|
|
#include <algorithm> |
|
|
|
|
|
|
|
#include "parallel/device_matrix.h" |
|
|
|
#include "parallel/graph_util/generate_graph.h" |
|
|
|
@@ -62,6 +63,55 @@ Status GatherV2PInfo::GetAttrs() { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
auto manual_split_iter = attrs_.find("manual_split"); |
|
|
|
if (manual_split_iter != attrs_.end()) { |
|
|
|
param_split_shapes_.clear(); |
|
|
|
manual_split_ = true; |
|
|
|
auto var = manual_split_iter->second->cast<ValueTuplePtr>(); |
|
|
|
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); |
|
|
|
|
|
|
|
if (var->size() > 0) { |
|
|
|
std::vector<ValuePtr> elements = var->value(); |
|
|
|
for (auto &ele : elements) { |
|
|
|
if (ele->isa<ValueSequeue>()) { |
|
|
|
auto value_tuple = ele->cast<ValueTuplePtr>(); |
|
|
|
std::vector<ValuePtr> value_vector = value_tuple->value(); |
|
|
|
if (value_vector.size() != 2) { |
|
|
|
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
param_split_shapes_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[0]))); |
|
|
|
index_offsets_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[1]))); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (param_split_shapes_.empty()) { |
|
|
|
MS_LOG(ERROR) << "Failed to extract param split strategy."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::CheckManualSplit() { |
|
|
|
auto param_shape = inputs_shape_.at(0); |
|
|
|
int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, |
|
|
|
[](int32_t s, int32_t shape) { return s + shape; }); |
|
|
|
if (split_shape_sum < param_shape.at(0)) { |
|
|
|
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { |
|
|
|
MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -103,6 +153,14 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (manual_split_) { |
|
|
|
if (CheckManualSplit() != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
// when using manual_split, no need to check belowings. |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 |
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; |
|
|
|
@@ -130,6 +188,11 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferMirrorOps() { |
|
|
|
// There is no mirror operators for manual split |
|
|
|
if (manual_split_) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
mirror_ops_.clear(); |
|
|
|
Shape input_a_tensor_map = inputs_tensor_map_.at(0); |
|
|
|
std::vector<Group> input_a_group; |
|
|
|
@@ -160,6 +223,13 @@ Status GatherV2PInfo::InferDevMatrixShape() { |
|
|
|
// infer input dev_matrix_shape |
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
auto index_strategy = strategy_->GetInputDim().at(1); |
|
|
|
|
|
|
|
if (manual_split_) { |
|
|
|
dev_matrix_shape_ = param_strategy; |
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
dev_matrix_shape_ = param_strategy; |
|
|
|
|
|
|
|
// param_strategy(axis)!=1, |
|
|
|
@@ -195,6 +265,12 @@ Status GatherV2PInfo::InferDevMatrixShape() { |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferTensorMap() { |
|
|
|
if (manual_split_) { |
|
|
|
inputs_tensor_map_.push_back({1, 0}); |
|
|
|
inputs_tensor_map_.push_back({-1, 1}); |
|
|
|
outputs_tensor_map_.push_back({-1, 1, 0}); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
// infer input tensor map |
|
|
|
// param_strategy(axis) != 1 |
|
|
|
size_t param_size = inputs_shape_.at(0).size(); |
|
|
|
@@ -261,8 +337,13 @@ Status GatherV2PInfo::InferTensorInfo() { |
|
|
|
Shape input_shape = inputs_shape_.at(0); |
|
|
|
Shape input_index_shape = inputs_shape_.at(1); |
|
|
|
Shape output_shape = outputs_shape_.at(0); |
|
|
|
int32_t rank = g_device_manager->global_rank(); |
|
|
|
// infer tensor layout |
|
|
|
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; |
|
|
|
if (manual_split_) { |
|
|
|
input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]]; |
|
|
|
input_shape[0] = input_shape[0] * dev_matrix_shape_[0]; |
|
|
|
} |
|
|
|
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || |
|
|
|
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || |
|
|
|
(output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != |
|
|
|
@@ -274,6 +355,9 @@ Status GatherV2PInfo::InferTensorInfo() { |
|
|
|
TensorInfo input_index_info(input_index_layout); |
|
|
|
TensorInfo output_tensor_info(output_tensor_layout); |
|
|
|
|
|
|
|
Shape slice_shape = input_tensor_info.slice_shape(); |
|
|
|
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); |
|
|
|
|
|
|
|
inputs_tensor_info_.push_back(input_tensor_info); |
|
|
|
inputs_tensor_info_.push_back(input_index_info); |
|
|
|
outputs_tensor_info_.push_back(output_tensor_info); |
|
|
|
@@ -312,6 +396,19 @@ Status GatherV2PInfo::InferBias() { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferOffset() { |
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
size_t rank = g_device_manager->global_rank(); |
|
|
|
if (rank < index_offsets_.size()) { |
|
|
|
index_offset_ = index_offsets_.at(rank); |
|
|
|
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size(); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferGroup() { |
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
size_t dim = IntToSize(axis_); |
|
|
|
@@ -410,6 +507,19 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
MS_LOG(ERROR) << "GenerateGraph Init failed"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (manual_split_) { |
|
|
|
if (InferOffset() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)}); |
|
|
|
auto gather_v2 = |
|
|
|
gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)}); |
|
|
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; |
|
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( |
|
|
|
std::make_pair(input_nodes, gather_v2)); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
if (InferBias() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed."; |
|
|
|
return FAILED; |
|
|
|
@@ -444,6 +554,14 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
if (manual_split_) { |
|
|
|
if (ComputeReplaceGraph(cnode) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return replace_graph_; |
|
|
|
} |
|
|
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
// target_ == CPU, no need to raplace graph |
|
|
|
if (target_ == CPU) { |
|
|
|
|