Merge pull request !31066 from liuluobin/fix_roialignr1.7
| @@ -91,8 +91,13 @@ AnfNodePtr CreateInt32Tensor(int64_t value) { | |||
| return anf_node_ptr; | |||
| } | |||
| AnfNodePtr CreatTypeInt(int64_t value) { | |||
| ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value)); | |||
| AnfNodePtr CreateTypeInt(int64_t nbits) { | |||
| ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits)); | |||
| return ValuePtrToAnfNodePtr(value_ptr); | |||
| } | |||
| AnfNodePtr CreateTypeFloat(int64_t nbits) { | |||
| ValuePtr value_ptr = MakeValue(std::make_shared<Float>(nbits)); | |||
| return ValuePtrToAnfNodePtr(value_ptr); | |||
| } | |||
| @@ -37,7 +37,8 @@ std::string GetOpPythonPath(const OperatorName &op_name); | |||
| // Init python operator Instance | |||
| ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); | |||
| AnfNodePtr CreatTypeInt(int64_t value); | |||
| AnfNodePtr CreateTypeInt(int64_t nbits); | |||
| AnfNodePtr CreateTypeFloat(int64_t nbits); | |||
| AnfNodePtr CreatInt64Imm(int64_t value); | |||
| AnfNodePtr CreateFP32Imm(float value); | |||
| AnfNodePtr CreateInt32Tensor(int64_t value); | |||
| @@ -147,18 +147,19 @@ Status CropAndResizeInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); | |||
| auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); | |||
| auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); | |||
| auto not_equal = gen_g.PushBack({gen_g.NewOpInst(NOT_EQUAL), minimum, sub}); | |||
| auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), minimum, sub}); | |||
| auto cast_equal = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreateTypeFloat(32)}); | |||
| auto crop_and_resize = gen_g.PushBack({gen_g.NewOpInst(CROP_AND_RESIZE), gen_g.virtual_input_node(), | |||
| gen_g.virtual_input_node(), minimum, gen_g.virtual_input_node()}); | |||
| auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), not_equal, CreatInt64Imm(-1)}); | |||
| auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast_equal, CreatInt64Imm(-1)}); | |||
| auto expand_dims_1 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_0, CreatInt64Imm(-1)}); | |||
| auto expand_dims_2 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_1, CreatInt64Imm(-1)}); | |||
| auto masked_fill = gen_g.PushBack({gen_g.NewOpInst(MASKED_FILL), crop_and_resize, expand_dims_2, CreateFP32Imm(0.0)}); | |||
| auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), crop_and_resize, expand_dims_2}); | |||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | |||
| Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||
| OperatorAttrs attrs = {attr_op, attr_group}; | |||
| AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), masked_fill}); | |||
| AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul}); | |||
| std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 3), std::make_pair(crop_and_resize, 1), | |||
| std::make_pair(crop_and_resize, 2), | |||
| @@ -220,9 +221,12 @@ std::vector<StrategyPtr> CropAndResizeInfo::GenerateOpStrategies(int64_t stage_i | |||
| } | |||
| Status CropAndResizeInfo::InferMirrorOps() { | |||
| if (OperatorInfo::InferMirrorOps() == FAILED) { | |||
| if (OperatorInfo::InferMirrorOps() != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (mirror_ops_.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| OperatorVector op_for_crop_size; | |||
| (void)mirror_ops_.emplace_back(std::move(op_for_crop_size)); | |||
| @@ -183,7 +183,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)}); | |||
| auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1}); | |||
| auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); | |||
| auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); | |||
| auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreateTypeInt(32)}); | |||
| auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); | |||
| auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(ADD), mul2, CreateInt32Tensor(1)}); | |||
| auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); | |||
| @@ -158,10 +158,10 @@ Status ROIAlignInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| std::make_pair(SHRINK_AXIS_MASK, MakeValue(2))}; | |||
| auto strided_slice = gen_g.PushBack( | |||
| {gen_g.NewOpInst(STRIDEDSLICE, strided_slice_attrs), gen_g.virtual_input_node(), begin, end, strides}); | |||
| auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gen_g.virtual_input_node()}); | |||
| auto cast_bias = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(bias_), dtype}); | |||
| auto dtype_rois = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gen_g.virtual_input_node()}); | |||
| auto cast_bias = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(bias_), dtype_rois}); | |||
| auto cast_slice_max_index = | |||
| gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(features_slice_size_ - 1), dtype}); | |||
| gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(features_slice_size_ - 1), dtype_rois}); | |||
| auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), strided_slice, cast_bias}); | |||
| auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); | |||
| auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, cast_slice_max_index}); | |||
| @@ -173,19 +173,21 @@ Status ROIAlignInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| gen_g.PushBack({gen_g.NewOpInst(TENSOR_SCATTER_UPDATE), gen_g.virtual_input_node(), stack, minimum}); | |||
| auto roi_align = | |||
| gen_g.PushBack({gen_g.NewOpInst(ROI_ALIGN, roi_align_attrs), gen_g.virtual_input_node(), tensor_scatter_update}); | |||
| auto not_equal = gen_g.PushBack({gen_g.NewOpInst(NOT_EQUAL), strided_slice, minimum}); | |||
| auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), not_equal, CreatInt64Imm(-1)}); | |||
| auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); | |||
| auto dtype_features = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gen_g.virtual_input_node()}); | |||
| auto cast_equal = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype_features}); | |||
| auto expand_dims_0 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast_equal, CreatInt64Imm(-1)}); | |||
| auto expand_dims_1 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_0, CreatInt64Imm(-1)}); | |||
| auto expand_dims_2 = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), expand_dims_1, CreatInt64Imm(-1)}); | |||
| auto masked_fill = gen_g.PushBack({gen_g.NewOpInst(MASKED_FILL), roi_align, expand_dims_2, CreateFP32Imm(0.0)}); | |||
| auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), roi_align, expand_dims_2}); | |||
| Attr attr_reduce_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | |||
| Attr attr_reduce_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||
| OperatorAttrs attrs_reduce = {attr_reduce_op, attr_reduce_group}; | |||
| AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs_reduce), masked_fill}); | |||
| AnfNodePtr reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs_reduce), mul}); | |||
| std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes = { | |||
| std::make_pair(strided_slice, 2), std::make_pair(dtype, 2), std::make_pair(tensor_scatter_update, 2), | |||
| std::make_pair(roi_align, 1)}; | |||
| std::make_pair(strided_slice, 2), std::make_pair(dtype_rois, 2), std::make_pair(tensor_scatter_update, 2), | |||
| std::make_pair(roi_align, 1), std::make_pair(dtype_features, 1)}; | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( | |||
| std::make_pair(inputs_nodes, reduce_op)); | |||
| return SUCCESS; | |||
| @@ -117,8 +117,10 @@ def test_roi_align_layout(): | |||
| # check sub_graph | |||
| sub_graph = { | |||
| 'TensorScatterUpdate-0': ['rois', 'Stack-0', 'Minimum-0'], | |||
| 'Equal-0': ['Sub-0', 'Minimum-0'], | |||
| 'ROIAlign-0': ['features', 'TensorScatterUpdate-0'], | |||
| 'MaskedFill-0': ['ROIAlign-0', 'ExpandDims-2', 0.0], | |||
| 'AllReduce-0': ['MaskedFill-0'] | |||
| 'Mul-0': ['ROIAlign-0', 'ExpandDims-2'], | |||
| 'AllReduce-0': ['Mul-0'] | |||
| } | |||
| assert validator.check_graph_structure(sub_graph) | |||