|
|
|
@@ -24,7 +24,7 @@ import itertools |
|
|
|
import numbers |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..._checkparam import ParamValidator as validator |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common.tensor import Tensor |
|
|
|
@@ -32,12 +32,12 @@ from ..operations.math_ops import _infer_shape_reduce |
|
|
|
from .._utils import _get_concat_offset |
|
|
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register |
|
|
|
|
|
|
|
def _check_infer_attr_reduce(axis, keep_dims): |
|
|
|
validator.check_type('keep_dims', keep_dims, [bool]) |
|
|
|
validator.check_type('axis', axis, [int, tuple]) |
|
|
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name): |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) |
|
|
|
validator.check_value_type('axis', axis, [int, tuple], prim_name) |
|
|
|
if isinstance(axis, tuple): |
|
|
|
for index, value in enumerate(axis): |
|
|
|
validator.check_type('axis[%d]' % index, value, [int]) |
|
|
|
validator.check_value_type('axis[%d]' % index, value, [int], prim_name) |
|
|
|
|
|
|
|
|
|
|
|
class ExpandDims(PrimitiveWithInfer): |
|
|
|
@@ -74,13 +74,11 @@ class ExpandDims(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) |
|
|
|
|
|
|
|
def __infer__(self, x, axis): |
|
|
|
validator.check_subclass("input_x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) |
|
|
|
x_shape = list(x['shape']) |
|
|
|
axis_v = axis['value'] |
|
|
|
rank = len(x_shape) |
|
|
|
validator.check_const_input('axis', axis_v) |
|
|
|
validator.check_type("axis", axis_v, [int]) |
|
|
|
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH) |
|
|
|
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name) |
|
|
|
if axis_v < 0: |
|
|
|
axis_v = rank + 1 + axis_v |
|
|
|
x_shape.insert(axis_v, 1) |
|
|
|
@@ -110,7 +108,7 @@ class DType(PrimitiveWithInfer): |
|
|
|
"""init DType""" |
|
|
|
|
|
|
|
def __infer__(self, x): |
|
|
|
validator.check_subclass("input_x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) |
|
|
|
out = {'shape': (), |
|
|
|
'dtype': mstype.type_type, |
|
|
|
'value': x['dtype'].element_type()} |
|
|
|
@@ -144,19 +142,17 @@ class SameTypeShape(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def __call__(self, x, y): |
|
|
|
"""run in PyNative mode""" |
|
|
|
if x.dtype() != y.dtype(): |
|
|
|
raise TypeError(f"The {x} and {y} should be same dtype.") |
|
|
|
if x.shape() != y.shape(): |
|
|
|
raise TypeError(f"The {x} and {y} should have same shape.") |
|
|
|
validator.check_subclass('x', x.dtype(), mstype.tensor, self.name) |
|
|
|
validator.check_subclass('y', y.dtype(), mstype.tensor, self.name) |
|
|
|
validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError) |
|
|
|
validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
def __infer__(self, x, y): |
|
|
|
if x['dtype'] != y['dtype']: |
|
|
|
raise TypeError(f"The {x} and {y} should be same dtype," |
|
|
|
f" but got {x['dtype']} {y['dtype']}.") |
|
|
|
if x['shape'] != y['shape']: |
|
|
|
raise ValueError(f"The {x} and {y} should be same shape," |
|
|
|
f" but got {x['shape']} {y['shape']}.") |
|
|
|
validator.check_subclass('x', x['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_subclass('y', y['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError) |
|
|
|
validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -191,8 +187,8 @@ class Cast(PrimitiveWithInfer): |
|
|
|
src_type = x['dtype'] |
|
|
|
dst_type = t['value'] |
|
|
|
|
|
|
|
validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number]) |
|
|
|
validator.check_subclass("type", dst_type, mstype.number, with_type_of=False) |
|
|
|
validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number], self.name) |
|
|
|
validator.check_subclass("type", dst_type, mstype.number, self.name) |
|
|
|
|
|
|
|
if isinstance(src_type, type(mstype.tensor)): |
|
|
|
src_type = x['dtype'].element_type() |
|
|
|
@@ -238,8 +234,8 @@ class IsSubClass(PrimitiveWithInfer): |
|
|
|
sub_type_t = sub_type['value'] |
|
|
|
type_v = type_['value'] |
|
|
|
|
|
|
|
validator.check_type("sub_type", sub_type_t, [mstype.Type]) |
|
|
|
validator.check_type("type_", type_v, [mstype.Type]) |
|
|
|
validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name) |
|
|
|
validator.check_value_type("type_", type_v, [mstype.Type], self.name) |
|
|
|
|
|
|
|
value = mstype.issubclass_(sub_type_t, type_v) |
|
|
|
|
|
|
|
@@ -273,8 +269,8 @@ class IsInstance(PrimitiveWithInfer): |
|
|
|
sub_type_t = inst['dtype'] |
|
|
|
type_v = type_['value'] |
|
|
|
|
|
|
|
validator.check_const_input("inst", inst['value']) |
|
|
|
validator.check_type("type_", type_v, [mstype.Type]) |
|
|
|
validator.check_const_input("inst", inst['value'], self.name) |
|
|
|
validator.check_value_type("type_", type_v, [mstype.Type], self.name) |
|
|
|
|
|
|
|
value = mstype.issubclass_(sub_type_t, type_v) |
|
|
|
|
|
|
|
@@ -316,14 +312,13 @@ class Reshape(PrimitiveWithInfer): |
|
|
|
def __infer__(self, x, shape): |
|
|
|
shape_v = shape['value'] |
|
|
|
x_shp = x['shape'] |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_const_input("shape", shape_v) |
|
|
|
validator.check_type("shape", shape_v, [tuple]) |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name) |
|
|
|
shape_v = list(shape_v) |
|
|
|
neg_index = -1 |
|
|
|
dim_prod = 1 |
|
|
|
for i, shp_i in enumerate(shape_v): |
|
|
|
validator.check_type("shape[%d]" % i, shp_i, [int]) |
|
|
|
validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name) |
|
|
|
if shp_i == -1: |
|
|
|
if neg_index != -1: |
|
|
|
raise ValueError(f'The shape can only has one -1 at most, but {shape_v}.') |
|
|
|
@@ -332,7 +327,7 @@ class Reshape(PrimitiveWithInfer): |
|
|
|
dim_prod *= shp_i |
|
|
|
arr_prod = np.prod(x_shp) |
|
|
|
if dim_prod <= 0 or arr_prod % dim_prod != 0: |
|
|
|
raise ValueError(f'The product of shape should > 0 and' |
|
|
|
raise ValueError(f'For \'{self.name}\' the product of shape should > 0 and' |
|
|
|
f' can be divided by prod of input {arr_prod},' |
|
|
|
f' but shape {shape}, product of shape {dim_prod}.') |
|
|
|
|
|
|
|
@@ -340,7 +335,7 @@ class Reshape(PrimitiveWithInfer): |
|
|
|
shape_v[neg_index] = int(arr_prod / dim_prod) |
|
|
|
dim_prod *= shape_v[neg_index] |
|
|
|
if dim_prod != arr_prod: |
|
|
|
raise ValueError(f'The shape arg for reshape must match array''s size' |
|
|
|
raise ValueError(f'For \'{self.name}\' The shape arg for reshape must match array''s size' |
|
|
|
f' input shape {arr_prod}, shape {dim_prod}.') |
|
|
|
|
|
|
|
value = None |
|
|
|
@@ -406,10 +401,10 @@ class Squeeze(PrimitiveWithInfer): |
|
|
|
def __init__(self, axis=()): |
|
|
|
"""init Squeeze""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['output']) |
|
|
|
validator.check_type('axis', axis, [int, tuple]) |
|
|
|
validator.check_value_type('axis', axis, [int, tuple], self.name) |
|
|
|
if isinstance(axis, tuple): |
|
|
|
for item in axis: |
|
|
|
validator.check_type("item", item, [int]) |
|
|
|
for idx, item in enumerate(axis): |
|
|
|
validator.check_value_type("axis[%d]" % idx, item, [int], self.name) |
|
|
|
else: |
|
|
|
self.axis = (axis,) |
|
|
|
self.add_prim_attr("axis", (axis,)) |
|
|
|
@@ -422,14 +417,14 @@ class Squeeze(PrimitiveWithInfer): |
|
|
|
ret = [d for d in x_shape if d != 1] |
|
|
|
else: |
|
|
|
for a in axis: |
|
|
|
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH) |
|
|
|
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name) |
|
|
|
if x_shape[a] != 1: |
|
|
|
raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.') |
|
|
|
ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)] |
|
|
|
return ret |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("x", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("x", x_dtype, mstype.tensor, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -467,14 +462,13 @@ class Transpose(PrimitiveWithInfer): |
|
|
|
if len(x_shape) != len(p_value): |
|
|
|
raise ValueError('The dimension of x and perm must be equal.') |
|
|
|
|
|
|
|
validator.check_const_input("perm", p_value) |
|
|
|
validator.check_type("p_value", p_value, [tuple]) |
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor) |
|
|
|
validator.check_value_type("p_value", p_value, [tuple], self.name) |
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) |
|
|
|
|
|
|
|
tmp = list(p_value) |
|
|
|
for i, dim in enumerate(p_value): |
|
|
|
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE) |
|
|
|
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT) |
|
|
|
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name) |
|
|
|
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name) |
|
|
|
tmp.remove(dim) |
|
|
|
if dim in tmp: |
|
|
|
raise ValueError('The value of perm is wrong.') |
|
|
|
@@ -517,15 +511,13 @@ class GatherV2(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) |
|
|
|
|
|
|
|
def __infer__(self, params, indices, axis): |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("indices", indices['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("axis", axis['dtype'], mstype.int_) |
|
|
|
validator.check_typename("element of indices", indices['dtype'], mstype.int_type) |
|
|
|
validator.check_const_input("axis", axis['value']) |
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) |
|
|
|
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) |
|
|
|
axis_v = axis['value'] |
|
|
|
params_shp = params['shape'] |
|
|
|
rank = len(params_shp) |
|
|
|
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT) |
|
|
|
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) |
|
|
|
if axis_v < 0: |
|
|
|
axis_v += rank |
|
|
|
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] |
|
|
|
@@ -564,19 +556,20 @@ class Split(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, axis=0, output_num=1): |
|
|
|
"""init Split""" |
|
|
|
validator.check_type("axis", axis, [int]) |
|
|
|
validator.check_type("output_num", output_num, [int]) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
validator.check_value_type("output_num", output_num, [int], self.name) |
|
|
|
self.axis = axis |
|
|
|
self.output_num = output_num |
|
|
|
|
|
|
|
def __infer__(self, x): |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) |
|
|
|
x_shape = list(x['shape']) |
|
|
|
dim = len(x_shape) |
|
|
|
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) |
|
|
|
validator.check_integer("output_num", self.output_num, 0, Rel.GT) |
|
|
|
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) |
|
|
|
validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name) |
|
|
|
output_valid_check = x_shape[self.axis] % self.output_num |
|
|
|
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ) |
|
|
|
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ, |
|
|
|
self.name) |
|
|
|
x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) |
|
|
|
out_shapes = [] |
|
|
|
out_dtypes = [] |
|
|
|
@@ -615,7 +608,7 @@ class Rank(PrimitiveWithInfer): |
|
|
|
"""init Rank""" |
|
|
|
|
|
|
|
def __infer__(self, x): |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) |
|
|
|
out = {'shape': None, |
|
|
|
'dtype': None, |
|
|
|
'value': len(x['shape'])} |
|
|
|
@@ -647,15 +640,14 @@ class TruncatedNormal(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, seed=0, dtype=mstype.float32): |
|
|
|
"""init TruncatedNormal""" |
|
|
|
validator.check_type('seed', seed, [int]) |
|
|
|
validator.check_typename('dtype', dtype, mstype.number_type) |
|
|
|
validator.check_value_type('seed', seed, [int], self.name) |
|
|
|
validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name) |
|
|
|
|
|
|
|
def __infer__(self, shape): |
|
|
|
shape_value = shape['value'] |
|
|
|
validator.check_const_input("shape", shape_value) |
|
|
|
validator.check_type("shape", shape_value, [tuple]) |
|
|
|
validator.check_value_type("shape", shape_value, [tuple], self.name) |
|
|
|
for i, value in enumerate(shape_value): |
|
|
|
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT) |
|
|
|
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT, self.name) |
|
|
|
out = {'shape': shape_value, |
|
|
|
'dtype': mstype.tensor_type(self.dtype), |
|
|
|
'value': None} |
|
|
|
@@ -687,7 +679,7 @@ class Size(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def __infer__(self, x): |
|
|
|
size = 1 |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) |
|
|
|
shp = x['shape'] |
|
|
|
if not shp: |
|
|
|
size = 0 |
|
|
|
@@ -723,25 +715,20 @@ class Fill(PrimitiveWithInfer): |
|
|
|
"""init Fill""" |
|
|
|
|
|
|
|
def __infer__(self, dtype, dims, x): |
|
|
|
validator.check_const_input("type", dtype['value']) |
|
|
|
validator.check_const_input("shape", dims['value']) |
|
|
|
validator.check_const_input("value", x['value']) |
|
|
|
validator.check_type("shape", dims['value'], [tuple]) |
|
|
|
validator.check_type("value", x['value'], [numbers.Number, bool]) |
|
|
|
for item in dims['value']: |
|
|
|
validator.check_type("item", item, [int]) |
|
|
|
validator.check_integer("item", item, 0, Rel.GT) |
|
|
|
x_dtype = dtype['value'] |
|
|
|
validator.check_value_type("shape", dims['value'], [tuple], self.name) |
|
|
|
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) |
|
|
|
for idx, item in enumerate(dims['value']): |
|
|
|
validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name) |
|
|
|
valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64, |
|
|
|
mstype.uint8, mstype.uint32, mstype.uint64, |
|
|
|
mstype.float16, mstype.float32, mstype.float64] |
|
|
|
validator.check_typename("value", x_dtype, valid_types) |
|
|
|
x_nptype = mstype.dtype_to_nptype(x_dtype) |
|
|
|
validator.check_type_same({"value": dtype['value']}, valid_types, self.name) |
|
|
|
x_nptype = mstype.dtype_to_nptype(dtype['value']) |
|
|
|
ret = np.full(dims['value'], x['value'], x_nptype) |
|
|
|
out = { |
|
|
|
'value': Tensor(ret), |
|
|
|
'shape': dims['value'], |
|
|
|
'dtype': x_dtype, |
|
|
|
'dtype': x['dtype'], |
|
|
|
} |
|
|
|
return out |
|
|
|
|
|
|
|
@@ -772,8 +759,7 @@ class OnesLike(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("x", x_dtype, mstype.tensor) |
|
|
|
validator.check_typename('x_dtype', x_dtype, mstype.number_type + (mstype.bool_,)) |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -804,8 +790,7 @@ class ZerosLike(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("x", x_dtype, mstype.tensor) |
|
|
|
validator.check_typename('x_dtype', x_dtype, mstype.number_type + (mstype.bool_,)) |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -830,14 +815,13 @@ class TupleToArray(PrimitiveWithInfer): |
|
|
|
"""init TupleToArray""" |
|
|
|
|
|
|
|
def infer_value(self, x): |
|
|
|
validator.check_const_input("x", x) |
|
|
|
validator.check_type("x", x, [tuple]) |
|
|
|
validator.check("size of x", len(x), '', 0, Rel.GT) |
|
|
|
validator.check_value_type("x", x, [tuple], self.name) |
|
|
|
validator.check("size of x", len(x), '', 0, Rel.GT, self.name) |
|
|
|
dtype = type(x[0]) |
|
|
|
for i, item in enumerate(x): |
|
|
|
validator.check_type(f"x[{i}]", item, [numbers.Number]) |
|
|
|
validator.check_value_type(f"x[{i}]", item, [numbers.Number], self.name) |
|
|
|
if not all(isinstance(item, dtype) for item in x): |
|
|
|
raise TypeError("All elements of input x must be have same type.") |
|
|
|
raise TypeError("For \'{self.name}\' all elements of input x must be have same type.") |
|
|
|
if isinstance(x[0], int): |
|
|
|
ret = np.array(x, np.int32) |
|
|
|
else: |
|
|
|
@@ -867,8 +851,7 @@ class ScalarToArray(PrimitiveWithInfer): |
|
|
|
pass |
|
|
|
|
|
|
|
def infer_value(self, x): |
|
|
|
validator.check_const_input("x", x) |
|
|
|
validator.check_type("x", x, [int, float]) |
|
|
|
validator.check_value_type("x", x, [int, float], self.name) |
|
|
|
if isinstance(x, int): |
|
|
|
ret = np.array(x, np.int32) |
|
|
|
else: |
|
|
|
@@ -899,9 +882,8 @@ class ScalarToTensor(PrimitiveWithInfer): |
|
|
|
pass |
|
|
|
|
|
|
|
def infer_value(self, x, dtype=mstype.float32): |
|
|
|
validator.check_const_input("x", x) |
|
|
|
validator.check_type("x", x, [int, float]) |
|
|
|
validator.check_subclass("dtype", dtype, mstype.number, with_type_of=False) |
|
|
|
validator.check_value_type("x", x, [int, float], self.name) |
|
|
|
validator.check_subclass("dtype", dtype, mstype.number, self.name) |
|
|
|
data_type = mstype.dtype_to_nptype(dtype) |
|
|
|
return Tensor(np.array(x, data_type)) |
|
|
|
|
|
|
|
@@ -943,15 +925,14 @@ class InvertPermutation(PrimitiveWithInfer): |
|
|
|
def __infer__(self, x): |
|
|
|
x_shp = x['shape'] |
|
|
|
x_value = x['value'] |
|
|
|
validator.check_const_input("shape", x_shp) |
|
|
|
validator.check_type("shape", x_shp, [tuple]) |
|
|
|
validator.check_value_type("shape", x_shp, [tuple], self.name) |
|
|
|
z = [x_value[i] for i in range(len(x_value))] |
|
|
|
z.sort() |
|
|
|
|
|
|
|
y = [None]*len(x_value) |
|
|
|
for i, value in enumerate(x_value): |
|
|
|
validator.check_type("input[%d]" % i, value, [int]) |
|
|
|
validator.check(f'value', z[i], f'index', i) |
|
|
|
validator.check_value_type("input[%d]" % i, value, [int], self.name) |
|
|
|
validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name) |
|
|
|
y[value] = i |
|
|
|
z.append(value) |
|
|
|
return {'shape': x_shp, |
|
|
|
@@ -986,8 +967,8 @@ class Argmax(PrimitiveWithInfer): |
|
|
|
def __init__(self, axis=-1, output_type=mstype.int64): |
|
|
|
"""init Argmax""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['output']) |
|
|
|
validator.check_type("axis", axis, [int]) |
|
|
|
validator.check_typename('output_type', output_type, [mstype.int32, mstype.int64]) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
validator.check_type_same({'output': output_type}, [mstype.int32, mstype.int64], self.name) |
|
|
|
self.axis = axis |
|
|
|
self.add_prim_attr('output_type', output_type) |
|
|
|
|
|
|
|
@@ -996,14 +977,13 @@ class Argmax(PrimitiveWithInfer): |
|
|
|
if axis is None: |
|
|
|
axis = 0 |
|
|
|
x_rank = len(x_shape) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) |
|
|
|
axis = axis + x_rank if axis < 0 else axis |
|
|
|
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] |
|
|
|
return ouput_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor) |
|
|
|
validator.check_typename('input_x', x_dtype, [mstype.float32, mstype.float16]) |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) |
|
|
|
return mstype.tensor_type(self.output_type) |
|
|
|
|
|
|
|
|
|
|
|
@@ -1035,7 +1015,7 @@ class Argmin(PrimitiveWithInfer): |
|
|
|
def __init__(self, axis=-1, output_type=mstype.int64): |
|
|
|
"""init Argmin""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['output']) |
|
|
|
validator.check_type("axis", axis, [int]) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
self.axis = axis |
|
|
|
self.add_prim_attr('output_type', output_type) |
|
|
|
|
|
|
|
@@ -1044,13 +1024,13 @@ class Argmin(PrimitiveWithInfer): |
|
|
|
if axis is None: |
|
|
|
axis = 0 |
|
|
|
x_rank = len(x_shape) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) |
|
|
|
axis = axis + x_rank if axis < 0 else axis |
|
|
|
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] |
|
|
|
return ouput_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) |
|
|
|
return mstype.tensor_type(self.output_type) |
|
|
|
|
|
|
|
|
|
|
|
@@ -1087,17 +1067,17 @@ class ArgMaxWithValue(PrimitiveWithInfer): |
|
|
|
"""init ArgMaxWithValue""" |
|
|
|
self.axis = axis |
|
|
|
self.keep_dims = keep_dims |
|
|
|
_check_infer_attr_reduce(axis, keep_dims) |
|
|
|
_check_infer_attr_reduce(axis, keep_dims, self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
axis = self.axis |
|
|
|
x_rank = len(x_shape) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) |
|
|
|
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) |
|
|
|
return ouput_shape, ouput_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) |
|
|
|
return mstype.tensor_type(mstype.int32), x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -1133,17 +1113,17 @@ class ArgMinWithValue(PrimitiveWithInfer): |
|
|
|
"""init ArgMinWithValue""" |
|
|
|
self.axis = axis |
|
|
|
self.keep_dims = keep_dims |
|
|
|
_check_infer_attr_reduce(axis, keep_dims) |
|
|
|
_check_infer_attr_reduce(axis, keep_dims, self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
axis = self.axis |
|
|
|
x_rank = len(x_shape) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) |
|
|
|
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) |
|
|
|
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) |
|
|
|
return ouput_shape, ouput_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) |
|
|
|
return mstype.tensor_type(mstype.int32), x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -1183,13 +1163,11 @@ class Tile(PrimitiveWithInfer): |
|
|
|
def __infer__(self, x, multiples): |
|
|
|
multiples_v = multiples['value'] |
|
|
|
x_shp = x['shape'] |
|
|
|
validator.check_const_input("shape", multiples_v) |
|
|
|
validator.check_type("shape", multiples_v, [tuple]) |
|
|
|
validator.check_value_type("shape", multiples_v, [tuple], self.name) |
|
|
|
for i, multiple in enumerate(multiples_v): |
|
|
|
validator.check_type("multiples[%d]" % i, multiple, [int]) |
|
|
|
validator.check_typename('x', x['dtype'], |
|
|
|
[mstype.int16, mstype.int32, mstype.bool_, |
|
|
|
mstype.float16, mstype.float32]) |
|
|
|
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) |
|
|
|
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32] |
|
|
|
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name) |
|
|
|
len_sub = len(multiples_v) - len(x_shp) |
|
|
|
multiples_w = None |
|
|
|
if len_sub == 0: |
|
|
|
@@ -1199,7 +1177,8 @@ class Tile(PrimitiveWithInfer): |
|
|
|
x_shp.insert(0, 1) |
|
|
|
multiples_w = multiples_v |
|
|
|
elif len_sub < 0: |
|
|
|
raise ValueError("The length of multiples can not be smaller than the length of dimension in input_x.") |
|
|
|
raise ValueError(f'For \'{self.name}\' the length of multiples can not be smaller than ' |
|
|
|
f'the length of dimension in input_x.') |
|
|
|
for i, a in enumerate(multiples_w): |
|
|
|
x_shp[i] *= a |
|
|
|
value = None |
|
|
|
@@ -1246,23 +1225,23 @@ class UnsortedSegmentSum(PrimitiveWithInfer): |
|
|
|
def __infer__(self, x, segment_ids, num_segments): |
|
|
|
x_type = x['dtype'] |
|
|
|
x_shp = x['shape'] |
|
|
|
validator.check_subclass("input_x", x_type, mstype.tensor) |
|
|
|
validator.check_type("x_shape", x_shp, [list]) |
|
|
|
validator.check_subclass("input_x", x_type, mstype.tensor, self.name) |
|
|
|
validator.check_value_type("x_shape", x_shp, [list], self.name) |
|
|
|
x_shp_len = len(x_shp) |
|
|
|
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT) |
|
|
|
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name) |
|
|
|
segment_ids_shp = segment_ids['shape'] |
|
|
|
segment_ids_type = segment_ids['dtype'] |
|
|
|
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor) |
|
|
|
validator.check_type("segment_ids", segment_ids_shp, [list]) |
|
|
|
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name) |
|
|
|
validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name) |
|
|
|
segment_ids_shp_len = len(segment_ids_shp) |
|
|
|
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT) |
|
|
|
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name) |
|
|
|
validator.check(f'rank of input_x', len(x_shp), |
|
|
|
'rank of segments_id', len(segment_ids_shp), Rel.GE) |
|
|
|
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name) |
|
|
|
for i, value in enumerate(segment_ids_shp): |
|
|
|
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i]) |
|
|
|
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) |
|
|
|
num_segments_v = num_segments['value'] |
|
|
|
validator.check_type('num_segments', num_segments_v, [int]) |
|
|
|
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT) |
|
|
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name) |
|
|
|
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) |
|
|
|
shp = [num_segments_v] |
|
|
|
shp += x_shp[segment_ids_shp_len:] |
|
|
|
out = {'shape': shp, |
|
|
|
@@ -1306,7 +1285,7 @@ class Concat(PrimitiveWithInfer): |
|
|
|
def __init__(self, axis=0): |
|
|
|
"""init Tile""" |
|
|
|
self.__setattr_flag__ = True |
|
|
|
validator.check_type("axis", axis, [int]) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
|
|
|
|
def __infer__(self, input_x): |
|
|
|
axis = self.axis |
|
|
|
@@ -1323,25 +1302,25 @@ class Concat(PrimitiveWithInfer): |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def _get_pack_shape(x_shape, x_type, axis): |
|
|
|
def _get_pack_shape(x_shape, x_type, axis, prim_name): |
|
|
|
"""for pack output shape""" |
|
|
|
validator.check_type("shape", x_shape, [tuple, list]) |
|
|
|
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) |
|
|
|
validator.check_subclass("shape0", x_type[0], mstype.tensor) |
|
|
|
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) |
|
|
|
validator.check_value_type("shape", x_shape, [tuple, list], prim_name) |
|
|
|
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT, prim_name) |
|
|
|
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) |
|
|
|
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name) |
|
|
|
rank_base = len(x_shape[0]) |
|
|
|
N = len(x_shape) |
|
|
|
out_shape = x_shape[0] |
|
|
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) |
|
|
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) |
|
|
|
if axis < 0: |
|
|
|
axis = axis + rank_base + 1 |
|
|
|
for i in range(1, N): |
|
|
|
v = x_shape[i] |
|
|
|
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base) |
|
|
|
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) |
|
|
|
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name) |
|
|
|
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name) |
|
|
|
for j in range(rank_base): |
|
|
|
if v[j] != x_shape[0][j]: |
|
|
|
raise ValueError("Pack evaluator element %d shape in input can not pack with first element" % i) |
|
|
|
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") |
|
|
|
out_shape.insert(axis, N) |
|
|
|
return out_shape |
|
|
|
|
|
|
|
@@ -1376,14 +1355,14 @@ class Pack(PrimitiveWithInfer): |
|
|
|
def __init__(self, axis=0): |
|
|
|
"""init Pack""" |
|
|
|
self.__setattr_flag__ = True |
|
|
|
validator.check_type("axis", axis, [int]) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
self.axis = axis |
|
|
|
|
|
|
|
def __infer__(self, value): |
|
|
|
x_shape = value['shape'] |
|
|
|
x_type = value['dtype'] |
|
|
|
self.add_prim_attr('num', len(x_shape)) |
|
|
|
all_shape = _get_pack_shape(x_shape, x_type, self.axis) |
|
|
|
all_shape = _get_pack_shape(x_shape, x_type, self.axis, self.name) |
|
|
|
out = {'shape': all_shape, |
|
|
|
'dtype': x_type[0], |
|
|
|
'value': None} |
|
|
|
@@ -1429,22 +1408,23 @@ class Unpack(PrimitiveWithInfer): |
|
|
|
def __init__(self, axis=0): |
|
|
|
"""init Unpack""" |
|
|
|
self.__setattr_flag__ = True |
|
|
|
validator.check_type("axis", axis, [int]) |
|
|
|
validator.check_value_type("axis", axis, [int], self.name) |
|
|
|
self.axis = axis |
|
|
|
|
|
|
|
def __infer__(self, x): |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) |
|
|
|
x_shape = list(x['shape']) |
|
|
|
dim = len(x_shape) |
|
|
|
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) |
|
|
|
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) |
|
|
|
if self.axis < 0: |
|
|
|
self.axis = self.axis + dim |
|
|
|
output_num = x_shape[self.axis] |
|
|
|
validator.check_type("num", output_num, [int]) |
|
|
|
validator.check_integer("output_num", output_num, 0, Rel.GT) |
|
|
|
validator.check_value_type("num", output_num, [int], self.name) |
|
|
|
validator.check_integer("output_num", output_num, 0, Rel.GT, self.name) |
|
|
|
self.add_prim_attr('num', output_num) |
|
|
|
output_valid_check = x_shape[self.axis] - output_num |
|
|
|
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ) |
|
|
|
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ, |
|
|
|
self.name) |
|
|
|
out_shapes = [] |
|
|
|
out_dtypes = [] |
|
|
|
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] |
|
|
|
@@ -1486,8 +1466,8 @@ class Slice(PrimitiveWithInfer): |
|
|
|
def __infer__(self, x, begin, size): |
|
|
|
x_shape = x['shape'] |
|
|
|
x_shp_len = len(x_shape) |
|
|
|
validator.check_const_input('begin', begin['value']) |
|
|
|
validator.check_const_input('size', size['value']) |
|
|
|
validator.check_const_input('begin', begin['value'], self.name) |
|
|
|
validator.check_const_input('size', size['value'], self.name) |
|
|
|
begin_v, size_v = begin['value'], size['value'] |
|
|
|
if begin_v is None or size_v is None: |
|
|
|
return {'shape': None, |
|
|
|
@@ -1499,7 +1479,8 @@ class Slice(PrimitiveWithInfer): |
|
|
|
for i in range(x_shp_len): |
|
|
|
if x_shape[i] < begin_v[i] + size_v[i]: |
|
|
|
y = begin_v[i] + size_v[i] |
|
|
|
raise ValueError("Slice shape can not bigger than orign shape %d, %d." % (x_shape[i], y)) |
|
|
|
raise ValueError("For '%s' slice shape can not bigger than orign shape %d, %d." % |
|
|
|
(self.name, x_shape[i], y)) |
|
|
|
return {'shape': size_v, |
|
|
|
'dtype': x['dtype'], |
|
|
|
'value': None} |
|
|
|
@@ -1565,11 +1546,11 @@ class Select(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_dtype(self, cond_type, x_type, y_type): |
|
|
|
self.add_prim_attr('T', x_type) |
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor) |
|
|
|
validator.check_subclass("y_type", y_type, mstype.tensor) |
|
|
|
validator.check_typename("cond_type", cond_type, [mstype.bool_]) |
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) |
|
|
|
validator.check_subclass("y_type", y_type, mstype.tensor, self.name) |
|
|
|
validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name) |
|
|
|
if x_type != y_type: |
|
|
|
raise TypeError('The x_type %s must be the same as y_type %s.' % (x_type, y_type)) |
|
|
|
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -1637,26 +1618,23 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
shrink_axis_mask=0): |
|
|
|
"""init StrideSlice""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) |
|
|
|
validator.check_type('begin_mask', begin_mask, [int]) |
|
|
|
validator.check_type('end_mask', end_mask, [int]) |
|
|
|
validator.check_type('ellipsis_mask', ellipsis_mask, [int]) |
|
|
|
validator.check_type('new_axis_mask', new_axis_mask, [int]) |
|
|
|
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int]) |
|
|
|
validator.check_value_type('begin_mask', begin_mask, [int], self.name) |
|
|
|
validator.check_value_type('end_mask', end_mask, [int], self.name) |
|
|
|
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) |
|
|
|
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) |
|
|
|
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) |
|
|
|
|
|
|
|
def __infer__(self, x, begin, end, strides): |
|
|
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] |
|
|
|
validator.check_const_input("begin", begin_v) |
|
|
|
validator.check_const_input("end", end_v) |
|
|
|
validator.check_const_input("strides", strides_v) |
|
|
|
validator.check_type("begin", begin_v, [tuple]) |
|
|
|
validator.check_type("end", end_v, [tuple]) |
|
|
|
validator.check_type("strides", strides_v, [tuple]) |
|
|
|
validator.check_value_type("begin", begin_v, [tuple], self.name) |
|
|
|
validator.check_value_type("end", end_v, [tuple], self.name) |
|
|
|
validator.check_value_type("strides", strides_v, [tuple], self.name) |
|
|
|
|
|
|
|
x_shape = x['shape'] |
|
|
|
x_shp_len = len(x_shape) |
|
|
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len: |
|
|
|
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} " |
|
|
|
f"must be equal to the dims({x_shp_len}) of input.") |
|
|
|
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and " |
|
|
|
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.") |
|
|
|
|
|
|
|
ret_shape = [] |
|
|
|
append_dimensions = [] |
|
|
|
@@ -1669,8 +1647,8 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)]) |
|
|
|
continue |
|
|
|
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1': |
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE) |
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT) |
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name) |
|
|
|
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name) |
|
|
|
continue |
|
|
|
|
|
|
|
begin_idx = begin_v[i] |
|
|
|
@@ -1680,9 +1658,9 @@ class StridedSlice(PrimitiveWithInfer): |
|
|
|
begin_idx = 0 |
|
|
|
if self.end_mask: |
|
|
|
end_idx = x_shape[i] |
|
|
|
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE) |
|
|
|
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE) |
|
|
|
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE) |
|
|
|
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name) |
|
|
|
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name) |
|
|
|
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name) |
|
|
|
if strides_idx > 0: |
|
|
|
# If sliced forward , end_idx >= begin_idx |
|
|
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE) |
|
|
|
@@ -1736,7 +1714,7 @@ class Diag(PrimitiveWithInfer): |
|
|
|
"""init Diag""" |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_subclass('input_x', x_type, mstype.tensor) |
|
|
|
validator.check_subclass('input_x', x_type, mstype.tensor, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
@@ -1748,7 +1726,7 @@ class Diag(PrimitiveWithInfer): |
|
|
|
def infer_value(self, x): |
|
|
|
if x is None: |
|
|
|
return None |
|
|
|
validator.check("input x rank", len(x.shape()), "", 1) |
|
|
|
validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name) |
|
|
|
ret = np.diag(x.asnumpy()) |
|
|
|
return Tensor(ret) |
|
|
|
|
|
|
|
@@ -1783,13 +1761,13 @@ class DiagPart(PrimitiveWithInfer): |
|
|
|
"""init DiagPart""" |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_subclass('input_x', x_type, mstype.tensor) |
|
|
|
validator.check_subclass('input_x', x_type, mstype.tensor, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
if len(x_shape)%2 != 0 or \ |
|
|
|
not x_shape: |
|
|
|
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, " |
|
|
|
raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, " |
|
|
|
f"with shapes {x_shape}") |
|
|
|
length = len(x_shape) // 2 |
|
|
|
ret_shape = x_shape[0:length] |
|
|
|
@@ -1798,7 +1776,7 @@ class DiagPart(PrimitiveWithInfer): |
|
|
|
def infer_value(self, x): |
|
|
|
if x is None: |
|
|
|
return None |
|
|
|
validator.check("x rank", len(x.shape()), "", 2) |
|
|
|
validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name) |
|
|
|
ret = np.diag(x.asnumpy()) |
|
|
|
return Tensor(ret) |
|
|
|
|
|
|
|
@@ -1826,12 +1804,10 @@ class Eye(PrimitiveWithInfer): |
|
|
|
"""init Eye""" |
|
|
|
|
|
|
|
def infer_value(self, n, m, t): |
|
|
|
validator.check_type("n", n, [int]) |
|
|
|
validator.check_integer("n", n, 0, Rel.GT) |
|
|
|
validator.check_type("m", m, [int]) |
|
|
|
validator.check_integer("m", m, 0, Rel.GT) |
|
|
|
validator.check_integer("n", n, 0, Rel.GT, self.name) |
|
|
|
validator.check_integer("m", m, 0, Rel.GT, self.name) |
|
|
|
args = {"dtype": t} |
|
|
|
validator.check_type_same(args, mstype.number_type + (mstype.bool_,)) |
|
|
|
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
np_type = mstype.dtype_to_nptype(t) |
|
|
|
ret = np.eye(n, m, dtype=np_type) |
|
|
|
return Tensor(ret) |
|
|
|
@@ -1866,16 +1842,15 @@ class ScatterNd(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def __infer__(self, indices, update, shape): |
|
|
|
shp = shape['value'] |
|
|
|
validator.check_subclass("indices_dtype", indices['dtype'], mstype.tensor) |
|
|
|
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor) |
|
|
|
validator.check_typename("indices_dtype", indices['dtype'], mstype.int_type) |
|
|
|
validator.check_type("shape", shp, [tuple]) |
|
|
|
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) |
|
|
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) |
|
|
|
validator.check_value_type("shape", shp, [tuple], self.name) |
|
|
|
for i, x in enumerate(shp): |
|
|
|
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT) |
|
|
|
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name) |
|
|
|
|
|
|
|
indices_shape, update_shape = indices["shape"], update["shape"] |
|
|
|
if indices_shape[0] != update_shape[0]: |
|
|
|
raise ValueError('The indices_shape[0] and update_shape[0] must be equal.') |
|
|
|
raise ValueError(f'For \'{self.name}\' The indices_shape[0] and update_shape[0] must be equal.') |
|
|
|
|
|
|
|
return {'shape': shp, |
|
|
|
'dtype': update['dtype'], |
|
|
|
@@ -1913,7 +1888,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) |
|
|
|
|
|
|
|
def infer_shape(self, x): |
|
|
|
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE) |
|
|
|
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name) |
|
|
|
return tuple(x)[:-2] + tuple(self.size) |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
@@ -1947,13 +1922,12 @@ class GatherNd(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, indices_shape): |
|
|
|
validator.check('the dimension of x', len(x_shape), |
|
|
|
'the dimension of indices', indices_shape[-1], Rel.GE) |
|
|
|
'the dimension of indices', indices_shape[-1], Rel.GE, self.name) |
|
|
|
return indices_shape[:-1] + x_shape[indices_shape[-1]:] |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype): |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("indices_dtype", indices_dtype, mstype.tensor) |
|
|
|
validator.check_typename("indices_dtype", indices_dtype, mstype.int_type) |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) |
|
|
|
validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -1995,12 +1969,9 @@ class ScatterNdUpdate(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("indices_dtype", indices_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("value_dtype", value_dtype, mstype.tensor) |
|
|
|
validator.check_typename('indices_dtype', indices_dtype, mstype.int_type) |
|
|
|
args = {"x_dtype": x_dtype, "value_dtype": value_dtype} |
|
|
|
validator.check_type_same(args, (mstype.bool_,) + mstype.number_type) |
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) |
|
|
|
args = {"x": x_dtype, "value": value_dtype} |
|
|
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -2038,7 +2009,7 @@ class SpaceToDepth(PrimitiveWithInfer): |
|
|
|
def __init__(self, block_size): |
|
|
|
"""Init SpaceToDepth""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['y']) |
|
|
|
validator.check_type('block_size', block_size, [int]) |
|
|
|
validator.check_value_type('block_size', block_size, [int], self.name) |
|
|
|
validator.check('block_size', block_size, '', 2, Rel.GE) |
|
|
|
self.block_size = block_size |
|
|
|
self.add_prim_attr("data_format", "NCHW") |
|
|
|
@@ -2048,7 +2019,7 @@ class SpaceToDepth(PrimitiveWithInfer): |
|
|
|
out_shape = copy.deepcopy(x_shape) |
|
|
|
for i in range(2): |
|
|
|
if out_shape[i+2] % self.block_size != 0: |
|
|
|
raise ValueError(f'SpaceToDepth input shape[{i+2}] {out_shape[i+2]} should be ' |
|
|
|
raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be ' |
|
|
|
f'fully divided by block_size {self.block_size}') |
|
|
|
out_shape[i+2] //= self.block_size |
|
|
|
|
|
|
|
@@ -2056,7 +2027,7 @@ class SpaceToDepth(PrimitiveWithInfer): |
|
|
|
return out_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -2096,8 +2067,8 @@ class DepthToSpace(PrimitiveWithInfer): |
|
|
|
def __init__(self, block_size): |
|
|
|
"""Init DepthToSpace""" |
|
|
|
self.init_prim_io_names(inputs=['x'], outputs=['y']) |
|
|
|
validator.check_type('block_size', block_size, [int]) |
|
|
|
validator.check('block_size', block_size, '', 2, Rel.GE) |
|
|
|
validator.check_value_type('block_size', block_size, [int], self.name) |
|
|
|
validator.check('block_size', block_size, '', 2, Rel.GE, self.name) |
|
|
|
self.block_size = block_size |
|
|
|
self.add_prim_attr("data_format", "NCHW") |
|
|
|
|
|
|
|
@@ -2107,12 +2078,13 @@ class DepthToSpace(PrimitiveWithInfer): |
|
|
|
for i in range(2): |
|
|
|
out_shape[i+2] *= self.block_size |
|
|
|
|
|
|
|
validator.check('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), '', 0) |
|
|
|
validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), |
|
|
|
0, Rel.EQ, self.name) |
|
|
|
out_shape[1] //= self.block_size * self.block_size |
|
|
|
return out_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) |
|
|
|
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -2159,27 +2131,26 @@ class SpaceToBatch(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, block_size, paddings): |
|
|
|
"""Init SpaceToBatch""" |
|
|
|
validator.check_type('block_size', block_size, [int]) |
|
|
|
validator.check('block_size', block_size, '', 1, Rel.GT) |
|
|
|
validator.check_value_type('block_size', block_size, [int], self.name) |
|
|
|
validator.check('block_size', block_size, '', 1, Rel.GT, self.name) |
|
|
|
self.block_size = block_size |
|
|
|
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2)) |
|
|
|
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name) |
|
|
|
for elem in itertools.chain(*paddings): |
|
|
|
validator.check_type('paddings element', elem, [int]) |
|
|
|
validator.check_value_type('paddings element', elem, [int], self.name) |
|
|
|
self.paddings = paddings |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor) |
|
|
|
validator.check_typename('input_x', x_dtype, mstype.number_type) |
|
|
|
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
validator.check('rank of input_x', len(x_shape), '', 4) |
|
|
|
validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) |
|
|
|
out_shape = copy.deepcopy(x_shape) |
|
|
|
for i in range(2): |
|
|
|
padded = out_shape[i+2] + self.paddings[i][0] + \ |
|
|
|
self.paddings[i][1] |
|
|
|
if padded % self.block_size != 0: |
|
|
|
raise ValueError(f'padded[{i}] {padded} should be divisible by ' |
|
|
|
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' |
|
|
|
f'block_size {self.block_size}') |
|
|
|
out_shape[i+2] = padded // self.block_size |
|
|
|
out_shape[0] *= self.block_size * self.block_size |
|
|
|
@@ -2227,17 +2198,16 @@ class BatchToSpace(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, block_size, crops): |
|
|
|
"""Init BatchToSpace""" |
|
|
|
validator.check_type('block_size', block_size, [int]) |
|
|
|
validator.check('block_size', block_size, '', 1, Rel.GT) |
|
|
|
validator.check_value_type('block_size', block_size, [int], self.name) |
|
|
|
validator.check('block_size', block_size, '', 1, Rel.GT, self.name) |
|
|
|
self.block_size = block_size |
|
|
|
validator.check('crops shape', np.array(crops).shape, '', (2, 2)) |
|
|
|
for elem in itertools.chain(*crops): |
|
|
|
validator.check_type('crops element', elem, [int]) |
|
|
|
validator.check_value_type('crops element', elem, [int], self.name) |
|
|
|
self.crops = crops |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_subclass("input_x", x_dtype, mstype.tensor) |
|
|
|
validator.check_typename('input_x', x_dtype, mstype.number_type) |
|
|
|
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
@@ -2246,11 +2216,11 @@ class BatchToSpace(PrimitiveWithInfer): |
|
|
|
for i in range(2): |
|
|
|
x_block_prod = out_shape[i+2] * self.block_size |
|
|
|
crops_sum = self.crops[i][0] + self.crops[i][1] |
|
|
|
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT) |
|
|
|
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) |
|
|
|
out_shape[i+2] = x_block_prod - crops_sum |
|
|
|
block_size_prod = self.block_size * self.block_size |
|
|
|
if out_shape[0] % block_size_prod != 0: |
|
|
|
raise ValueError(f'input_x dimension 0 {out_shape[0]} should be divisible by ' |
|
|
|
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' |
|
|
|
f'block_size_prod {block_size_prod}') |
|
|
|
out_shape[0] = out_shape[0] // block_size_prod |
|
|
|
return out_shape |