Browse Source

[feat][assistant][I480FL] add the dynamic shape for iou

tags/v1.6.0
shicaiwei 4 years ago
parent
commit
14a0ea18e4
7 changed files with 167 additions and 15 deletions
  1. +1
    -0
      mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc
  2. +1
    -0
      mindspore/core/base/core_ops.h
  3. +86
    -0
      mindspore/core/ops/iou.cc
  4. +38
    -0
      mindspore/core/ops/iou.h
  5. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  6. +39
    -0
      mindspore/ops/_op_impl/tbe/iou_ds.py
  7. +1
    -15
      mindspore/ops/operations/other_ops.py

+ 1
- 0
mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc View File

@@ -47,6 +47,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
{"Concat", "ConcatD"},
{"Softmax", "SoftmaxV2"},
{"DropoutDoMask", "DropOutDoMask"},
{"IOU", "Iou"},
};
auto iter = kOpTypeMap.find(op_type);
if (iter == kOpTypeMap.end()) {


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -587,6 +587,7 @@ inline const PrimitivePtr kPrimErrorOnDynamicShapeInput = std::make_shared<Primi

// Other miscellaneous
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend", kSideEffectPropagate);
inline const PrimitivePtr kPrimIOU = std::make_shared<Primitive>("IOU");
inline const PrimitivePtr kPrimReformat = std::make_shared<Primitive>("Reformat");
inline const PrimitivePtr kPrimLoad = std::make_shared<Primitive>("Load");
inline const PrimitivePtr kPrimUpdateState = std::make_shared<Primitive>("UpdateState");


+ 86
- 0
mindspore/core/ops/iou.cc View File

@@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/iou.h"
#include <algorithm>
#include <set>

namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 2, prim_name);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
auto x_shape_ptr = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x_shape_ptr);
auto y_shape_ptr = input_args[1]->BuildShape();
MS_EXCEPTION_IF_NULL(y_shape_ptr);
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr);
auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(y_shape_ptr);
auto x_shp = x_shape_map[kShape];
auto y_shp = y_shape_map[kShape];
if (x_shp.size() != 2 || y_shp.size() != 2) {
MS_EXCEPTION(ValueError) << "For BatchMatMul, input x, y should have the same dimension size and should be greater"
<< "or equal to 3, while x size = " << x_shp.size() << ", y size = " << y_shp.size();
}
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(x_shp[1]), kGreaterEqual, 4, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(y_shp[1]), kGreaterEqual, 4, prim_name);
ShapeVector x_min_shape = x_shape_map[kMinShape];
ShapeVector x_max_shape = x_shape_map[kMaxShape];
ShapeVector y_min_shape = y_shape_map[kMinShape];
ShapeVector y_max_shape = y_shape_map[kMaxShape];
ShapeVector ret_shape;
ShapeVector ret_min_shape;
ShapeVector ret_max_shape;
ret_shape.push_back(y_shp[0]);
ret_shape.push_back(x_shp[0]);
if (y_shape_ptr->IsDynamic()) {
ret_min_shape.push_back(y_min_shape[0]);
ret_max_shape.push_back(y_max_shape[0]);
} else {
ret_min_shape.push_back(y_shp[0]);
ret_max_shape.push_back(y_shp[0]);
}
if (x_shape_ptr->IsDynamic()) {
ret_min_shape.push_back(x_min_shape[0]);
ret_max_shape.push_back(x_max_shape[0]);
} else {
ret_min_shape.push_back(x_shp[0]);
ret_max_shape.push_back(x_shp[0]);
}
return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape);
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr IOUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(IOU, prim::kPrimIOU, IOUInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 38
- 0
mindspore/core/ops/iou.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_IOU_H_
#define MINDSPORE_CORE_OPS_IOU_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
class MS_CORE_API IOU : public PrimitiveC {
public:
IOU() : PrimitiveC(prim::kPrimIOU->name()) { InitIOName({"x,y"}, {"output"}); }
~IOU() = default;
MS_DECLARE_PARENT(IOU, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_IOU_H_

+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -278,6 +278,7 @@ from .bounding_box_decode import _bounding_box_decode_tbe
from .bounding_box_encode import _bounding_box_encode_tbe
from .check_valid import _check_valid_tbe
from .iou import _iou_tbe
from .iou_ds import _iou_ds_tbe
from .arg_max import _arg_max_tbe
from .nms_with_mask import _nms_with_mask_tbe
from .sgd import _sgd_tbe


+ 39
- 0
mindspore/ops/_op_impl/tbe/iou_ds.py View File

@@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Iou op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

iou_op_info = TBERegOp("IOU") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("iou.so") \
.compute_cost(10) \
.kernel_name("iou") \
.partial_flag(True) \
.dynamic_shape(True)\
.attr("mode", "optional", "str", "all", "iou") \
.attr("eps", "optional", "float", "all", "1.0") \
.input(0, "bboxes", False, "required", "all") \
.input(1, "gtboxes", False, "required", "all") \
.output(0, "overlap", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(iou_op_info)
def _iou_ds_tbe():
"""Iou TBE register"""
return

+ 1
- 15
mindspore/ops/operations/other_ops.py View File

@@ -332,7 +332,7 @@ class CheckValid(PrimitiveWithInfer):
return mstype.bool_


class IOU(PrimitiveWithInfer):
class IOU(Primitive):
r"""
Calculates intersection over union for boxes.

@@ -384,20 +384,6 @@ class IOU(PrimitiveWithInfer):
raise KeyError(f"For '{self.name}', only 'iou' or 'iof' are supported, but got 'mode': {mode}.")
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])

def infer_shape(self, anchor_boxes, gt_boxes):
validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name)
validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name)
validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name)
validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name)
iou = [gt_boxes[0], anchor_boxes[0]]
return iou

def infer_dtype(self, anchor_boxes, gt_boxes):
valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name)
validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name)
return anchor_boxes


class Partial(Primitive):
"""


Loading…
Cancel
Save