|
|
|
@@ -1185,7 +1185,8 @@ class ArgMaxWithValue(PrimitiveWithInfer): |
|
|
|
"""init ArgMaxWithValue""" |
|
|
|
self.axis = axis |
|
|
|
self.keep_dims = keep_dims |
|
|
|
_check_infer_attr_reduce(axis, keep_dims, self.name) |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], self.name) |
|
|
|
validator.check_value_type('axis', axis, [int], self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
axis = self.axis |
|
|
|
|