|
|
|
@@ -33,6 +33,7 @@ from .. import signature as sig |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common._decorator import deprecated |
|
|
|
from ...common.parameter import Parameter |
|
|
|
from ...common.tensor import Tensor |
|
|
|
|
|
|
|
@@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck): |
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) |
|
|
|
|
|
|
|
|
|
|
|
def GatherV2(): |
|
|
|
"""Warning: This will be changed later""" |
|
|
|
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") |
|
|
|
return Gather() |
|
|
|
class GatherV2(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Same as operator Gather. GatherV2 will be deprecated in the future. |
|
|
|
Please use Gather instead. |
|
|
|
""" |
|
|
|
#deprecate_new_name = "Gather" |
|
|
|
|
|
|
|
@deprecated("1.1", "Gather", True) |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""Initialize index_select""" |
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) |
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2]) |
|
|
|
|
|
|
|
def __check__(self, params, indices, axis): |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) |
|
|
|
axis_v = axis['value'] |
|
|
|
validator.check_value_type('axis', axis_v, [int], self.name) |
|
|
|
rank = len(params['shape']) |
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) |
|
|
|
|
|
|
|
|
|
|
|
class SparseGatherV2(Gather): |
|
|
|
""" |
|
|
|
|