Browse Source

!31066 Fix a bug where ROIAlign and CropAndResize distributed op do not support GPU

Merge pull request !31066 from liuluobin/fix_roialign
r1.7
i-robot Gitee 4 years ago
parent
commit
67d10ce3be
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 34 additions and 20 deletions
  1. +7
    -2
      mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h
  3. +9
    -5
      mindspore/ccsrc/frontend/parallel/ops_info/crop_and_resize_info.cc
  4. +1
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc
  5. +11
    -9
      mindspore/ccsrc/frontend/parallel/ops_info/roi_align_info.cc
  6. +4
    -2
      tests/ut/python/parallel/test_roi_align.py

+ 7
- 2
mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc View File

@@ -91,8 +91,13 @@ AnfNodePtr CreateInt32Tensor(int64_t value) {
return anf_node_ptr;
}

AnfNodePtr CreatTypeInt(int64_t value) {
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value));
AnfNodePtr CreateTypeInt(int64_t nbits) {
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits));
return ValuePtrToAnfNodePtr(value_ptr);
}

AnfNodePtr CreateTypeFloat(int64_t nbits) {
ValuePtr value_ptr = MakeValue(std::make_shared<Float>(nbits));
return ValuePtrToAnfNodePtr(value_ptr);
}



+ 2
- 1
mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h View File

@@ -37,7 +37,8 @@ std::string GetOpPythonPath(const OperatorName &op_name);
// Init python operator Instance
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);

AnfNodePtr CreatTypeInt(int64_t value);
AnfNodePtr CreateTypeInt(int64_t nbits);
AnfNodePtr CreateTypeFloat(int64_t nbits);
AnfNodePtr CreatInt64Imm(int64_t value);
AnfNodePtr CreateFP32Imm(float value);
AnfNodePtr CreateInt32Tensor(int64_t value);


+ 9
- 5
mindspore/ccsrc/frontend/parallel/ops_info/crop_and_resize_info.cc View File

@@ -147,18 +147,19 @@ Status CropAndResizeInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
auto not_equal = gen_g.PushBack({gen_g.NewOpInst(NOT_EQUAL), minimum, sub});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), minimum, sub});
auto cast_equal = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreateTypeFloat(32)});
auto crop_and_resize = gen_g.PushBack({gen_g.NewOpInst(CROP_AND_RESIZE), gen_g.virtual_input_node(),
gen_g.virtual_input_node(), minimum, gen_g.virtual_input_node()});
auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), not_equal, CreatInt64Imm(-1)});
auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast_equal, CreatInt64Imm(-1)});
auto expand_dims_1 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_0, CreatInt64Imm(-1)});
auto expand_dims_2 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_1, CreatInt64Imm(-1)});
auto masked_fill = gen_g.PushBack({gen_g.NewOpInst(MASKED_FILL), crop_and_resize, expand_dims_2, CreateFP32Imm(0.0)});
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), crop_and_resize, expand_dims_2});

Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
OperatorAttrs attrs = {attr_op, attr_group};
AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), masked_fill});
AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});

std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 3), std::make_pair(crop_and_resize, 1),
std::make_pair(crop_and_resize, 2),
@@ -220,9 +221,12 @@ std::vector<StrategyPtr> CropAndResizeInfo::GenerateOpStrategies(int64_t stage_i
}

Status CropAndResizeInfo::InferMirrorOps() {
if (OperatorInfo::InferMirrorOps() == FAILED) {
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
return FAILED;
}
if (mirror_ops_.empty()) {
return SUCCESS;
}

OperatorVector op_for_crop_size;
(void)mirror_ops_.emplace_back(std::move(op_for_crop_size));


+ 1
- 1
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc View File

@@ -183,7 +183,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)});
auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)});
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)});
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreateTypeInt(32)});
auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast});
auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)});
auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add});


+ 11
- 9
mindspore/ccsrc/frontend/parallel/ops_info/roi_align_info.cc View File

@@ -158,10 +158,10 @@ Status ROIAlignInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
std::make_pair(SHRINK_AXIS_MASK, MakeValue(2))};
auto strided_slice = gen_g.PushBack(
{gen_g.NewOpInst(STRIDEDSLICE, strided_slice_attrs), gen_g.virtual_input_node(), begin, end, strides});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gen_g.virtual_input_node()});
auto cast_bias = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(bias_), dtype});
auto dtype_rois = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gen_g.virtual_input_node()});
auto cast_bias = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(bias_), dtype_rois});
auto cast_slice_max_index =
gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(features_slice_size_ - 1), dtype});
gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(features_slice_size_ - 1), dtype_rois});
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), strided_slice, cast_bias});
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, cast_slice_max_index});
@@ -173,19 +173,21 @@ Status ROIAlignInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
gen_g.PushBack({gen_g.NewOpInst(TENSOR_SCATTER_UPDATE), gen_g.virtual_input_node(), stack, minimum});
auto roi_align =
gen_g.PushBack({gen_g.NewOpInst(ROI_ALIGN, roi_align_attrs), gen_g.virtual_input_node(), tensor_scatter_update});
auto not_equal = gen_g.PushBack({gen_g.NewOpInst(NOT_EQUAL), strided_slice, minimum});
auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), not_equal, CreatInt64Imm(-1)});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
auto dtype_features = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gen_g.virtual_input_node()});
auto cast_equal = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype_features});
auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast_equal, CreatInt64Imm(-1)});
auto expand_dims_1 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_0, CreatInt64Imm(-1)});
auto expand_dims_2 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_1, CreatInt64Imm(-1)});
auto masked_fill = gen_g.PushBack({gen_g.NewOpInst(MASKED_FILL), roi_align, expand_dims_2, CreateFP32Imm(0.0)});
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), roi_align, expand_dims_2});
Attr attr_reduce_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
Attr attr_reduce_group = std::make_pair(GROUP, MakeValue(group_.name()));
OperatorAttrs attrs_reduce = {attr_reduce_op, attr_reduce_group};
AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs_reduce), masked_fill});
AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs_reduce), mul});

std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes = {
std::make_pair(strided_slice, 2), std::make_pair(dtype, 2), std::make_pair(tensor_scatter_update, 2),
std::make_pair(roi_align, 1)};
std::make_pair(strided_slice, 2), std::make_pair(dtype_rois, 2), std::make_pair(tensor_scatter_update, 2),
std::make_pair(roi_align, 1), std::make_pair(dtype_features, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(inputs_nodes, reduce_op));
return SUCCESS;


+ 4
- 2
tests/ut/python/parallel/test_roi_align.py View File

@@ -117,8 +117,10 @@ def test_roi_align_layout():

# check sub_graph
sub_graph = {
'TensorScatterUpdate-0': ['rois', 'Stack-0', 'Minimum-0'],
'Equal-0': ['Sub-0', 'Minimum-0'],
'ROIAlign-0': ['features', 'TensorScatterUpdate-0'],
'MaskedFill-0': ['ROIAlign-0', 'ExpandDims-2', 0.0],
'AllReduce-0': ['MaskedFill-0']
'Mul-0': ['ROIAlign-0', 'ExpandDims-2'],
'AllReduce-0': ['Mul-0']
}
assert validator.check_graph_structure(sub_graph)

Loading…
Cancel
Save