Browse Source

!8114 rectify and optimize the type checking function

Merge pull request !8114 from zhangbuxue/rectify_and_optimize_the_type_checking_function
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e9cd12e904
30 changed files with 648 additions and 712 deletions
  1. +38
    -108
      mindspore/_checkparam.py
  2. +18
    -4
      mindspore/common/tensor.py
  3. +10
    -8
      mindspore/nn/graph_kernels/graph_kernels.py
  4. +8
    -8
      mindspore/nn/layer/normalization.py
  5. +2
    -2
      mindspore/nn/layer/quant.py
  6. +1
    -1
      mindspore/nn/probability/bijector/gumbel_cdf.py
  7. +1
    -1
      mindspore/nn/probability/distribution/bernoulli.py
  8. +1
    -1
      mindspore/nn/probability/distribution/categorical.py
  9. +1
    -1
      mindspore/nn/probability/distribution/exponential.py
  10. +1
    -1
      mindspore/nn/probability/distribution/geometric.py
  11. +1
    -1
      mindspore/nn/probability/distribution/gumbel.py
  12. +1
    -1
      mindspore/nn/probability/distribution/logistic.py
  13. +1
    -1
      mindspore/nn/probability/distribution/normal.py
  14. +1
    -1
      mindspore/nn/probability/distribution/uniform.py
  15. +5
    -8
      mindspore/ops/operations/_cache_ops.py
  16. +57
    -51
      mindspore/ops/operations/_grad_ops.py
  17. +20
    -19
      mindspore/ops/operations/_inner_ops.py
  18. +63
    -78
      mindspore/ops/operations/_quant_ops.py
  19. +13
    -8
      mindspore/ops/operations/_thor_ops.py
  20. +64
    -59
      mindspore/ops/operations/array_ops.py
  21. +7
    -7
      mindspore/ops/operations/comm_ops.py
  22. +2
    -3
      mindspore/ops/operations/control_ops.py
  23. +1
    -1
      mindspore/ops/operations/debug_ops.py
  24. +5
    -5
      mindspore/ops/operations/image_ops.py
  25. +90
    -89
      mindspore/ops/operations/math_ops.py
  26. +214
    -222
      mindspore/ops/operations/nn_ops.py
  27. +10
    -11
      mindspore/ops/operations/other_ops.py
  28. +9
    -9
      mindspore/ops/operations/random_ops.py
  29. +2
    -2
      mindspore/train/serialization.py
  30. +1
    -1
      tests/ut/python/ir/test_row_tensor.py

+ 38
- 108
mindspore/_checkparam.py View File

@@ -415,37 +415,20 @@ class Validator:
break
if not hit:
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass'
f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.')

@staticmethod
def check_const_input(arg_name, arg_value, prim_name):
"""Checks valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.')
return arg_value

@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')

if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()

@staticmethod
def check_type_same(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same."""
def _check_tensor_type(arg):
def check_types_same_and_valid(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same and valid."""
def _check_type_valid(arg):
arg_key, arg_val = arg
elem_type = arg_val
Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
@@ -455,21 +438,27 @@ class Validator:
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
return arg1

elem_types = map(_check_tensor_type, args.items())
elem_types = map(_check_type_valid, args.items())
reduce(_check_types_same, elem_types)

@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""Checks whether the element types of input tensors are the same."""
tensor_types = [mstype.tensor_type(t) for t in valid_values]
Validator.check_type_same(args, tensor_types, prim_name)
def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are the same and valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_types_same_and_valid(args, tensor_types, prim_name)

@staticmethod
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)

