| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/sort.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| const int64_t input_num = 1; | |||
| (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num, | |||
| prim_name); | |||
| CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0); | |||
| auto x = input_args[0]->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| auto shape_element = x->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_element); | |||
| return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{shape_element, shape_element}); | |||
| } | |||
| TuplePtr SortInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| auto infer_type = input_args[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(infer_type); | |||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; | |||
| auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, prim->name()); | |||
| std::vector<TypePtr> type_tuple; | |||
| type_tuple.push_back(type); | |||
| type_tuple.push_back(kInt32); | |||
| return std::make_shared<Tuple>(type_tuple); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto infertype = SortInferType(primitive, input_args); | |||
| auto infershape = SortInferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(infershape, infertype); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Sort, prim::kPrimSort, SortInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_SORT_H_ | |||
| #define MINDSPORE_CORE_OPS_SORT_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameSort = "Sort"; | |||
| class Sort : public PrimitiveC { | |||
| public: | |||
| Sort() : PrimitiveC(kNameSort) { InitIOName({"x"}, {"y1", "y2"}); } | |||
| ~Sort() = default; | |||
| MS_DECLARE_PARENT(Sort, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_SORT_H_ | |||
| @@ -266,6 +266,7 @@ from .depth_to_space import _depth_to_space_tbe | |||
| from .space_to_depth import _space_to_depth_tbe | |||
| from .extract_image_patches import _extract_image_patches_tbe | |||
| from .sort import _sort_tbe | |||
| from .sort_ds import _sort_ds_tbe | |||
| from .floor import _floor_tbe | |||
| from .ceil import _ceil_tbe | |||
| from .log1p import _log1p_tbe | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Sort op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sort_op_info = TBERegOp("Sort") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sort.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sort") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("axis", "optional", "int", "all", "-1") \ | |||
| .attr("descending", "optional", "bool", "all", "false") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y1", False, "required", "all") \ | |||
| .output(1, "y2", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(sort_op_info) | |||
| def _sort_ds_tbe(): | |||
| """Sort TBE register""" | |||
| return | |||
| @@ -5868,7 +5868,7 @@ class TransShape(PrimitiveWithInfer): | |||
| 'value': None} | |||
| class Sort(PrimitiveWithInfer): | |||
| class Sort(Primitive): | |||
| """ | |||
| Sorts the elements of the input tensor along a given dimension in ascending order by value. | |||
| @@ -5877,6 +5877,10 @@ class Sort(PrimitiveWithInfer): | |||
| descending (bool): Controls the sorting order. If descending is True then the elements | |||
| are sorted in descending order by value. Default: False. | |||
| .. warning:: | |||
| Currently, only the data type of Float16 is supported. If use Float32, it may cause loss | |||
| of accuracy. | |||
| Inputs: | |||
| - **x** (Tensor) - The input to sort, with float16 or float32 data type. | |||
| The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. | |||
| @@ -5906,19 +5910,12 @@ class Sort(PrimitiveWithInfer): | |||
| [2, 0, 1], | |||
| [0, 1, 2]])) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, axis=-1, descending=False): | |||
| """Initialize Sort""" | |||
| self.axis = validator.check_value_type("axis", axis, [int], self.name) | |||
| self.descending = validator.check_value_type("descending", descending, [bool], self.name) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape, x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name) | |||
| return x_dtype, mstype.tensor_type(mstype.int32) | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y1', 'y2']) | |||
| class EmbeddingLookup(PrimitiveWithCheck): | |||