From aff60404365357375564e291922fa291058bac29 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 19 May 2020 09:59:14 +0800 Subject: [PATCH 1/2] support vm for Argmax --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/arg_max.py | 38 +++++++++++++++++++++++ 3 files changed, 40 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/arg_max.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 9e69cc7445..e758a20b35 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -77,6 +77,7 @@ static std::map tbe_func_adapter_map = { {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, {"pad", "pad_d"}, + {"argmax", "arg_max_d"}, {"space_to_batch", "space_to_batch_d"}, {"batch_to_space", "batch_to_space_d"}, {"resize_bilinear", "resize_bilinear_v2_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index c8aa30f2c2..aa604d18de 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -175,6 +175,7 @@ from .bounding_box_decode import _bounding_box_decode_tbe from .bounding_box_encode import _bounding_box_encode_tbe from .check_valid import _check_valid_tbe from .iou import _iou_tbe +from .arg_max import _arg_max_tbe from .nms_with_mask import nms_with_mask_op_info from .random_choice_with_mask import random_choice_with_mask_op_info from .sgd import sgd_op_info diff --git a/mindspore/ops/_op_impl/tbe/arg_max.py b/mindspore/ops/_op_impl/tbe/arg_max.py new file mode 100644 index 0000000000..dbfe2ad923 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/arg_max.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Argmax op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +arg_max_op_info = TBERegOp("Argmax") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("arg_max_d.so") \ + .compute_cost(10) \ + .kernel_name("arg_max_d") \ + .partial_flag(True) \ + .attr("dimension", "required", "int", "all") \ + .attr("dtype", "optional", "type", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(arg_max_op_info) +def _arg_max_tbe(): + """Argmax TBE register""" + return From 64a287a02953fd4e3ac407b53bbc3b07a34fecc3 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 19 May 2020 15:30:30 +0800 Subject: [PATCH 2/2] fixed arg_max --- mindspore/ops/_op_impl/tbe/arg_max.py | 4 ++-- mindspore/ops/operations/array_ops.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/arg_max.py b/mindspore/ops/_op_impl/tbe/arg_max.py index dbfe2ad923..b91df1cfb6 100644 --- a/mindspore/ops/_op_impl/tbe/arg_max.py +++ b/mindspore/ops/_op_impl/tbe/arg_max.py @@ -23,8 +23,8 @@ arg_max_op_info = TBERegOp("Argmax") \ .compute_cost(10) \ .kernel_name("arg_max_d") \ .partial_flag(True) \ - .attr("dimension", "required", "int", "all") \ - .attr("dtype", "optional", "type", "all") \ + .attr("axis", "required", "int", "all") \ + .attr("output_dtype", "optional", "type", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.I32_Default) \ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b27865c528..e8cdbe5e90 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -951,8 +951,8 @@ class Argmax(PrimitiveWithInfer): Args: axis (int): Axis on which Argmax operation applies. Default: -1. - output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32` and - `mindspore.dtype.int64`. Default: `mindspore.dtype.int64`. + output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`. + Default: `mindspore.dtype.int32`. Inputs: - **input_x** (Tensor) - Input tensor. @@ -961,12 +961,12 @@ class Argmax(PrimitiveWithInfer): Tensor, indices of the max value of input tensor across the axis. Examples: - >>> input_x = Tensor(np.array([2.0, 3.1, 1.2])) + >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) >>> index = P.Argmax(output_type=mindspore.int32)(input_x) """ @prim_attr_register - def __init__(self, axis=-1, output_type=mstype.int64): + def __init__(self, axis=-1, output_type=mstype.int32): """init Argmax""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name)