@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
"""
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
@@ -480,7 +469,7 @@ class Validator:
if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type()
if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.')
return arg

@@ -512,40 +501,40 @@ class Validator:

def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
num_types = len(valid_types)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
f'{type_names if num_types > 1 else type_names[0]}, '
f'but got {arg_value} with type {type(arg_value).__name__}.')

# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
if not isinstance(arg_value, tuple(valid_types)):
raise_error_msg()
return arg_value

@staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)

def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
f"{type_names if num_types > 1 else type_names[0]}, "
f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")

if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()

if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
if arg_type not in valid_types:
raise_error_msg()
return arg_type

@staticmethod
def check_reduce_shape(ori_shape, shape, axis, prim_name):
@@ -611,65 +600,6 @@ def check_output_data(data):
once = _expand_tuple(1)
twice = _expand_tuple(2)
triple = _expand_tuple(3)
valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, bool, np.bool_)


def check_type(arg_name, arg_value, valid_types):
"""Check value type."""
# if input type is Tensor ,get element type
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()

# First, check if arg_value has argvalid_types
if isinstance(arg_value, tuple(valid_types)):
return type(arg_value).__name__

# Second, wrap arg_value with numpy array so that it can be checked through numpy api
if isinstance(arg_value, (list, tuple)):
arg_value = np.array(arg_value)

# Thirdly, check the data type by numpy's dtype api
valid = False
if isinstance(arg_value, np.ndarray):
valid = arg_value.dtype in valid_data_types

# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
valid = False

if not valid:
type_names = [t.__name__ for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {type(arg_value).__name__}.')

return type(arg_value).__name__


def check_typename(arg_name, arg_type, valid_types):
"""Check type name."""

def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)

if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()

if arg_type in valid_types:
return arg_type
if isinstance(arg_type, tuple(valid_types)):
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')


def args_type_check(*type_args, **type_kwargs):


+ 18
- 4
mindspore/common/tensor.py View File

@@ -19,7 +19,7 @@ from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename
from .._checkparam import Validator as validator
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry

@@ -64,9 +64,19 @@ class Tensor(Tensor_):
input_data = np.array(input_data)

# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int))
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is "
f"{input_data.dtype} that is not supported to initialize a Tensor.")
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
if dtype is not None:
check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,))
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")

if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data)
if dtype is None:
@@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_):
Returns:
Array, an array after being initialized.
"""

def __init__(self, dtype, shape, init=None):
#check param
# check param
self.init = init
MetaTensor_.__init__(self, dtype, shape)

@@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)

class seed_context:
'''set and restore seed'''

def __init__(self, init):
self.init = init
from .seed import get_seed
@@ -482,4 +495,5 @@ def _vm_compare(*args):
y = args[0]
return Tensor(np.array(fn(y)))


tensor_operator_registry.register('vm_compare', _vm_compare)

+ 10
- 8
mindspore/nn/graph_kernels/graph_kernels.py View File

@@ -21,7 +21,7 @@ from ...ops import operations as P
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
from ...ops.composite import multitype_ops as C
from ...ops.operations import _grad_ops as G
from ..._checkparam import Validator
from ..._checkparam import Validator as validator
from ..cell import Cell, GraphKernel


@@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
use_locking=False,
gradient_scale=1.0):
super(ApplyMomentum, self).__init__()
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__)
self.fake_output_assign_1 = InplaceAssign()
self.fake_output_assign_1.add_prim_attr("fake_output", True)
self.fake_output_assign_2 = InplaceAssign()
@@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):

def __init__(self, keep_dims=True):
super(ReduceMean, self).__init__()
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__)
self.sum = P.ReduceSum(self.keep_dims)

def construct(self, x, axis):
@@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel):
""" Forward function of the LayerNorm operator. """
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
super(LayerNormForward, self).__init__()
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int],
type(self).__name__)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int],
type(self).__name__)
self.mul = P.Mul()
self.sum_keep_dims = P.ReduceSum(keep_dims=True)
self.sub = P.Sub()
@@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel):

def __init__(self, axis=-1):
super(LogSoftmax, self).__init__()
self.axis = Validator.check_type('axis', axis, [int])
self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__)
self.max_keep_dims = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
self.exp = P.Exp()
@@ -952,13 +954,13 @@ class Softmax(GraphKernel):

