|
|
@@ -48,7 +48,7 @@ Status GatherV2PInfo::GetAttrs() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { |
|
|
|
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { |
|
|
if (is_auto_parallel_) { |
|
|
if (is_auto_parallel_) { |
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy."; |
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy."; |
|
|
} else { |
|
|
} else { |
|
|
@@ -84,12 +84,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// Don't support repeated calc |
|
|
|
|
|
auto params_strategy = strategy->GetInputDim().at(0); |
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, index can't be splited |
|
|
|
|
|
auto index_strategy = strategy->GetInputDim().at(1); |
|
|
|
|
|
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>()); |
|
|
|
|
|
if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { |
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, Don't support repeated calc |
|
|
CheckGlobalDeviceManager(); |
|
|
CheckGlobalDeviceManager(); |
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
auto product = std::accumulate(params_strategy.begin(), params_strategy.end(), 1, std::multiplies<int>()); |
|
|
|
|
|
if (dev_num != IntToSize(product)) { |
|
|
|
|
|
|
|
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); |
|
|
|
|
|
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { |
|
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; |
|
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
@@ -97,26 +104,66 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferMirrorOps() { |
|
|
|
|
|
mirror_ops_.clear(); |
|
|
|
|
|
Shape input_a_tensor_map = inputs_tensor_map_.at(0); |
|
|
|
|
|
std::vector<Group> input_a_group; |
|
|
|
|
|
if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { |
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Create group for input a failed."; |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
OperatorVector op_for_input_a, op_for_input_b, op_for_axis; |
|
|
|
|
|
if (input_a_group.empty()) { |
|
|
|
|
|
MS_LOG(INFO) << name_ << " : The mirror group is empty."; |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} else { |
|
|
|
|
|
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); |
|
|
|
|
|
MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
mirror_ops_.push_back(op_for_input_a); |
|
|
|
|
|
mirror_ops_.push_back(op_for_input_b); |
|
|
|
|
|
mirror_ops_.push_back(op_for_axis); |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferDevMatrixShape() { |
|
|
Status GatherV2PInfo::InferDevMatrixShape() { |
|
|
dev_matrix_shape_.clear(); |
|
|
dev_matrix_shape_.clear(); |
|
|
out_dev_matrix_shape_.clear(); |
|
|
out_dev_matrix_shape_.clear(); |
|
|
// infer input dev_matrix_shape |
|
|
// infer input dev_matrix_shape |
|
|
auto params_strategy = strategy_->GetInputDim().at(0); |
|
|
|
|
|
dev_matrix_shape_ = params_strategy; |
|
|
|
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
|
|
auto index_strategy = strategy_->GetInputDim().at(1); |
|
|
|
|
|
dev_matrix_shape_ = param_strategy; |
|
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis)!=1, |
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1) { |
|
|
|
|
|
std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); |
|
|
|
|
|
} else { |
|
|
|
|
|
dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// infer out dev_matrix_shape |
|
|
// infer out dev_matrix_shape |
|
|
// axis!=0, split axis |
|
|
// axis!=0, split axis |
|
|
if (axis_ != 0 && params_strategy.at(IntToSize(axis_)) != 1) { |
|
|
|
|
|
out_dev_matrix_shape_.push_back(params_strategy.at(0) * params_strategy.at(IntToSize(axis_))); |
|
|
|
|
|
for (size_t i = 1; i < params_strategy.size(); ++i) { |
|
|
|
|
|
|
|
|
if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) { |
|
|
|
|
|
out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_))); |
|
|
|
|
|
for (size_t i = 1; i < param_strategy.size(); ++i) { |
|
|
if (i == IntToSize(axis_)) { |
|
|
if (i == IntToSize(axis_)) { |
|
|
out_dev_matrix_shape_.push_back(1); |
|
|
out_dev_matrix_shape_.push_back(1); |
|
|
} else { |
|
|
} else { |
|
|
out_dev_matrix_shape_.push_back(params_strategy.at(i)); |
|
|
|
|
|
|
|
|
out_dev_matrix_shape_.push_back(param_strategy.at(i)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
out_dev_matrix_shape_ = params_strategy; |
|
|
|
|
|
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_; |
|
|
|
|
|
} |
|
|
|
|
|
auto product_out = |
|
|
|
|
|
std::accumulate(out_dev_matrix_shape_.begin(), out_dev_matrix_shape_.end(), 1, std::multiplies<int>()); |
|
|
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
|
|
|
if (product_out == 1) { |
|
|
|
|
|
out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), dev_num); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
@@ -124,28 +171,56 @@ Status GatherV2PInfo::InferDevMatrixShape() { |
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferTensorMap() { |
|
|
Status GatherV2PInfo::InferTensorMap() { |
|
|
// infer input tensor map |
|
|
// infer input tensor map |
|
|
|
|
|
// param_strategy(axis) != 1 |
|
|
size_t param_size = inputs_shape_.at(0).size(); |
|
|
size_t param_size = inputs_shape_.at(0).size(); |
|
|
size_t index_size = inputs_shape_.at(1).size(); |
|
|
size_t index_size = inputs_shape_.at(1).size(); |
|
|
std::vector<int32_t> tensor_map_index(index_size, -1); |
|
|
|
|
|
|
|
|
size_t total_size = dev_matrix_shape_.size(); |
|
|
|
|
|
std::vector<int32_t> tensor_map_index; |
|
|
std::vector<int32_t> tensor_map_params; |
|
|
std::vector<int32_t> tensor_map_params; |
|
|
for (size_t i = 0; i < param_size; ++i) { |
|
|
|
|
|
tensor_map_params.push_back(SizeToInt(param_size - i - 1)); |
|
|
|
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1) { |
|
|
|
|
|
tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); |
|
|
|
|
|
for (size_t i = 0; i < param_size; ++i) { |
|
|
|
|
|
tensor_map_params.push_back(SizeToInt(i)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
// param_strategy(axis) == 1 |
|
|
|
|
|
for (size_t i = 0; i < param_size; ++i) { |
|
|
|
|
|
tensor_map_params.push_back(SizeToInt(total_size - i - 1)); |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < index_size; ++i) { |
|
|
|
|
|
tensor_map_index.push_back(SizeToInt(index_size - i - 1)); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// infer output tensor map |
|
|
// infer output tensor map |
|
|
std::vector<int32_t> tensor_map_out; |
|
|
std::vector<int32_t> tensor_map_out; |
|
|
if (axis_ == 0) { |
|
|
|
|
|
tensor_map_out.push_back(SizeToInt(param_size - 1)); |
|
|
|
|
|
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); |
|
|
|
|
|
for (size_t i = 1; i < param_size; ++i) { |
|
|
|
|
|
tensor_map_out.push_back(SizeToInt(param_size - i - 1)); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) == 1) { |
|
|
|
|
|
// param_strategy(axis) == 1 |
|
|
for (size_t i = 0; i < param_size; ++i) { |
|
|
for (size_t i = 0; i < param_size; ++i) { |
|
|
if (i == IntToSize(axis_)) { |
|
|
if (i == IntToSize(axis_)) { |
|
|
tensor_map_out.insert(tensor_map_out.end(), index_size, -1); |
|
|
|
|
|
|
|
|
for (size_t j = 0; j < index_size; ++j) { |
|
|
|
|
|
tensor_map_out.push_back(SizeToInt(index_size - j - 1)); |
|
|
|
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
tensor_map_out.push_back(SizeToInt(param_size - i - 1)); |
|
|
|
|
|
|
|
|
tensor_map_out.push_back(SizeToInt(total_size - i - 1)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
// param_strategy(axis) != 1 |
|
|
|
|
|
if (axis_ == 0) { |
|
|
|
|
|
tensor_map_out.insert(tensor_map_out.end(), 0); |
|
|
|
|
|
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); |
|
|
|
|
|
for (size_t i = 1; i < param_size; ++i) { |
|
|
|
|
|
tensor_map_out.push_back(i); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (size_t i = 0; i < param_size; ++i) { |
|
|
|
|
|
if (i == IntToSize(axis_)) { |
|
|
|
|
|
tensor_map_out.insert(tensor_map_out.end(), index_size, -1); |
|
|
|
|
|
} else { |
|
|
|
|
|
tensor_map_out.push_back(SizeToInt(param_size - i - 1)); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -209,7 +284,12 @@ Status GatherV2PInfo::InferBias() { |
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferGroup() { |
|
|
Status GatherV2PInfo::InferGroup() { |
|
|
std::vector<Group> group_list; |
|
|
std::vector<Group> group_list; |
|
|
if (CreateGroupByDim(IntToSize(axis_), &group_list) != SUCCESS) { |
|
|
|
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
|
|
size_t dim = IntToSize(axis_); |
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { |
|
|
|
|
|
dim = (axis_ + 1) % 2; |
|
|
|
|
|
} |
|
|
|
|
|
if (CreateGroupByDim(dim, &group_list) != SUCCESS) { |
|
|
MS_LOG(ERROR) << name_ << ": Create group failed."; |
|
|
MS_LOG(ERROR) << name_ << ": Create group failed."; |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
@@ -231,7 +311,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); |
|
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); |
|
|
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); |
|
|
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); |
|
|
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); |
|
|
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); |
|
|
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), gen_g.virtual_input_node(), minimum}); |
|
|
|
|
|
|
|
|
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); |
|
|
auto gather_v2 = |
|
|
auto gather_v2 = |
|
|
gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); |
|
|
gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); |
|
|
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); |
|
|
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); |
|
|
@@ -250,8 +330,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); |
|
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); |
|
|
OperatorAttrs attrs = {attr_op, attr_group}; |
|
|
OperatorAttrs attrs = {attr_op, attr_group}; |
|
|
auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); |
|
|
auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); |
|
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1), |
|
|
|
|
|
std::make_pair(equal, 2)}; |
|
|
|
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; |
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( |
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( |
|
|
std::make_pair(input_nodes, reduce_scatter)); |
|
|
std::make_pair(input_nodes, reduce_scatter)); |
|
|
|
|
|
|
|
|
@@ -309,11 +388,11 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { |
|
|
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { |
|
|
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { |
|
|
is_auto_parallel_ = true; |
|
|
is_auto_parallel_ = true; |
|
|
Shape input0_split(inputs_shape_[0].size(), 1); |
|
|
Shape input0_split(inputs_shape_[0].size(), 1); |
|
|
Shapes splittable_inputs = {input0_split}; |
|
|
|
|
|
|
|
|
Shape input1_split(inputs_shape_[1].size(), 1); |
|
|
|
|
|
Shapes splittable_inputs = {input0_split, input1_split}; |
|
|
|
|
|
|
|
|
std::vector<StrategyPtr> sp_vector; |
|
|
std::vector<StrategyPtr> sp_vector; |
|
|
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != |
|
|
|
|
|
SUCCESS) { |
|
|
|
|
|
|
|
|
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { |
|
|
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; |
|
|
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
@@ -331,12 +410,13 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { |
|
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() { |
|
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() { |
|
|
CheckGlobalDeviceManager(); |
|
|
CheckGlobalDeviceManager(); |
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
Dimensions strategy; |
|
|
|
|
|
strategy.push_back(SizeToInt(dev_num)); |
|
|
|
|
|
for (size_t i = 1; i < inputs_shape_[0].size(); i++) { |
|
|
|
|
|
strategy.push_back(1); |
|
|
|
|
|
|
|
|
Dimensions param_strategy(inputs_shape_[0].size(), 1); |
|
|
|
|
|
Dimensions index_strategy; |
|
|
|
|
|
index_strategy.push_back(SizeToInt(dev_num)); |
|
|
|
|
|
for (size_t i = 1; i < inputs_shape_[1].size(); i++) { |
|
|
|
|
|
index_strategy.push_back(1); |
|
|
} |
|
|
} |
|
|
std::vector<Dimensions> strategy_v = {strategy}; |
|
|
|
|
|
|
|
|
std::vector<Dimensions> strategy_v = {param_strategy, index_strategy}; |
|
|
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); |
|
|
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); |
|
|
} |
|
|
} |
|
|
} // namespace parallel |
|
|
} // namespace parallel |
|
|
|