diff --git a/mindspore/ops/_op_impl/tbe/arg_max.py b/mindspore/ops/_op_impl/tbe/arg_max.py index dbfe2ad923..b91df1cfb6 100644 --- a/mindspore/ops/_op_impl/tbe/arg_max.py +++ b/mindspore/ops/_op_impl/tbe/arg_max.py @@ -23,8 +23,8 @@ arg_max_op_info = TBERegOp("Argmax") \ .compute_cost(10) \ .kernel_name("arg_max_d") \ .partial_flag(True) \ - .attr("dimension", "required", "int", "all") \ - .attr("dtype", "optional", "type", "all") \ + .attr("axis", "required", "int", "all") \ + .attr("output_dtype", "optional", "type", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.I32_Default) \ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b27865c528..e8cdbe5e90 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -951,8 +951,8 @@ class Argmax(PrimitiveWithInfer): Args: axis (int): Axis on which Argmax operation applies. Default: -1. - output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32` and - `mindspore.dtype.int64`. Default: `mindspore.dtype.int64`. + output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`. + Default: `mindspore.dtype.int32`. Inputs: - **input_x** (Tensor) - Input tensor. @@ -961,12 +961,12 @@ class Argmax(PrimitiveWithInfer): Tensor, indices of the max value of input tensor across the axis. Examples: - >>> input_x = Tensor(np.array([2.0, 3.1, 1.2])) + >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) >>> index = P.Argmax(output_type=mindspore.int32)(input_x) """ @prim_attr_register - def __init__(self, axis=-1, output_type=mstype.int64): + def __init__(self, axis=-1, output_type=mstype.int32): """init Argmax""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name)