def __init__(self, axis):
super(Softmax, self).__init__()
Validator.check_type("axis", axis, [int, tuple])
validator.check_value_type("axis", axis, [int, tuple], type(self).__name__)
if isinstance(axis, int):
self.axis = (axis,)
else:
self.axis = axis
for item in self.axis:
Validator.check_type("item of axis", item, [int])
validator.check_value_type("item of axis", item, [int], type(self).__name__)
self.max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
self.exp = P.Exp()


+ 8
- 8
mindspore/nn/layer/normalization.py View File

@@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.context as context
from mindspore._checkparam import Validator, check_typename
from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
@@ -52,7 +52,7 @@ class _BatchNorm(Cell):

if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics
@@ -67,7 +67,7 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.group = Validator.check_positive_int(device_num_each_group)
self.group = validator.check_positive_int(device_num_each_group)
self.is_global = False
if self.group != 1:
self.rank_id = get_rank()
@@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm):
use_batch_statistics,
device_num_each_group,
input_dims='both')
self.group = Validator.check_positive_int(device_num_each_group)
self.group = validator.check_positive_int(device_num_each_group)
if self.group <= 1:
raise ValueError("the number of group must be greater than 1.")

@@ -607,12 +607,12 @@ class GroupNorm(Cell):

def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
super(GroupNorm, self).__init__()
self.num_groups = Validator.check_positive_int(num_groups)
self.num_channels = Validator.check_positive_int(num_channels)
self.num_groups = validator.check_positive_int(num_groups)
self.num_channels = validator.check_positive_int(num_channels)
if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,))
self.affine = Validator.check_bool(affine)
self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
self.affine = validator.check_bool(affine)

gamma = initializer(gamma_init, num_channels)
beta = initializer(beta_init, num_channels)


+ 2
- 2
mindspore/nn/layer/quant.py View File

@@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__)
Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__)
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init


+ 1
- 1
mindspore/nn/probability/bijector/gumbel_cdf.py View File

@@ -68,7 +68,7 @@ class GumbelCDF(Bijector):
"""
param = dict(locals())
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)



+ 1
- 1
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -119,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param)

self._probs = self._add_parameter(probs, 'probs')


+ 1
- 1
mindspore/nn/probability/distribution/categorical.py View File

@@ -109,7 +109,7 @@ class Categorical(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type
Validator.check_type("Categorical", dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs')


+ 1
- 1
mindspore/nn/probability/distribution/exponential.py View File

@@ -121,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals())
param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param)

self._rate = self._add_parameter(rate, 'rate')


+ 1
- 1
mindspore/nn/probability/distribution/geometric.py View File

@@ -122,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals())
param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param)

self._probs = self._add_parameter(probs, 'probs')


+ 1
- 1
mindspore/nn/probability/distribution/gumbel.py View File

@@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution):
Constructor of Gumbel distribution.
"""
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype),


+ 1
- 1
mindspore/nn/probability/distribution/logistic.py View File

@@ -111,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Logistic, self).__init__(seed, dtype, name, param)

self._loc = self._add_parameter(loc, 'loc')


+ 1
- 1
mindspore/nn/probability/distribution/normal.py View File

@@ -127,7 +127,7 @@ class Normal(Distribution):
param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param)

self._mean_value = self._add_parameter(mean, 'mean')


+ 1
- 1
mindspore/nn/probability/distribution/uniform.py View File

@@ -126,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals())
param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param)

self._low = self._add_parameter(low, 'low')


+ 5
- 8
mindspore/ops/operations/_cache_ops.py View File

@@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer):
return [1]

def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
args = {"indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype


@@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer):

def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
return out_dtype

@@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer):
return out_shape

def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
args = {"miss_emb_idx": miss_emb_idx_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
return out_dtype

@@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer):
return miss_value_shape

def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
args = {"swap_cache_idx": swap_cache_idx_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
return miss_value_dtype


@@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer):

def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype)
return out_dtype

+ 57
- 51
mindspore/ops/operations/_grad_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================

"""Operators for gradients."""
from functools import partial

