Browse Source

[feat][assistant][I48OBK] add dynamic shape for Flatten operation

tags/v1.6.0
Victor Chen 4 years ago
parent
commit
0d7a9731c4
4 changed files with 78 additions and 18 deletions
  1. +29
    -6
      mindspore/core/ops/flatten.cc
  2. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  3. +46
    -0
      mindspore/ops/_op_impl/tbe/flatten_ds.py
  4. +2
    -12
      mindspore/ops/operations/nn_ops.py

+ 29
- 6
mindspore/core/ops/flatten.cc View File

@@ -26,14 +26,36 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input args size", SizeToLong(input_args.size()), kGreaterEqual, 1,
prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto x_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
int64_t prod = 1;
size_t size = x_shape.size();
for (size_t i = 1; i < size; i++) {
if (x_shape[i] == -1) {
prod = -1;
break;
}
prod = prod * x_shape[i];
}
std::vector<int64_t> out_shape = {x_shape[0], prod};
return std::make_shared<abstract::Shape>(out_shape);
ShapeVector out_shape = {x_shape[0], prod};
if (min_shape.empty() || max_shape.empty()) {
return std::make_shared<abstract::Shape>(out_shape);
}
int64_t min_prod = 1;
size_t min_size = min_shape.size();
for (size_t i = 1; i < min_size; i++) {
min_prod = min_prod * min_shape[i];
}
ShapeVector out_min_shape = {min_shape[0], min_prod};
int64_t max_prod = 1;
size_t max_size = max_shape.size();
for (size_t i = 1; i < max_size; i++) {
max_prod = max_prod * max_shape[i];
}
ShapeVector out_max_shape = {max_shape[0], max_prod};
return std::make_shared<abstract::Shape>(out_shape, out_min_shape, out_max_shape);
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
@@ -49,9 +71,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &

AbstractBasePtr FlattenInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
auto infer_type = InferShape(primitive, input_args);
auto infer_shape = InferType(primitive, input_args);
return abstract::MakeAbstract(infer_type, infer_shape);
}
REGISTER_PRIMITIVE_C(kNameFlatten, Flatten);
REGISTER_PRIMITIVE_EVAL_IMPL(Flatten, prim::kPrimFlatten, FlattenInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -310,6 +310,7 @@ from .log1p import _log1p_tbe
from .resize_bilinear import _resize_bilinear_tbe
from .resize_bilinear_grad import _resize_bilinear_grad_tbe
from .flatten import _flatten_tbe
from .flatten_ds import _flatten_ds_tbe
from .roi_align import _roi_align_tbe
from .roi_align_grad import _roi_align_grad_tbe
from .bounding_box_decode import _bounding_box_decode_tbe


+ 46
- 0
mindspore/ops/_op_impl/tbe/flatten_ds.py View File

@@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Flatten op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

flatten_ds_op_info = TBERegOp("Flatten") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("flatten.so") \
.compute_cost(10) \
.kernel_name("flatten") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("axis", "optional", "int", "all", "1") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(flatten_ds_op_info)
def _flatten_ds_tbe():
"""Flatten TBE register"""
return

+ 2
- 12
mindspore/ops/operations/nn_ops.py View File

@@ -16,8 +16,7 @@
"""Operators for nn."""

import math
import operator
from functools import reduce, partial
from functools import partial
import numpy as np
from mindspore import log as logger
from mindspore._checkparam import _check_3d_int_or_tuple
@@ -140,7 +139,7 @@ class CeLU(Primitive):
self.add_prim_attr('alpha2', self.alpha2)


class Flatten(PrimitiveWithInfer):
class Flatten(Primitive):
r"""
Flattens a tensor without changing its batch size on the 0-th axis.

@@ -170,15 +169,6 @@ class Flatten(PrimitiveWithInfer):
def __init__(self):
pass

def infer_shape(self, input_x):
validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:])
return input_x[0], prod

def infer_dtype(self, input_x):
validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
return input_x


class AdaptiveAvgPool2D(PrimitiveWithInfer):
r"""


Loading…
Cancel
Save