|
|
|
@@ -5017,7 +5017,7 @@ class Sort(PrimitiveWithInfer): |
|
|
|
return x_dtype, mstype.tensor_type(mstype.int32) |
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookup(PrimitiveWithInfer): |
|
|
|
class EmbeddingLookup(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Returns a slice of input tensor based on the specified indices. |
|
|
|
|
|
|
|
@@ -5063,28 +5063,13 @@ class EmbeddingLookup(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset'], |
|
|
|
outputs=['output']) |
|
|
|
|
|
|
|
def __infer__(self, params, indices, offset): |
|
|
|
def __check__(self, params, indices, offset): |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) |
|
|
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) |
|
|
|
params_shp = params['shape'] |
|
|
|
if len(params_shp) > 2: |
|
|
|
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp)) |
|
|
|
out_shape = indices['shape'] + params_shp[1:] |
|
|
|
if 'max_shape' in indices: |
|
|
|
out_max_shape = indices['max_shape'] + params_shp[1:] |
|
|
|
else: |
|
|
|
out_max_shape = out_shape |
|
|
|
if 'min_shape' in indices: |
|
|
|
out_min_shape = indices['min_shape'] + params_shp[1:] |
|
|
|
else: |
|
|
|
out_min_shape = out_shape |
|
|
|
out = {'shape': out_shape, |
|
|
|
'dtype': params['dtype'], |
|
|
|
'value': None, |
|
|
|
'max_shape': out_max_shape, |
|
|
|
'min_shape': out_min_shape} |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class GatherD(PrimitiveWithInfer): |
|
|
|
|