from .. import signature as sig
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@@ -23,6 +24,7 @@ from ...common import dtype as mstype
from .. import functional as F
from ... import context


class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation."""

@@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer):

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x


@@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer):

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x


@@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer):

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x


@@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer):

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x


@@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer):

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype


@@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer):

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8],
self.name)
return x_dtype


@@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer):

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype


@@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer):

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype


@@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer):

def infer_dtype(self, x_type, y_type, doutput_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, y_type


@@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):

def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if weight_type:
validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
return x_type
@@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
self.name)
out = {
'value': None,
'shape': w_size_v,
@@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
def __infer__(self, x, w_size, dout):
w_size_v = w_size['value']
args = {'x': x['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
out = {
'value': None,
'shape': w_size_v,
@@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):

def __infer__(self, x_size, w, dout):
args = {'w': w['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
x_size_v = x_size['value']
out = {
'value': None,
@@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer):
return dy_shape

def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name)
return dy_dtype


@@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("y_backprop", "x", "y"),
(y_backprop_dtype, x_dtype, y_dtype)))
return x_dtype


@@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad):

def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return x1_dtype


@@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad):

def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return grad_dtype


@@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer):

def infer_dtype(self, input_x, out, dout):
args = {'input_x': input_x, 'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return input_x


@@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
hx_dtype, cx_dtype, reserve_dtype, state_dtype):
args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (dy_dtype, dy_dtype, dy_dtype)


@@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name)
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
if seq_dtype is not None:
validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
if mask_dtype is not None:
validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype


@@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer):
return y_backprop_shape, w_shape

def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
('y_backprop', "input_x", "weight"),
(y_backprop_dtype, A_dtype, w_dtype)))
return y_backprop_dtype, w_dtype


@@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype


@@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer):
return gradients_shape

def infer_dtype(self, gradients_dtype, mask_dtype):
validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name)
return gradients_dtype


@@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer):

def infer_dtype(self, y_grad_dtype, x_dtype):
args = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype


@@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer):

def infer_dtype(self, out, dout):
args = {'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out


@@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype


@@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype


@@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):

def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dout_dtype


@@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):

def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dloss


@@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])

def __infer__(self, dy, shapex, begin, end, strides):
args = {"dy": dy['dtype']}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name)

for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
@@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer):

def infer_dtype(self, dout_dtype, x_dtype):
args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype


@@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer):

def infer_dtype(self, out, dout):
args = {"out": out, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out


@@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer):

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x


@@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer):

def infer_dtype(self, grads, x, y):
args = {"grads": grads, "x": x, "y": y}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
return x

def infer_shape(self, grads, x, y):


+ 20
- 19
mindspore/ops/operations/_inner_ops.py View File

@@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"):
"""init"""

def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
@@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer):

def infer_dtype(self, input_x):
"""infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name)
return input_x


@@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
return x_dtype


@@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer):
>>> dequant = P.Dequant(False, False)
>>> y = dequant(input_x)
"""

@prim_attr_register
def __init__(self, sqrt_mode=False, relu_flag=False):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
@@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer):
return assist

def infer_dtype(self, assist, start, stop, num):
args = {"num": num}
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name)
args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
return assist


@@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype

def infer_shape(self, x_shape, assist_shape):
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1,
validator.check('rank of x', len(x_shape) + 1,
'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
assist_shape[-1], Rel.EQ, self.name)
@@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer):
while r_idx >= r_end_dim:
if x_shape[r_idx] != 1:
validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name)
assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name)
r_idx = r_idx - 1

return assist_shape
@@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype

def infer_shape(self, x_shape, assist_shape):
@@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype

def infer_shape(self, x_shape, diagonal_shape, assist_shape):
@@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer):
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape

def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name)
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
b_dtype = mstype.float32
if binput_dtype is not None:
validator.check_tensor_type_same({"bias input dtype": binput_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = binput_dtype
elif bhidden_dtype is not None:
validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = bhidden_dtype
elif h_dtype is not None:
validator.check_tensor_type_same({"init_h dtype": h_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = h_dtype
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype



+ 63
- 78
mindspore/ops/operations/_quant_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================

"""Operators for quantization."""
from functools import partial

import mindspore.context as context
from ..._checkparam import Validator as validator
@@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
return min_shape, max_shape

def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return min_type, max_type


@@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
return min_shape, max_shape

def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return min_type, max_type


@@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
"""

