|
|
|
@@ -861,7 +861,7 @@ class GatherV2(PrimitiveWithCheck): |
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) |
|
|
|
|
|
|
|
|
|
|
|
class SparseGatherV2(Gather): |
|
|
|
class SparseGatherV2(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Returns a slice of input tensor based on the specified indices and axis. |
|
|
|
|
|
|
|
@@ -893,6 +893,22 @@ class SparseGatherV2(Gather): |
|
|
|
[2. 55.]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""Initialize index_select""" |
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) |
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, params, indices, axis): |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) |
|
|
|
axis_v = axis['value'] |
|
|
|
validator.check_value_type('axis', axis_v, [int], self.name) |
|
|
|
rank = len(params['shape']) |
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Padding(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
|