| @@ -47,6 +47,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type) | |||
| {"Concat", "ConcatD"}, | |||
| {"Softmax", "SoftmaxV2"}, | |||
| {"DropoutDoMask", "DropOutDoMask"}, | |||
| {"IOU", "Iou"}, | |||
| }; | |||
| auto iter = kOpTypeMap.find(op_type); | |||
| if (iter == kOpTypeMap.end()) { | |||
| @@ -587,6 +587,7 @@ inline const PrimitivePtr kPrimErrorOnDynamicShapeInput = std::make_shared<Primi | |||
| // Other miscellaneous | |||
| inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend", kSideEffectPropagate); | |||
| inline const PrimitivePtr kPrimIOU = std::make_shared<Primitive>("IOU"); | |||
| inline const PrimitivePtr kPrimReformat = std::make_shared<Primitive>("Reformat"); | |||
| inline const PrimitivePtr kPrimLoad = std::make_shared<Primitive>("Load"); | |||
| inline const PrimitivePtr kPrimUpdateState = std::make_shared<Primitive>("UpdateState"); | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/iou.h" | |||
| #include <algorithm> | |||
| #include <set> | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 2, prim_name); | |||
| (void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0); | |||
| (void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1); | |||
| auto x_shape_ptr = input_args[0]->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(x_shape_ptr); | |||
| auto y_shape_ptr = input_args[1]->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(y_shape_ptr); | |||
| auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr); | |||
| auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(y_shape_ptr); | |||
| auto x_shp = x_shape_map[kShape]; | |||
| auto y_shp = y_shape_map[kShape]; | |||
| if (x_shp.size() != 2 || y_shp.size() != 2) { | |||
| MS_EXCEPTION(ValueError) << "For BatchMatMul, input x, y should have the same dimension size and should be greater" | |||
| << "or equal to 3, while x size = " << x_shp.size() << ", y size = " << y_shp.size(); | |||
| } | |||
| (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(x_shp[1]), kGreaterEqual, 4, prim_name); | |||
| (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(y_shp[1]), kGreaterEqual, 4, prim_name); | |||
| ShapeVector x_min_shape = x_shape_map[kMinShape]; | |||
| ShapeVector x_max_shape = x_shape_map[kMaxShape]; | |||
| ShapeVector y_min_shape = y_shape_map[kMinShape]; | |||
| ShapeVector y_max_shape = y_shape_map[kMaxShape]; | |||
| ShapeVector ret_shape; | |||
| ShapeVector ret_min_shape; | |||
| ShapeVector ret_max_shape; | |||
| ret_shape.push_back(y_shp[0]); | |||
| ret_shape.push_back(x_shp[0]); | |||
| if (y_shape_ptr->IsDynamic()) { | |||
| ret_min_shape.push_back(y_min_shape[0]); | |||
| ret_max_shape.push_back(y_max_shape[0]); | |||
| } else { | |||
| ret_min_shape.push_back(y_shp[0]); | |||
| ret_max_shape.push_back(y_shp[0]); | |||
| } | |||
| if (x_shape_ptr->IsDynamic()) { | |||
| ret_min_shape.push_back(x_min_shape[0]); | |||
| ret_max_shape.push_back(x_max_shape[0]); | |||
| } else { | |||
| ret_min_shape.push_back(x_shp[0]); | |||
| ret_max_shape.push_back(x_shp[0]); | |||
| } | |||
| return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | |||
| std::map<std::string, TypePtr> types; | |||
| (void)types.emplace("x", input_args[0]->BuildType()); | |||
| (void)types.emplace("y", input_args[1]->BuildType()); | |||
| return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr IOUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto type = InferType(primitive, input_args); | |||
| auto shape = InferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(shape, type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(IOU, prim::kPrimIOU, IOUInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_IOU_H_ | |||
| #define MINDSPORE_CORE_OPS_IOU_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| class MS_CORE_API IOU : public PrimitiveC { | |||
| public: | |||
| IOU() : PrimitiveC(prim::kPrimIOU->name()) { InitIOName({"x,y"}, {"output"}); } | |||
| ~IOU() = default; | |||
| MS_DECLARE_PARENT(IOU, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_IOU_H_ | |||
| @@ -278,6 +278,7 @@ from .bounding_box_decode import _bounding_box_decode_tbe | |||
| from .bounding_box_encode import _bounding_box_encode_tbe | |||
| from .check_valid import _check_valid_tbe | |||
| from .iou import _iou_tbe | |||
| from .iou_ds import _iou_ds_tbe | |||
| from .arg_max import _arg_max_tbe | |||
| from .nms_with_mask import _nms_with_mask_tbe | |||
| from .sgd import _sgd_tbe | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Iou op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| iou_op_info = TBERegOp("IOU") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("iou.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("iou") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True)\ | |||
| .attr("mode", "optional", "str", "all", "iou") \ | |||
| .attr("eps", "optional", "float", "all", "1.0") \ | |||
| .input(0, "bboxes", False, "required", "all") \ | |||
| .input(1, "gtboxes", False, "required", "all") \ | |||
| .output(0, "overlap", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(iou_op_info) | |||
| def _iou_ds_tbe(): | |||
| """Iou TBE register""" | |||
| return | |||
| @@ -332,7 +332,7 @@ class CheckValid(PrimitiveWithInfer): | |||
| return mstype.bool_ | |||
| class IOU(PrimitiveWithInfer): | |||
| class IOU(Primitive): | |||
| r""" | |||
| Calculates intersection over union for boxes. | |||
| @@ -384,20 +384,6 @@ class IOU(PrimitiveWithInfer): | |||
| raise KeyError(f"For '{self.name}', only 'iou' or 'iof' are supported, but got 'mode': {mode}.") | |||
| self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) | |||
| def infer_shape(self, anchor_boxes, gt_boxes): | |||
| validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name) | |||
| validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name) | |||
| validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name) | |||
| validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name) | |||
| iou = [gt_boxes[0], anchor_boxes[0]] | |||
| return iou | |||
| def infer_dtype(self, anchor_boxes, gt_boxes): | |||
| valid_type = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name) | |||
| return anchor_boxes | |||
| class Partial(Primitive): | |||
| """ | |||