@prim_attr_register
def __init__(self,
num_bits=8,
@@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type


@@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
>>> min_gradient shape: (1,) data type: mstype.float32
>>> max_gradient shape: (1,) data type: mstype.float32
"""

@prim_attr_register
def __init__(self,
num_bits=8,
@@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape

def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
('dout', "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return x_type, min_type, max_type


@@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
"""

@prim_attr_register
def __init__(self,
num_bits=8,
@@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type


@@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
>>> min_gradient shape: (4,) data type: mstype.float32
>>> max_gradient shape: (4,) data type: mstype.float32
"""

@prim_attr_register
def __init__(self,
num_bits=8,
@@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape

def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return x_type, min_type, max_type


@@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer):

def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type


@@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):

def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type


@@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):

def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type


@@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):

def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type


@@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type, x_type, x_type, x_type


@@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
global_step_type):
args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
"batch_mean": batch_mean_type, "batch_std": batch_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type


@@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer):

def infer_dtype(self, x_type, batch_std_type, running_std_type):
args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type


@@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):

def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if context.get_context('device_target') == "Ascend":
return x_type, x_type
return x_type, gamma_type
@@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer):
running_mean_type, global_step_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type


@@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
"dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type


@@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, x_type, x_type, x_type, x_type, x_type, x_type


@@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
validator.check("input type", x_type, "batch_mean type", batch_mean_type)
validator.check("input type", x_type, "batch_std type", batch_std_type)
args = {"input type": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
return x_type


@@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer):
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type


@@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
"dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type




+ 13
- 8
mindspore/ops/operations/_thor_ops.py View File

@@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
def infer_shape(self, data1_shape):
ll = []
if len(data1_shape) == 2:
ll = [1,]
ll = [1]
else:
ll = [32, 64]
return ll
@@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer):
>>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
>>> output = img2col(input_x)
"""

@prim_attr_register
def __init__(self,
kernel_size,
@@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer):
return out_shape

def infer_dtype(self, x_dtype):
args = {'x': x_dtype}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
return x_dtype


@@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer):
return x2_shape

def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(
{'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
return x2_dtype


class Cholesky(PrimitiveWithInfer):
"""
Inner API for resnet50 THOR GPU backend
"""

@prim_attr_register
def __init__(self, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer):
return out_shape

def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype


class DetTriangle(PrimitiveWithInfer):
"""
Calculate the determinant of triangle matrices
"""

@prim_attr_register
def __init__(self, fill_mode=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer):
return out_shape

def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype

+ 64
- 59
mindspore/ops/operations/array_ops.py View File

@@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_dtype


@@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp):
"""
Defines _ScatterNd operators
"""

def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
@@ -627,6 +628,7 @@ class Unique(Primitive):
>>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32))
"""

@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
@@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck):
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2,])
self.add_prim_attr("dynamic_shape_depends", [2])

def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']
@@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer):
>>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]]
"""

