Merge pull request !1242 from jiangjinsheng/vm_arg_maxtags/v0.3.0-alpha
| @@ -77,6 +77,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, | {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, | ||||
| {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, | {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, | ||||
| {"pad", "pad_d"}, | {"pad", "pad_d"}, | ||||
| {"argmax", "arg_max_d"}, | |||||
| {"space_to_batch", "space_to_batch_d"}, | {"space_to_batch", "space_to_batch_d"}, | ||||
| {"batch_to_space", "batch_to_space_d"}, | {"batch_to_space", "batch_to_space_d"}, | ||||
| {"resize_bilinear", "resize_bilinear_v2_d"}, | {"resize_bilinear", "resize_bilinear_v2_d"}, | ||||
| @@ -175,6 +175,7 @@ from .bounding_box_decode import _bounding_box_decode_tbe | |||||
| from .bounding_box_encode import _bounding_box_encode_tbe | from .bounding_box_encode import _bounding_box_encode_tbe | ||||
| from .check_valid import _check_valid_tbe | from .check_valid import _check_valid_tbe | ||||
| from .iou import _iou_tbe | from .iou import _iou_tbe | ||||
| from .arg_max import _arg_max_tbe | |||||
| from .nms_with_mask import nms_with_mask_op_info | from .nms_with_mask import nms_with_mask_op_info | ||||
| from .random_choice_with_mask import random_choice_with_mask_op_info | from .random_choice_with_mask import random_choice_with_mask_op_info | ||||
| from .sgd import sgd_op_info | from .sgd import sgd_op_info | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Argmax op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| arg_max_op_info = TBERegOp("Argmax") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("arg_max_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("arg_max_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("axis", "required", "int", "all") \ | |||||
| .attr("output_dtype", "optional", "type", "all") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(arg_max_op_info) | |||||
| def _arg_max_tbe(): | |||||
| """Argmax TBE register""" | |||||
| return | |||||
| @@ -951,8 +951,8 @@ class Argmax(PrimitiveWithInfer): | |||||
| Args: | Args: | ||||
| axis (int): Axis on which Argmax operation applies. Default: -1. | axis (int): Axis on which Argmax operation applies. Default: -1. | ||||
| output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32` and | |||||
| `mindspore.dtype.int64`. Default: `mindspore.dtype.int64`. | |||||
| output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`. | |||||
| Default: `mindspore.dtype.int32`. | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - Input tensor. | - **input_x** (Tensor) - Input tensor. | ||||
| @@ -961,12 +961,12 @@ class Argmax(PrimitiveWithInfer): | |||||
| Tensor, indices of the max value of input tensor across the axis. | Tensor, indices of the max value of input tensor across the axis. | ||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor(np.array([2.0, 3.1, 1.2])) | |||||
| >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) | |||||
| >>> index = P.Argmax(output_type=mindspore.int32)(input_x) | >>> index = P.Argmax(output_type=mindspore.int32)(input_x) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, axis=-1, output_type=mstype.int64): | |||||
| def __init__(self, axis=-1, output_type=mstype.int32): | |||||
| """init Argmax""" | """init Argmax""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||