Browse Source

auto_parallel_dynamic_shape_supplements

tags/v1.1.0
yao_yf 5 years ago
parent
commit
e76c8b708d
1 changed files with 9 additions and 4 deletions
  1. +9
    -4
      mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc

+ 9
- 4
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc View File

@@ -409,7 +409,7 @@ void GatherV2PInfo::InferOutputsTensorMap() {
} else {
// param_strategy(axis) != 1
if (axis_ == 0) {
if (dynamic_shape_indices_) {
if (dynamic_shape_indices_ && target_ != CPU) {
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
} else {
tensor_map_out.insert(tensor_map_out.end(), 0);
@@ -423,7 +423,7 @@ void GatherV2PInfo::InferOutputsTensorMap() {
if (i == IntToSize(axis_)) {
tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
} else {
if (i == 0 && dynamic_shape_indices_) {
if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
tensor_map_out.push_back(MAP_NONE);
}
tensor_map_out.push_back(SizeToInt(param_size - i - 1));
@@ -648,10 +648,15 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
OperatorAttrs attrs = {attr_op, attr_group};
auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
AnfNodePtr reduce_op;
if (dynamic_shape_indices_) {
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
} else {
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
}
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
std::make_pair(input_nodes, reduce_scatter));
std::make_pair(input_nodes, reduce_op));

return SUCCESS;
}


Loading…
Cancel
Save