@prim_attr_register
def __init__(self, pad_dim_size=8):
"""Initialize padding"""
@@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer):
>>> out = P.UniqueWithPad()(x, pad_num)
([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
"""

@prim_attr_register
def __init__(self):
"""init UniqueWithPad"""

def __infer__(self, x, pad_num):
validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
x_shape = list(x['shape'])
validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
@@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer):
def __init__(self, seed=0, dtype=mstype.float32):
"""Initialize TruncatedNormal"""
validator.check_value_type('seed', seed, [int], self.name)
validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)

def __infer__(self, shape):
shape_value = shape['value']
@@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer):
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.full(dims['value'], x['value'], x_nptype)
out = {
@@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype


@@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype


@@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer):
"""Initialize Argmax"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name)
validator.check_type_same({'output': output_type}, [mstype.int32], self.name)
validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name)
self.axis = axis
self.add_prim_attr('output_type', output_type)

@@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def __init__(self):
"""Initialize UnsortedSegmentSum"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2,])
self.add_prim_attr("dynamic_shape_depends", [2])

def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype']
@@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name)
shp = [-1]
else:
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
@@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
x_shape = x['shape']
segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_value_type("x_shape", x_shape, [list], self.name)
valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer):
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)

args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)

first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]):
@@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype


@@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype


@@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer):
self.add_prim_attr('T', x_type)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
if x_type != y_type:
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
return x_type
@@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer):
validator.check_positive_int(n, "n", self.name)
validator.check_positive_int(m, "m", self.name)
args = {"dtype": t}
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
np_type = mstype.dtype_to_nptype(t)
ret = np.eye(n, m, dtype=np_type)
return Tensor(ret)
@@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape):
shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp):
validator.check_positive_int(x, f'shape[{i}]', self.name)
@@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])

def infer_shape(self, x):
validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name)
return tuple(x)[:-2] + tuple(self.size)
def infer_shape(self, x_shape):
validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
return tuple(x_shape)[:-2] + tuple(self.size)

def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype


class GatherNd(PrimitiveWithInfer):
@@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer):
return indices_shape[:-1] + x_shape[indices_shape[-1]:]

def infer_dtype(self, x_dtype, indices_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
return x_dtype


@@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype


@@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])

def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype


@@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp):
[0.4 2.2 -3.2]]
"""


@prim_attr_register
def __init__(self, use_locking=True):
"""Initialize ScatterNdUpdate"""
@@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])

def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype


@@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])

def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
return x_dtype


@@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer):
self.paddings = paddings

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_shape(self, x_shape):
@@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer):
self.crops = crops

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_shape(self, x_shape):
@@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
self.add_prim_attr("paddings", paddings_append)

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_shape(self, x_shape):
@@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
self.add_prim_attr("crops", crops_append)

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_shape(self, x_shape):
@@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer):
out_shape = tuple(tuple(shape_0) for _ in range(n))
return out_shape


def infer_dtype(self, x_type):
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
n = len(x_type)
@@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer):
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
return x_type


class InplaceUpdate(PrimitiveWithInfer):
r"""
Updates specified rows with values in `v`.
@@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype

def infer_shape(self, x_shape, v_shape):
@@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer):
return x

def infer_dtype(self, x, seq_lengths):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
return x


@@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer):
validator.check_const_input('truth_shape', truth_shape['value'], self.name)
args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
"truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
validator.check_tensor_type_same(args_int, [mstype.int64], self.name)
validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name)
args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)

hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
@@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer):
Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`.
"""

@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
@@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer):
def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
@@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer):
return x_shape, x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype, mstype.tensor_type(mstype.int32)


@@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""

@prim_attr_register
def __init__(self):
"""Initialize index_select"""
@@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer):

def __infer__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape']
if len(params_shp) != 2:
@@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer):
>>> out = P.GatherD()(x, dim, index)
[[1, 1], [4, 3]]
"""

@prim_attr_register
def __init__(self):
"""Initialize GatherD"""
@@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer):

def __infer__(self, x, dim, index):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name)
x_shp = x['shape']
idx_shp = index['shape']
@@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer):
>>> y = P.Identity()(x)
[1, 2, 3, 4]
"""

@prim_attr_register
def __init__(self):
"""Initialize identity"""


+ 7
- 7
mindspore/ops/operations/comm_ops.py View File

@@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype


@@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype

def __call__(self, tensor):
@@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype

def __call__(self, tensor):
@@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype

def __call__(self, tensor):
@@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype

def __call__(self, tensor):
@@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer):
if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype:
validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
return x_dtype


@@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype

def __call__(self, tensor):


