|
|
|
@@ -807,7 +807,7 @@ class GatherV2(PrimitiveWithCheck): |
|
|
|
def __check__(self, params, indices, axis): |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.tensor, mstype.int_], self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.int_], self.name) |
|
|
|
|
|
|
|
|
|
|
|
class SparseGatherV2(GatherV2): |
|
|
|
|