| @@ -26,14 +26,36 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| auto prim_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("input args size", SizeToLong(input_args.size()), kGreaterEqual, 1, | |||
| prim_name); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||
| auto x_shape = shape_map[kShape]; | |||
| auto min_shape = shape_map[kMinShape]; | |||
| auto max_shape = shape_map[kMaxShape]; | |||
| int64_t prod = 1; | |||
| size_t size = x_shape.size(); | |||
| for (size_t i = 1; i < size; i++) { | |||
| if (x_shape[i] == -1) { | |||
| prod = -1; | |||
| break; | |||
| } | |||
| prod = prod * x_shape[i]; | |||
| } | |||
| std::vector<int64_t> out_shape = {x_shape[0], prod}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| ShapeVector out_shape = {x_shape[0], prod}; | |||
| if (min_shape.empty() || max_shape.empty()) { | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| int64_t min_prod = 1; | |||
| size_t min_size = min_shape.size(); | |||
| for (size_t i = 1; i < min_size; i++) { | |||
| min_prod = min_prod * min_shape[i]; | |||
| } | |||
| ShapeVector out_min_shape = {min_shape[0], min_prod}; | |||
| int64_t max_prod = 1; | |||
| size_t max_size = max_shape.size(); | |||
| for (size_t i = 1; i < max_size; i++) { | |||
| max_prod = max_prod * max_shape[i]; | |||
| } | |||
| ShapeVector out_max_shape = {max_shape[0], max_prod}; | |||
| return std::make_shared<abstract::Shape>(out_shape, out_min_shape, out_max_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| @@ -49,9 +71,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||
| AbstractBasePtr FlattenInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| auto infer_type = InferShape(primitive, input_args); | |||
| auto infer_shape = InferType(primitive, input_args); | |||
| return abstract::MakeAbstract(infer_type, infer_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameFlatten, Flatten); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Flatten, prim::kPrimFlatten, FlattenInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -310,6 +310,7 @@ from .log1p import _log1p_tbe | |||
| from .resize_bilinear import _resize_bilinear_tbe | |||
| from .resize_bilinear_grad import _resize_bilinear_grad_tbe | |||
| from .flatten import _flatten_tbe | |||
| from .flatten_ds import _flatten_ds_tbe | |||
| from .roi_align import _roi_align_tbe | |||
| from .roi_align_grad import _roi_align_grad_tbe | |||
| from .bounding_box_decode import _bounding_box_decode_tbe | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Flatten op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| flatten_ds_op_info = TBERegOp("Flatten") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("flatten.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("flatten") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("axis", "optional", "int", "all", "1") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(flatten_ds_op_info) | |||
| def _flatten_ds_tbe(): | |||
| """Flatten TBE register""" | |||
| return | |||
| @@ -16,8 +16,7 @@ | |||
| """Operators for nn.""" | |||
| import math | |||
| import operator | |||
| from functools import reduce, partial | |||
| from functools import partial | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import _check_3d_int_or_tuple | |||
| @@ -140,7 +139,7 @@ class CeLU(Primitive): | |||
| self.add_prim_attr('alpha2', self.alpha2) | |||
| class Flatten(PrimitiveWithInfer): | |||
| class Flatten(Primitive): | |||
| r""" | |||
| Flattens a tensor without changing its batch size on the 0-th axis. | |||
| @@ -170,15 +169,6 @@ class Flatten(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| pass | |||
| def infer_shape(self, input_x): | |||
| validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) | |||
| prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) | |||
| return input_x[0], prod | |||
| def infer_dtype(self, input_x): | |||
| validator.check_subclass("input_x", input_x, mstype.tensor, self.name) | |||
| return input_x | |||
| class AdaptiveAvgPool2D(PrimitiveWithInfer): | |||
| r""" | |||