+ 2
- 3
mindspore/ops/operations/control_ops.py View File

@@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer):
def infer_dtype(self, data_type, pred_type):
validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
{"pred": pred_type}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
return (data_type, data_type)


@@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item

validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)

+ 1
- 1
mindspore/ops/operations/debug_ops.py View File

@@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer):
return [1]

def infer_dtype(self, condition, inputs):
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
for dtype in inputs:
validator.check_subclass("input", dtype, [mstype.tensor], self.name)
return mstype.int32

+ 5
- 5
mindspore/ops/operations/image_ops.py View File

@@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer):
box_index_dtype = box_index['dtype']
crop_size_dtype = crop_size['dtype']
# check dytpe
validator.check_tensor_type_same({"x": x_dtype},
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name)
validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x_dtype,
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name)
validator.check_value_type("crop_size", crop_size_value, [tuple], self.name)
# check input shape rank
validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name)


+ 90
- 89
mindspore/ops/operations/math_ops.py View File

@@ -16,6 +16,8 @@
"""Operators for math."""

import copy
from functools import partial

import numpy as np
from ... import context
from .. import signature as sig
@@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp):
@staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
args_type = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_type, valid_dtype, prim_name)
validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name)
return x_dtype

def infer_dtype(self, x_dtype, y_dtype):
@@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp):
@staticmethod
def _check_bitwise_op_input_type(x1_type, x2_type, prim):
args = {'x1': x1_type, 'x2': x2_type}
valid_types = mstype.int_type + mstype.uint_type
validator.check_tensor_type_same(args, valid_types, prim)
valid_dtypes = mstype.int_type + mstype.uint_type
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim)
return x1_type

def infer_dtype(self, x1_type, x2_type):
@@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer):

def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value


@@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer):

def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value


@@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer):
axis_v = axis['value']
input_shp = input_x['shape']
args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name)

if axis_v is None:
raise ValueError(f"For {self.name}, axis must be const.")
@@ -504,6 +506,7 @@ class ReduceMax(_Reduce):
def __infer__(self, input_x, axis):
return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,))


class ReduceMin(_Reduce):
"""
Reduce a dimension of a tensor by the minimum value in the dimension.
@@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer):

def infer_dtype(self, x_type, axis_type):
cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type

@@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer):

def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name)
if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x1
@@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer):
self.axes = axes
validator.check_value_type('axes', axes, [int, tuple, list], self.name)
if not isinstance(self.axes, int):
self.axes = list(self.axes) # to avoid immutability issues
self.axes = list(self.axes) # to avoid immutability issues
if len(self.axes) != 2:
raise ValueError("Require two axes inputs, given less")
self.int_to_tuple_conv() # convert before length checks
self.int_to_tuple_conv() # convert before length checks
if len(self.axes[0]) != len(self.axes[1]):
raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
@@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer):
if isinstance(self.axes, int):
if self.axes <= 0:
# outer product, no input validation required
self.axes = ([], []) # no axes selected for either
self.axes = ([], []) # no axes selected for either
return
if self.axes > len(x1_shape) or self.axes > len(x2_shape):
raise ValueError(
@@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer):

def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x1


@@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer):
if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
return {'shape': x_shp,
'dtype': x['dtype'],
'value': None}
@@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer):
if dtype == mstype.undetermined:
contains_undetermined = True
if not contains_undetermined:
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0]

def infer_value(self, inputs):
@@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer):
args = {}
for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0]


@@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer):
"""Initialize Neg"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])

def infer_shape(self, input_x):
return input_x
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_value(self, input_x):
if input_x is not None:
@@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype

def infer_shape(self, x_shape, v_shape):
@@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype

def infer_shape(self, x_shape, v_shape):
@@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_value(self, x):
if x is not None:
@@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype

def infer_value(self, x):
if x is not None:
@@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def check_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name)

def infer_value(self, x):
if x is not None:
@@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type


@@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer):
return (self.nbins,)

