|
|
|
@@ -589,7 +589,7 @@ class Squeeze(PrimitiveWithInfer): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
class Transpose(PrimitiveWithInfer): |
|
|
|
class Transpose(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Permutes the dimensions of input tensor according to input permutation. |
|
|
|
|
|
|
|
@@ -621,32 +621,13 @@ class Transpose(PrimitiveWithInfer): |
|
|
|
"""Initialize Transpose""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) |
|
|
|
|
|
|
|
def __infer__(self, x, perm): |
|
|
|
x_shape = x['shape'] |
|
|
|
p_value = perm['value'] |
|
|
|
x_type = x['dtype'] |
|
|
|
validator.check_value_type("p_value", p_value, [tuple], self.name) |
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) |
|
|
|
|
|
|
|
if len(x_shape) != len(p_value): |
|
|
|
def check_shape(self, x, perm): |
|
|
|
validator.check_value_type("perm", perm, [tuple], self.name) |
|
|
|
if len(x) != len(perm): |
|
|
|
raise ValueError('The dimension of x and perm must be equal.') |
|
|
|
|
|
|
|
tmp = list(p_value) |
|
|
|
for i, dim in enumerate(p_value): |
|
|
|
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name) |
|
|
|
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name) |
|
|
|
tmp.remove(dim) |
|
|
|
if dim in tmp: |
|
|
|
raise ValueError('The value of perm is wrong.') |
|
|
|
|
|
|
|
out_shapes = [] |
|
|
|
for i in p_value: |
|
|
|
out_shapes.append(x_shape[i]) |
|
|
|
out = {'shape': tuple(out_shapes), |
|
|
|
'dtype': x['dtype'], |
|
|
|
'value': None} |
|
|
|
return out |
|
|
|
|
|
|
|
def check_dtype(self, x, perm): |
|
|
|
validator.check_subclass("x", x, mstype.tensor, self.name) |
|
|
|
|
|
|
|
class Unique(Primitive): |
|
|
|
""" |
|
|
|
|