|
|
|
@@ -951,8 +951,8 @@ class Argmax(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Args: |
|
|
|
axis (int): Axis on which Argmax operation applies. Default: -1. |
|
|
|
output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32` and |
|
|
|
`mindspore.dtype.int64`. Default: `mindspore.dtype.int64`. |
|
|
|
output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`. |
|
|
|
Default: `mindspore.dtype.int32`. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - Input tensor. |
|
|
|
@@ -961,12 +961,12 @@ class Argmax(PrimitiveWithInfer): |
|
|
|
Tensor, indices of the max value of input tensor across the axis. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.array([2.0, 3.1, 1.2])) |
|
|
|
>>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) |
|
|
|
>>> index = P.Argmax(output_type=mindspore.int32)(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, axis=-1, output_type=mstype.int64): |
|
|
|
def __init__(self, axis=-1, output_type=mstype.int32): |
|
|
|
"""init Argmax""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['output']) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
|