def infer_dtype(self, x_dtype, range_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
valid_types = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name)
y_dtype = mstype.int32
return y_dtype

@@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer):
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y'])

def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype


class Erf(PrimitiveWithInfer):
@@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype


class Erfc(PrimitiveWithInfer):
@@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type


@@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name)
return x_dtype


@@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype


@@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp):
@staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None):
args_dtype = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_dtype, valid_type, prim_name)
validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name)
return mstype.tensor_type(mstype.bool_)

def infer_dtype(self, x_dtype, y_dtype):
@@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype):
args_dtype = {"x": x_dtype, "y": y_dtype}
valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name)
validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name)
return mstype.tensor_type(mstype.bool_)


@@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer):

def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype


@@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_)


@@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_


@@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1]

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype


@@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
return [8]

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32


@@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
return [8]

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32


@@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,))

def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_)


@@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type

def infer_value(self, x):
@@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


@@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer):
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype


class Tan(PrimitiveWithInfer):
@@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type):
valid_types = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name)
return x_type


@@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type


@@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type


@@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer):
return [], []

def infer_dtype(self, x_type, y_type):
validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name)
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name)
validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name)
return x_type, y_type


@@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer):
return x

def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x


@@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer):
return x

def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x


@@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32,
validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32,
mstype.int32], self.name)
return x_dtype

@@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name)
validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name)
return x_dtype


@@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['input_x'], outputs=['y'])

def __infer__(self, input_x):
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name)

x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type())
if x_nptype == np.float16:
@@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer):
return (1,), (1,)

def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
valid_types = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("input_value", "input_min", "input_max"),
(data_dtype, data_min_dtype, data_max_dtype)))
validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)

+ 214
- 222
mindspore/ops/operations/nn_ops.py
File diff suppressed because it is too large
View File


+ 10
- 11
mindspore/ops/operations/other_ops.py View File

@@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck):

def check_dtype(self, variable, value):
if variable != mstype.type_refkey:
validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name)
validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name)


class BoundingBoxEncode(PrimitiveWithInfer):
@@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer):

def infer_dtype(self, anchor_box, groundtruth_box):
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box


@@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer):

def infer_dtype(self, anchor_box, deltas):
args = {"anchor_box": anchor_box, "deltas": deltas}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box


@@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer):

def infer_dtype(self, bboxes_type, metas_type):
valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8]
validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name)
validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name)
validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name)
validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name)
return mstype.bool_


@@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer):

def infer_dtype(self, anchor_boxes, gt_boxes):
valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name)
validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name)
validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name)
validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name)
return anchor_boxes


@@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
if weights is not None:
validator.check_subclass('weights', weights, mstype.tensor, self.name)
args = {"labels": labels, "predictions": predictions}
validator.check_tensor_type_same(args, (mstype.number_type), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name)
return labels


@@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
args = {"x": x_dtype}
validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name)
return mstype.tensor_type(mstype.uint8)

class Push(PrimitiveWithInfer):


+ 9
- 9
mindspore/ops/operations/random_ops.py View File

@@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = {
@@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
out = {
'shape': broadcast_shape,
@@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name)
Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name)
minval_shape = minval['shape']
maxval_shape = maxval['shape']
Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
@@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
return ([self.count, len(x_shape)], [self.count])

def infer_dtype(self, x_dtype):
Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_)


@@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer):

def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name)
num_samples_v = num_samples['value']
seed_v = seed['value']
Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
@@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer):
input_shape = inputs["shape"]
if len(input_shape) != 1 and len(input_shape) != 2:
raise ValueError("input dim must be 1 or 2")
Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
num_samples_value = num_samples["value"]
if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const")


+ 2
- 2
mindspore/train/serialization.py View File

@@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs):
if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')

mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
mean = Validator.check_value_type("mean", mean, (int, float))
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))

if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))


+ 1
- 1
tests/ut/python/ir/test_row_tensor.py View File

@@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer):

def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value']
params_shp = params['shape']


Loading…
Cancel
Save