|
|
|
@@ -409,7 +409,7 @@ void GatherV2PInfo::InferOutputsTensorMap() { |
|
|
|
} else { |
|
|
|
// param_strategy(axis) != 1 |
|
|
|
if (axis_ == 0) { |
|
|
|
if (dynamic_shape_indices_) { |
|
|
|
if (dynamic_shape_indices_ && target_ != CPU) { |
|
|
|
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE); |
|
|
|
} else { |
|
|
|
tensor_map_out.insert(tensor_map_out.end(), 0); |
|
|
|
@@ -423,7 +423,7 @@ void GatherV2PInfo::InferOutputsTensorMap() { |
|
|
|
if (i == IntToSize(axis_)) { |
|
|
|
tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE); |
|
|
|
} else { |
|
|
|
if (i == 0 && dynamic_shape_indices_) { |
|
|
|
if (i == 0 && dynamic_shape_indices_ && target_ != CPU) { |
|
|
|
tensor_map_out.push_back(MAP_NONE); |
|
|
|
} |
|
|
|
tensor_map_out.push_back(SizeToInt(param_size - i - 1)); |
|
|
|
@@ -648,10 +648,15 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); |
|
|
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); |
|
|
|
OperatorAttrs attrs = {attr_op, attr_group}; |
|
|
|
auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); |
|
|
|
AnfNodePtr reduce_op; |
|
|
|
if (dynamic_shape_indices_) { |
|
|
|
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul}); |
|
|
|
} else { |
|
|
|
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); |
|
|
|
} |
|
|
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; |
|
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( |
|
|
|
std::make_pair(input_nodes, reduce_scatter)); |
|
|
|
std::make_pair(input_nodes, reduce_op)); |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|