|
|
|
@@ -1185,6 +1185,7 @@ class Argmax(PrimitiveWithInfer): |
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32) |
|
|
|
>>> index = P.Argmax(output_type=mindspore.int32)(input_x) |
|
|
|
1 |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
@@ -1192,7 +1193,7 @@ class Argmax(PrimitiveWithInfer): |
|
|
|
"""Initialize Argmax""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['output']) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
validator.check_type_same({'output': output_type}, [mstype.int32, mstype.int64], self.name) |
|
|
|
validator.check_type_same({'output': output_type}, [mstype.int32], self.name) |
|
|
|
self.axis = axis |
|
|
|
self.add_prim_attr('output_type', output_type) |
|
|
|
|
|
|
|
@@ -1996,7 +1997,7 @@ class Select(PrimitiveWithInfer): |
|
|
|
and :math:`y`. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N)`. |
|
|
|
- **input_x** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. |
|
|
|
The condition tensor, decides which element is chosen. |
|
|
|
- **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. |
|
|
|
The first input tensor. |
|
|
|
|