Browse Source

!8114 rectify and optimize the type checking function

Merge pull request !8114 from zhangbuxue/rectify_and_optimize_the_type_checking_function
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e9cd12e904
30 changed files with 648 additions and 712 deletions
  1. +38
    -108
      mindspore/_checkparam.py
  2. +18
    -4
      mindspore/common/tensor.py
  3. +10
    -8
      mindspore/nn/graph_kernels/graph_kernels.py
  4. +8
    -8
      mindspore/nn/layer/normalization.py
  5. +2
    -2
      mindspore/nn/layer/quant.py
  6. +1
    -1
      mindspore/nn/probability/bijector/gumbel_cdf.py
  7. +1
    -1
      mindspore/nn/probability/distribution/bernoulli.py
  8. +1
    -1
      mindspore/nn/probability/distribution/categorical.py
  9. +1
    -1
      mindspore/nn/probability/distribution/exponential.py
  10. +1
    -1
      mindspore/nn/probability/distribution/geometric.py
  11. +1
    -1
      mindspore/nn/probability/distribution/gumbel.py
  12. +1
    -1
      mindspore/nn/probability/distribution/logistic.py
  13. +1
    -1
      mindspore/nn/probability/distribution/normal.py
  14. +1
    -1
      mindspore/nn/probability/distribution/uniform.py
  15. +5
    -8
      mindspore/ops/operations/_cache_ops.py
  16. +57
    -51
      mindspore/ops/operations/_grad_ops.py
  17. +20
    -19
      mindspore/ops/operations/_inner_ops.py
  18. +63
    -78
      mindspore/ops/operations/_quant_ops.py
  19. +13
    -8
      mindspore/ops/operations/_thor_ops.py
  20. +64
    -59
      mindspore/ops/operations/array_ops.py
  21. +7
    -7
      mindspore/ops/operations/comm_ops.py
  22. +2
    -3
      mindspore/ops/operations/control_ops.py
  23. +1
    -1
      mindspore/ops/operations/debug_ops.py
  24. +5
    -5
      mindspore/ops/operations/image_ops.py
  25. +90
    -89
      mindspore/ops/operations/math_ops.py
  26. +214
    -222
      mindspore/ops/operations/nn_ops.py
  27. +10
    -11
      mindspore/ops/operations/other_ops.py
  28. +9
    -9
      mindspore/ops/operations/random_ops.py
  29. +2
    -2
      mindspore/train/serialization.py
  30. +1
    -1
      tests/ut/python/ir/test_row_tensor.py

+ 38
- 108
mindspore/_checkparam.py View File

@@ -415,37 +415,20 @@ class Validator:
break break
if not hit: if not hit:
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_types))}, but got {type_str}.')
raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass'
f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.')


@staticmethod @staticmethod
def check_const_input(arg_name, arg_value, prim_name): def check_const_input(arg_name, arg_value, prim_name):
"""Checks valid value.""" """Checks valid value."""
if arg_value is None: if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.')
return arg_value return arg_value


@staticmethod @staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.')

if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if arg_value in valid_types:
return arg_value
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()

@staticmethod
def check_type_same(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same."""
def _check_tensor_type(arg):
def check_types_same_and_valid(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same and valid."""
def _check_type_valid(arg):
arg_key, arg_val = arg arg_key, arg_val = arg
elem_type = arg_val elem_type = arg_val
Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
@@ -455,21 +438,27 @@ class Validator:
arg1_name, arg1_type = arg1 arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2 arg2_name, arg2_type = arg2
if arg1_type != arg2_type: if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
return arg1 return arg1


elem_types = map(_check_tensor_type, args.items())
elem_types = map(_check_type_valid, args.items())
reduce(_check_types_same, elem_types) reduce(_check_types_same, elem_types)


@staticmethod @staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""Checks whether the element types of input tensors are the same."""
tensor_types = [mstype.tensor_type(t) for t in valid_values]
Validator.check_type_same(args, tensor_types, prim_name)
def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are the same and valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_types_same_and_valid(args, tensor_types, prim_name)

@staticmethod
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
"""Checks whether the element types of input tensors are valid."""
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)


@staticmethod @staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
""" """
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
@@ -480,7 +469,7 @@ class Validator:
if isinstance(arg_val, type(mstype.tensor)): if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type() arg_val = arg_val.element_type()
if not arg_val in valid_values: if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.') f' but `{arg_key}` is {arg_val}.')
return arg return arg


@@ -512,40 +501,40 @@ class Validator:


def raise_error_msg(): def raise_error_msg():
"""func for raising error message when check failed""" """func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
num_types = len(valid_types) num_types = len(valid_types)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
f'{type_names if num_types > 1 else type_names[0]}, '
f'but got {arg_value} with type {type(arg_value).__name__}.')


# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass # `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types): if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg() raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
if not isinstance(arg_value, tuple(valid_types)):
raise_error_msg()
return arg_value


@staticmethod @staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name): def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types""" """Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)


def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
num_types = len(valid_types)
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
f"{type_names if num_types > 1 else type_names[0]}, "
f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")


if isinstance(arg_type, type(mstype.tensor)): if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type() arg_type = arg_type.element_type()

if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
if arg_type not in valid_types:
raise_error_msg()
return arg_type


@staticmethod @staticmethod
def check_reduce_shape(ori_shape, shape, axis, prim_name): def check_reduce_shape(ori_shape, shape, axis, prim_name):
@@ -611,65 +600,6 @@ def check_output_data(data):
once = _expand_tuple(1) once = _expand_tuple(1)
twice = _expand_tuple(2) twice = _expand_tuple(2)
triple = _expand_tuple(3) triple = _expand_tuple(3)
valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, bool, np.bool_)


def check_type(arg_name, arg_value, valid_types):
"""Check value type."""
# if input type is Tensor ,get element type
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()

# First, check if arg_value has argvalid_types
if isinstance(arg_value, tuple(valid_types)):
return type(arg_value).__name__

# Second, wrap arg_value with numpy array so that it can be checked through numpy api
if isinstance(arg_value, (list, tuple)):
arg_value = np.array(arg_value)

# Thirdly, check the data type by numpy's dtype api
valid = False
if isinstance(arg_value, np.ndarray):
valid = arg_value.dtype in valid_data_types

# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
valid = False

if not valid:
type_names = [t.__name__ for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {type(arg_value).__name__}.')

return type(arg_value).__name__


def check_typename(arg_name, arg_type, valid_types):
"""Check type name."""

def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)

if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()

if arg_type in valid_types:
return arg_type
if isinstance(arg_type, tuple(valid_types)):
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')




def args_type_check(*type_args, **type_kwargs): def args_type_check(*type_args, **type_kwargs):


+ 18
- 4
mindspore/common/tensor.py View File

@@ -19,7 +19,7 @@ from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import Tensor as Tensor_ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor as MetaTensor_ from .._c_expression import MetaTensor as MetaTensor_
from .._checkparam import check_type, check_typename
from .._checkparam import Validator as validator
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry


@@ -64,9 +64,19 @@ class Tensor(Tensor_):
input_data = np.array(input_data) input_data = np.array(input_data)


# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int))
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is "
f"{input_data.dtype} that is not supported to initialize a Tensor.")
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
if dtype is not None: if dtype is not None:
check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,))
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor")

if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data) input_data = np.ascontiguousarray(input_data)
if dtype is None: if dtype is None:
@@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_):
Returns: Returns:
Array, an array after being initialized. Array, an array after being initialized.
""" """

def __init__(self, dtype, shape, init=None): def __init__(self, dtype, shape, init=None):
#check param
# check param
self.init = init self.init = init
MetaTensor_.__init__(self, dtype, shape) MetaTensor_.__init__(self, dtype, shape)


@@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_):
msg = "Error shape={}".format(shape) msg = "Error shape={}".format(shape)
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)

class seed_context: class seed_context:
'''set and restore seed''' '''set and restore seed'''

def __init__(self, init): def __init__(self, init):
self.init = init self.init = init
from .seed import get_seed from .seed import get_seed
@@ -482,4 +495,5 @@ def _vm_compare(*args):
y = args[0] y = args[0]
return Tensor(np.array(fn(y))) return Tensor(np.array(fn(y)))



tensor_operator_registry.register('vm_compare', _vm_compare) tensor_operator_registry.register('vm_compare', _vm_compare)

+ 10
- 8
mindspore/nn/graph_kernels/graph_kernels.py View File

@@ -21,7 +21,7 @@ from ...ops import operations as P
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
from ...ops.composite import multitype_ops as C from ...ops.composite import multitype_ops as C
from ...ops.operations import _grad_ops as G from ...ops.operations import _grad_ops as G
from ..._checkparam import Validator
from ..._checkparam import Validator as validator
from ..cell import Cell, GraphKernel from ..cell import Cell, GraphKernel




@@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
use_locking=False, use_locking=False,
gradient_scale=1.0): gradient_scale=1.0):
super(ApplyMomentum, self).__init__() super(ApplyMomentum, self).__init__()
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__)
self.fake_output_assign_1 = InplaceAssign() self.fake_output_assign_1 = InplaceAssign()
self.fake_output_assign_1.add_prim_attr("fake_output", True) self.fake_output_assign_1.add_prim_attr("fake_output", True)
self.fake_output_assign_2 = InplaceAssign() self.fake_output_assign_2 = InplaceAssign()
@@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):


def __init__(self, keep_dims=True): def __init__(self, keep_dims=True):
super(ReduceMean, self).__init__() super(ReduceMean, self).__init__()
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__)
self.sum = P.ReduceSum(self.keep_dims) self.sum = P.ReduceSum(self.keep_dims)


def construct(self, x, axis): def construct(self, x, axis):
@@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel):
""" Forward function of the LayerNorm operator. """ """ Forward function of the LayerNorm operator. """
def __init__(self, begin_norm_axis=1, begin_params_axis=1): def __init__(self, begin_norm_axis=1, begin_params_axis=1):
super(LayerNormForward, self).__init__() super(LayerNormForward, self).__init__()
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int],
type(self).__name__)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int],
type(self).__name__)
self.mul = P.Mul() self.mul = P.Mul()
self.sum_keep_dims = P.ReduceSum(keep_dims=True) self.sum_keep_dims = P.ReduceSum(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
@@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel):


def __init__(self, axis=-1): def __init__(self, axis=-1):
super(LogSoftmax, self).__init__() super(LogSoftmax, self).__init__()
self.axis = Validator.check_type('axis', axis, [int])
self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__)
self.max_keep_dims = P.ReduceMax(keep_dims=True) self.max_keep_dims = P.ReduceMax(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.exp = P.Exp() self.exp = P.Exp()
@@ -952,13 +954,13 @@ class Softmax(GraphKernel):


def __init__(self, axis): def __init__(self, axis):
super(Softmax, self).__init__() super(Softmax, self).__init__()
Validator.check_type("axis", axis, [int, tuple])
validator.check_value_type("axis", axis, [int, tuple], type(self).__name__)
if isinstance(axis, int): if isinstance(axis, int):
self.axis = (axis,) self.axis = (axis,)
else: else:
self.axis = axis self.axis = axis
for item in self.axis: for item in self.axis:
Validator.check_type("item of axis", item, [int])
validator.check_value_type("item of axis", item, [int], type(self).__name__)
self.max = P.ReduceMax(keep_dims=True) self.max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub() self.sub = P.Sub()
self.exp = P.Exp() self.exp = P.Exp()


+ 8
- 8
mindspore/nn/layer/normalization.py View File

@@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
import mindspore.context as context import mindspore.context as context
from mindspore._checkparam import Validator, check_typename
from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management from mindspore.communication import management
@@ -52,7 +52,7 @@ class _BatchNorm(Cell):


if momentum < 0 or momentum > 1: if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC": if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.") raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics self.use_batch_statistics = use_batch_statistics
@@ -67,7 +67,7 @@ class _BatchNorm(Cell):
gamma_init, num_features), name="gamma", requires_grad=affine) gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer( self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine) beta_init, num_features), name="beta", requires_grad=affine)
self.group = Validator.check_positive_int(device_num_each_group)
self.group = validator.check_positive_int(device_num_each_group)
self.is_global = False self.is_global = False
if self.group != 1: if self.group != 1:
self.rank_id = get_rank() self.rank_id = get_rank()
@@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm):
use_batch_statistics, use_batch_statistics,
device_num_each_group, device_num_each_group,
input_dims='both') input_dims='both')
self.group = Validator.check_positive_int(device_num_each_group)
self.group = validator.check_positive_int(device_num_each_group)
if self.group <= 1: if self.group <= 1:
raise ValueError("the number of group must be greater than 1.") raise ValueError("the number of group must be greater than 1.")


@@ -607,12 +607,12 @@ class GroupNorm(Cell):


def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
super(GroupNorm, self).__init__() super(GroupNorm, self).__init__()
self.num_groups = Validator.check_positive_int(num_groups)
self.num_channels = Validator.check_positive_int(num_channels)
self.num_groups = validator.check_positive_int(num_groups)
self.num_channels = validator.check_positive_int(num_channels)
if num_channels % num_groups != 0: if num_channels % num_groups != 0:
raise ValueError("num_channels should be divided by num_groups") raise ValueError("num_channels should be divided by num_groups")
self.eps = check_typename('eps', eps, (float,))
self.affine = Validator.check_bool(affine)
self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
self.affine = validator.check_bool(affine)


gamma = initializer(gamma_init, num_channels) gamma = initializer(gamma_init, num_channels)
beta = initializer(beta_init, num_channels) beta = initializer(beta_init, num_channels)


+ 2
- 2
mindspore/nn/layer/quant.py View File

@@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range, symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels) num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__)
Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__)
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_non_negative_int(quant_delay, 'quant_delay') Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init self.min_init = min_init


+ 1
- 1
mindspore/nn/probability/bijector/gumbel_cdf.py View File

@@ -68,7 +68,7 @@ class GumbelCDF(Bijector):
""" """
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype)
super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param)




+ 1
- 1
mindspore/nn/probability/distribution/bernoulli.py View File

@@ -119,7 +119,7 @@ class Bernoulli(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Bernoulli, self).__init__(seed, dtype, name, param) super(Bernoulli, self).__init__(seed, dtype, name, param)


self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')


+ 1
- 1
mindspore/nn/probability/distribution/categorical.py View File

@@ -109,7 +109,7 @@ class Categorical(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type valid_dtype = mstype.int_type
Validator.check_type("Categorical", dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Categorical, self).__init__(seed, dtype, name, param) super(Categorical, self).__init__(seed, dtype, name, param)
self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')


+ 1
- 1
mindspore/nn/probability/distribution/exponential.py View File

@@ -121,7 +121,7 @@ class Exponential(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'rate': rate} param['param_dict'] = {'rate': rate}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Exponential, self).__init__(seed, dtype, name, param) super(Exponential, self).__init__(seed, dtype, name, param)


self._rate = self._add_parameter(rate, 'rate') self._rate = self._add_parameter(rate, 'rate')


+ 1
- 1
mindspore/nn/probability/distribution/geometric.py View File

@@ -122,7 +122,7 @@ class Geometric(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'probs': probs} param['param_dict'] = {'probs': probs}
valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Geometric, self).__init__(seed, dtype, name, param) super(Geometric, self).__init__(seed, dtype, name, param)


self._probs = self._add_parameter(probs, 'probs') self._probs = self._add_parameter(probs, 'probs')


+ 1
- 1
mindspore/nn/probability/distribution/gumbel.py View File

@@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution):
Constructor of Gumbel distribution. Constructor of Gumbel distribution.
""" """
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) gumbel_cdf = msb.GumbelCDF(loc, scale, dtype)
super(Gumbel, self).__init__( super(Gumbel, self).__init__(
distribution=msd.Uniform(0.0, 1.0, dtype=dtype), distribution=msd.Uniform(0.0, 1.0, dtype=dtype),


+ 1
- 1
mindspore/nn/probability/distribution/logistic.py View File

@@ -111,7 +111,7 @@ class Logistic(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'loc': loc, 'scale': scale} param['param_dict'] = {'loc': loc, 'scale': scale}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Logistic, self).__init__(seed, dtype, name, param) super(Logistic, self).__init__(seed, dtype, name, param)


self._loc = self._add_parameter(loc, 'loc') self._loc = self._add_parameter(loc, 'loc')


+ 1
- 1
mindspore/nn/probability/distribution/normal.py View File

@@ -127,7 +127,7 @@ class Normal(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'mean': mean, 'sd': sd} param['param_dict'] = {'mean': mean, 'sd': sd}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Normal, self).__init__(seed, dtype, name, param) super(Normal, self).__init__(seed, dtype, name, param)


self._mean_value = self._add_parameter(mean, 'mean') self._mean_value = self._add_parameter(mean, 'mean')


+ 1
- 1
mindspore/nn/probability/distribution/uniform.py View File

@@ -126,7 +126,7 @@ class Uniform(Distribution):
param = dict(locals()) param = dict(locals())
param['param_dict'] = {'low': low, 'high': high} param['param_dict'] = {'low': low, 'high': high}
valid_dtype = mstype.float_type valid_dtype = mstype.float_type
Validator.check_type(type(self).__name__, dtype, valid_dtype)
Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__)
super(Uniform, self).__init__(seed, dtype, name, param) super(Uniform, self).__init__(seed, dtype, name, param)


self._low = self._add_parameter(low, 'low') self._low = self._add_parameter(low, 'low')


+ 5
- 8
mindspore/ops/operations/_cache_ops.py View File

@@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer):
return [1] return [1]


def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
args = {"indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype




@@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer):


def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
return out_dtype return out_dtype


@@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer):
return out_shape return out_shape


def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
args = {"miss_emb_idx": miss_emb_idx_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
return out_dtype return out_dtype


@@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer):
return miss_value_shape return miss_value_shape


def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
args = {"swap_cache_idx": swap_cache_idx_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
return miss_value_dtype return miss_value_dtype




@@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer):


def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensor_type_same(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype) hashmap_dtype, hashmap_dtype)
return out_dtype return out_dtype

+ 57
- 51
mindspore/ops/operations/_grad_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================


"""Operators for gradients.""" """Operators for gradients."""
from functools import partial


from .. import signature as sig from .. import signature as sig
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@@ -23,6 +24,7 @@ from ...common import dtype as mstype
from .. import functional as F from .. import functional as F
from ... import context from ... import context



class AbsGrad(PrimitiveWithInfer): class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation.""" """Computes gradients for abs operation."""


@@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer):


def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x




@@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer):


def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x




@@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer):


def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x




@@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer):


def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x




@@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer):


def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype




@@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer):


def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8],
self.name)
return x_dtype return x_dtype




@@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer):


def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype




@@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer):


def infer_dtype(self, x_dtype, dout_dtype): def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype} args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype




@@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer):


def infer_dtype(self, x_type, y_type, doutput_type): def infer_dtype(self, x_type, y_type, doutput_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, y_type return x_type, y_type




@@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):


def infer_dtype(self, x_type, y_type, doutput_type, weight_type): def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if weight_type: if weight_type:
validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
return x_type return x_type
@@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
for i, dim_len in enumerate(w_size_v): for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']} args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
self.name)
out = { out = {
'value': None, 'value': None,
'shape': w_size_v, 'shape': w_size_v,
@@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
def __infer__(self, x, w_size, dout): def __infer__(self, x, w_size, dout):
w_size_v = w_size['value'] w_size_v = w_size['value']
args = {'x': x['dtype'], 'dout': dout['dtype']} args = {'x': x['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
out = { out = {
'value': None, 'value': None,
'shape': w_size_v, 'shape': w_size_v,
@@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):


def __infer__(self, x_size, w, dout): def __infer__(self, x_size, w, dout):
args = {'w': w['dtype'], 'dout': dout['dtype']} args = {'w': w['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
x_size_v = x_size['value'] x_size_v = x_size['value']
out = { out = {
'value': None, 'value': None,
@@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer):
return dy_shape return dy_shape


def infer_dtype(self, dy_dtype, mask_dtype): def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name) validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name)
return dy_dtype return dy_dtype




@@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("y_backprop", "x", "y"),
(y_backprop_dtype, x_dtype, y_dtype)))
return x_dtype return x_dtype




@@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad):


def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return x1_dtype return x1_dtype




@@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad):


def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
validator.check_tensor_type_same(args, [mstype.float16], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
return grad_dtype return grad_dtype




@@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer):


def infer_dtype(self, input_x, out, dout): def infer_dtype(self, input_x, out, dout):
args = {'input_x': input_x, 'out': out, 'dout': dout} args = {'input_x': input_x, 'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return input_x return input_x




@@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
hx_dtype, cx_dtype, reserve_dtype, state_dtype): hx_dtype, cx_dtype, reserve_dtype, state_dtype):
args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (dy_dtype, dy_dtype, dy_dtype) return (dy_dtype, dy_dtype, dy_dtype)




@@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name)
validator.check_tensor_type_same(args, valid_types, self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
if seq_dtype is not None: if seq_dtype is not None:
validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
if mask_dtype is not None: if mask_dtype is not None:
validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype




@@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer):
return y_backprop_shape, w_shape return y_backprop_shape, w_shape


def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
('y_backprop', "input_x", "weight"),
(y_backprop_dtype, A_dtype, w_dtype)))
return y_backprop_dtype, w_dtype return y_backprop_dtype, w_dtype




@@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype




@@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer):
return gradients_shape return gradients_shape


def infer_dtype(self, gradients_dtype, mask_dtype): def infer_dtype(self, gradients_dtype, mask_dtype):
validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name)
return gradients_dtype return gradients_dtype




@@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer):


def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
args = {'y_grad': y_grad_dtype, 'x': x_dtype} args = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype return x_dtype




@@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer):


def infer_dtype(self, out, dout): def infer_dtype(self, out, dout):
args = {'out': out, 'dout': dout} args = {'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out return out




@@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype




@@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, y_grad_dtype, x_dtype): def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype




@@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):


def infer_dtype(self, x_dtype, y_dtype, dout_dtype): def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dout_dtype return dout_dtype




@@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):


def infer_dtype(self, prediction, target, dloss): def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss} args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dloss return dloss




@@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])


def __infer__(self, dy, shapex, begin, end, strides): def __infer__(self, dy, shapex, begin, end, strides):
args = {"dy": dy['dtype']}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name)


for idx, item in enumerate(shapex['value']): for idx, item in enumerate(shapex['value']):
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
@@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer):


def infer_dtype(self, dout_dtype, x_dtype): def infer_dtype(self, dout_dtype, x_dtype):
args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
return x_dtype return x_dtype




@@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer):


def infer_dtype(self, out, dout): def infer_dtype(self, out, dout):
args = {"out": out, "dout": dout} args = {"out": out, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return out return out




@@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer):


def infer_dtype(self, x, dout): def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout} args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x return x




@@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer):


def infer_dtype(self, grads, x, y): def infer_dtype(self, grads, x, y):
args = {"grads": grads, "x": x, "y": y} args = {"grads": grads, "x": x, "y": y}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
return x return x


def infer_shape(self, grads, x, y): def infer_shape(self, grads, x, y):


+ 20
- 19
mindspore/ops/operations/_inner_ops.py View File

@@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, ksizes, strides, rates, padding="valid"): def __init__(self, ksizes, strides, rates, padding="valid"):
"""init""" """init"""

def _check_tuple_or_list(arg_name, arg_val, prim_name): def _check_tuple_or_list(arg_name, arg_val, prim_name):
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
@@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer):


def infer_dtype(self, input_x): def infer_dtype(self, input_x):
"""infer dtype""" """infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name)
return input_x return input_x




@@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
return x_dtype return x_dtype




@@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer):
>>> dequant = P.Dequant(False, False) >>> dequant = P.Dequant(False, False)
>>> y = dequant(input_x) >>> y = dequant(input_x)
""" """

@prim_attr_register @prim_attr_register
def __init__(self, sqrt_mode=False, relu_flag=False): def __init__(self, sqrt_mode=False, relu_flag=False):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
@@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer):
return assist return assist


def infer_dtype(self, assist, start, stop, num): def infer_dtype(self, assist, start, stop, num):
args = {"num": num}
validator.check_tensor_type_same(args, (mstype.int32,), self.name)
validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name)
args = {"assist": assist, "start": start, "stop": stop} args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
return assist return assist




@@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype): def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype} args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, assist_shape): def infer_shape(self, x_shape, assist_shape):
validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
validator.check('rank of x', len(x_shape)+1,
validator.check('rank of x', len(x_shape) + 1,
'rank of assist', len(assist_shape), Rel.LE, self.name) 'rank of assist', len(assist_shape), Rel.LE, self.name)
validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
assist_shape[-1], Rel.EQ, self.name) assist_shape[-1], Rel.EQ, self.name)
@@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer):
while r_idx >= r_end_dim: while r_idx >= r_end_dim:
if x_shape[r_idx] != 1: if x_shape[r_idx] != 1:
validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name)
assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name)
r_idx = r_idx - 1 r_idx = r_idx - 1


return assist_shape return assist_shape
@@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, assist_dtype): def infer_dtype(self, x_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "assist": assist_dtype} args = {"x": x_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, assist_shape): def infer_shape(self, x_shape, assist_shape):
@@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, diagonal_shape, assist_shape): def infer_shape(self, x_shape, diagonal_shape, assist_shape):
@@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer):
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape


def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name)
validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name)
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
b_dtype = mstype.float32 b_dtype = mstype.float32
if binput_dtype is not None: if binput_dtype is not None:
validator.check_tensor_type_same({"bias input dtype": binput_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = binput_dtype b_dtype = binput_dtype
elif bhidden_dtype is not None: elif bhidden_dtype is not None:
validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = bhidden_dtype b_dtype = bhidden_dtype
elif h_dtype is not None: elif h_dtype is not None:
validator.check_tensor_type_same({"init_h dtype": h_dtype},
(mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = h_dtype b_dtype = h_dtype
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype




+ 63
- 78
mindspore/ops/operations/_quant_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================


"""Operators for quantization.""" """Operators for quantization."""
from functools import partial


import mindspore.context as context import mindspore.context as context
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
return min_shape, max_shape return min_shape, max_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return min_type, max_type return min_type, max_type




@@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
return min_shape, max_shape return min_shape, max_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return min_type, max_type return min_type, max_type




@@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor) >>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
""" """

@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type return x_type




@@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
>>> min_gradient shape: (1,) data type: mstype.float32 >>> min_gradient shape: (1,) data type: mstype.float32
>>> max_gradient shape: (1,) data type: mstype.float32 >>> max_gradient shape: (1,) data type: mstype.float32
""" """

@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape return x_shape, min_shape, max_shape


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
('dout', "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return x_type, min_type, max_type return x_type, min_type, max_type




@@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
>>> input_tensor, min_tensor, max_tensor) >>> input_tensor, min_tensor, max_tensor)
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
""" """

@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type return x_type




@@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
>>> min_gradient shape: (4,) data type: mstype.float32 >>> min_gradient shape: (4,) data type: mstype.float32
>>> max_gradient shape: (4,) data type: mstype.float32 >>> max_gradient shape: (4,) data type: mstype.float32
""" """

@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
num_bits=8, num_bits=8,
@@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
return x_shape, min_shape, max_shape return x_shape, min_shape, max_shape


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return x_type, min_type, max_type return x_type, min_type, max_type




@@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer):


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type return x_type




@@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type return dout_type




@@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):


def infer_dtype(self, x_type, min_type, max_type): def infer_dtype(self, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("x", "min", "max"),
(x_type, min_type, max_type)))
return x_type return x_type




@@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):


def infer_dtype(self, dout_type, x_type, min_type, max_type): def infer_dtype(self, dout_type, x_type, min_type, max_type):
if context.get_context('device_target') == "GPU": if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
valid_dtypes = (mstype.float32,)
else: else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32)
tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
("dout", "x", "min", "max"),
(dout_type, x_type, min_type, max_type)))
return dout_type return dout_type




@@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type) validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type} args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type, x_type, x_type, x_type return x_type, x_type, x_type, x_type




@@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
global_step_type): global_step_type):
args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
"batch_mean": batch_mean_type, "batch_std": batch_std_type} "batch_mean": batch_mean_type, "batch_std": batch_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type return x_type




@@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer):


def infer_dtype(self, x_type, batch_std_type, running_std_type): def infer_dtype(self, x_type, batch_std_type, running_std_type):
args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type return x_type




@@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):


def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
if context.get_context('device_target') == "Ascend": if context.get_context('device_target') == "Ascend":
return x_type, x_type return x_type, x_type
return x_type, gamma_type return x_type, gamma_type
@@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer):
running_mean_type, global_step_type): running_mean_type, global_step_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return x_type return x_type




@@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
"dout type", dout_type) "dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type




@@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "mean type", mean_type)
validator.check("input type", x_type, "variance type", variance_type) validator.check("input type", x_type, "variance type", variance_type)
args = {"x": x_type, "mean": mean_type, "variance": variance_type} args = {"x": x_type, "mean": mean_type, "variance": variance_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type, x_type, x_type, x_type, x_type, x_type, x_type return x_type, x_type, x_type, x_type, x_type, x_type, x_type




@@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
validator.check("input type", x_type, "batch_mean type", batch_mean_type) validator.check("input type", x_type, "batch_mean type", batch_mean_type)
validator.check("input type", x_type, "batch_std type", batch_std_type) validator.check("input type", x_type, "batch_std type", batch_std_type)
args = {"input type": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
return x_type return x_type




@@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer):
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
"beta": beta_type, "gamma": gamma_type, "x": x_type} "beta": beta_type, "gamma": gamma_type, "x": x_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return x_type return x_type




@@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
"dout type", dout_type) "dout type", dout_type)
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
"running_std": running_std_type, "dout": dout_type} "running_std": running_std_type, "dout": dout_type}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return gamma_type, gamma_type, gamma_type, gamma_type return gamma_type, gamma_type, gamma_type, gamma_type






+ 13
- 8
mindspore/ops/operations/_thor_ops.py View File

@@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
def infer_shape(self, data1_shape): def infer_shape(self, data1_shape):
ll = [] ll = []
if len(data1_shape) == 2: if len(data1_shape) == 2:
ll = [1,]
ll = [1]
else: else:
ll = [32, 64] ll = [32, 64]
return ll return ll
@@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer):
>>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
>>> output = img2col(input_x) >>> output = img2col(input_x)
""" """

@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
kernel_size, kernel_size,
@@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer):
return out_shape return out_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
args = {'x': x_dtype}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
return x_dtype return x_dtype




@@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer):
return x2_shape return x2_shape


def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
validator.check_tensors_dtypes_same_and_valid(
{'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
return x2_dtype return x2_dtype



class Cholesky(PrimitiveWithInfer): class Cholesky(PrimitiveWithInfer):
""" """
Inner API for resnet50 THOR GPU backend Inner API for resnet50 THOR GPU backend
""" """

@prim_attr_register @prim_attr_register
def __init__(self, split_dim=0): def __init__(self, split_dim=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer):
return out_shape return out_shape


def infer_dtype(self, x1_dtype): def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype return x1_dtype



class DetTriangle(PrimitiveWithInfer): class DetTriangle(PrimitiveWithInfer):
""" """
Calculate the determinant of triangle matrices Calculate the determinant of triangle matrices
""" """

@prim_attr_register @prim_attr_register
def __init__(self, fill_mode=0): def __init__(self, fill_mode=0):
self.init_prim_io_names(inputs=['x1'], outputs=['y']) self.init_prim_io_names(inputs=['x1'], outputs=['y'])
@@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer):
return out_shape return out_shape


def infer_dtype(self, x1_dtype): def infer_dtype(self, x1_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
return x1_dtype return x1_dtype

+ 64
- 59
mindspore/ops/operations/array_ops.py View File

@@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype} args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp):
""" """
Defines _ScatterNd operators Defines _ScatterNd operators
""" """

def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
validator.check('the dimension of x', len(x_shape), validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE) 'the dimension of indices', indices_shape[-1], Rel.GE)
@@ -627,6 +628,7 @@ class Unique(Primitive):
>>> out = P.Unique()(x) >>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32)) (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32))
""" """

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
@@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck):
def __init__(self): def __init__(self):
"""Initialize index_select""" """Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2,])
self.add_prim_attr("dynamic_shape_depends", [2])


def __check__(self, params, indices, axis): def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']
@@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer):
>>> out = P.Padding(pad_dim_size)(x) >>> out = P.Padding(pad_dim_size)(x)
[[8, 0, 0, 0], [10, 0, 0, 0]] [[8, 0, 0, 0], [10, 0, 0, 0]]
""" """

@prim_attr_register @prim_attr_register
def __init__(self, pad_dim_size=8): def __init__(self, pad_dim_size=8):
"""Initialize padding""" """Initialize padding"""
@@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer):
>>> out = P.UniqueWithPad()(x, pad_num) >>> out = P.UniqueWithPad()(x, pad_num)
([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
""" """

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init UniqueWithPad""" """init UniqueWithPad"""


def __infer__(self, x, pad_num): def __infer__(self, x, pad_num):
validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
@@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer):
def __init__(self, seed=0, dtype=mstype.float32): def __init__(self, seed=0, dtype=mstype.float32):
"""Initialize TruncatedNormal""" """Initialize TruncatedNormal"""
validator.check_value_type('seed', seed, [int], self.name) validator.check_value_type('seed', seed, [int], self.name)
validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)


def __infer__(self, shape): def __infer__(self, shape):
shape_value = shape['value'] shape_value = shape['value']
@@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer):
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
for i, item in enumerate(dims['value']): for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name) validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value']) x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.full(dims['value'], x['value'], x_nptype) ret = np.full(dims['value'], x['value'], x_nptype)
out = { out = {
@@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype




@@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype




@@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer):
"""Initialize Argmax""" """Initialize Argmax"""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("axis", axis, [int], self.name)
validator.check_type_same({'output': output_type}, [mstype.int32], self.name)
validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name)
self.axis = axis self.axis = axis
self.add_prim_attr('output_type', output_type) self.add_prim_attr('output_type', output_type)


@@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def __init__(self): def __init__(self):
"""Initialize UnsortedSegmentSum""" """Initialize UnsortedSegmentSum"""
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("dynamic_shape_depends", [2,])
self.add_prim_attr("dynamic_shape_depends", [2])


def __infer__(self, x, segment_ids, num_segments): def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype'] x_type = x['dtype']
@@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
num_segments_type = num_segments['dtype'] num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
if isinstance(num_segments_type, type(mstype.tensor)): if isinstance(num_segments_type, type(mstype.tensor)):
validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name)
shp = [-1] shp = [-1]
else: else:
validator.check_value_type('num_segments', num_segments_v, [int], self.name) validator.check_value_type('num_segments', num_segments_v, [int], self.name)
@@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
x_shape = x['shape'] x_shape = x['shape']
segment_ids_shape = segment_ids['shape'] segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
x_shape = x['shape'] x_shape = x['shape']
segment_ids_shape = segment_ids['shape'] segment_ids_shape = segment_ids['shape']
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
validator.check_subclass("input_x", x_type, mstype.tensor, self.name) validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_value_type("x_shape", x_shape, [list], self.name) validator.check_value_type("x_shape", x_shape, [list], self.name)
valid_type = [mstype.float16, mstype.float32, mstype.int32] valid_type = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
validator.check(f'first shape of input_x', x_shape[0], validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer):
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)


args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)


first_elem = x_shp[0] first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]): for i, elem in enumerate(x_shp[1:]):
@@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype




@@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer):
self.add_prim_attr('T', x_type) self.add_prim_attr('T', x_type)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor, self.name) validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name)
if x_type != y_type: if x_type != y_type:
raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
return x_type return x_type
@@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer):
validator.check_positive_int(n, "n", self.name) validator.check_positive_int(n, "n", self.name)
validator.check_positive_int(m, "m", self.name) validator.check_positive_int(m, "m", self.name)
args = {"dtype": t} args = {"dtype": t}
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
np_type = mstype.dtype_to_nptype(t) np_type = mstype.dtype_to_nptype(t)
ret = np.eye(n, m, dtype=np_type) ret = np.eye(n, m, dtype=np_type)
return Tensor(ret) return Tensor(ret)
@@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape): def __infer__(self, indices, update, shape):
shp = shape['value'] shp = shape['value']
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name)
validator.check_value_type("shape", shp, [tuple], self.name) validator.check_value_type("shape", shp, [tuple], self.name)
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_positive_int(x, f'shape[{i}]', self.name) validator.check_positive_int(x, f'shape[{i}]', self.name)
@@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
validator.check_non_negative_int(value, f'{i}th value of size', self.name) validator.check_non_negative_int(value, f'{i}th value of size', self.name)
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])


def infer_shape(self, x):
validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name)
return tuple(x)[:-2] + tuple(self.size)
def infer_shape(self, x_shape):
validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
return tuple(x_shape)[:-2] + tuple(self.size)


def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype




class GatherNd(PrimitiveWithInfer): class GatherNd(PrimitiveWithInfer):
@@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer):
return indices_shape[:-1] + x_shape[indices_shape[-1]:] return indices_shape[:-1] + x_shape[indices_shape[-1]:]


def infer_dtype(self, x_dtype, indices_dtype): def infer_dtype(self, x_dtype, indices_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
return x_dtype return x_dtype




@@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp):
[0.4 2.2 -3.2]] [0.4 2.2 -3.2]]
""" """



@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Initialize ScatterNdUpdate""" """Initialize ScatterNdUpdate"""
@@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])


def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype} args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
return x_dtype return x_dtype




@@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer):
self.paddings = paddings self.paddings = paddings


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer):
self.crops = crops self.crops = crops


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
self.add_prim_attr("paddings", paddings_append) self.add_prim_attr("paddings", paddings_append)


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
self.add_prim_attr("crops", crops_append) self.add_prim_attr("crops", crops_append)


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
@@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer):
out_shape = tuple(tuple(shape_0) for _ in range(n)) out_shape = tuple(tuple(shape_0) for _ in range(n))
return out_shape return out_shape



def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name) validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
n = len(x_type) n = len(x_type)
@@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer):
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError) validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
return x_type return x_type



class InplaceUpdate(PrimitiveWithInfer): class InplaceUpdate(PrimitiveWithInfer):
r""" r"""
Updates specified rows with values in `v`. Updates specified rows with values in `v`.
@@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype): def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype} args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32] valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, v_shape): def infer_shape(self, x_shape, v_shape):
@@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer):
return x return x


def infer_dtype(self, x, seq_lengths): def infer_dtype(self, x, seq_lengths):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
return x return x




@@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer):
validator.check_const_input('truth_shape', truth_shape['value'], self.name) validator.check_const_input('truth_shape', truth_shape['value'], self.name)
args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
"truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
validator.check_tensor_type_same(args_int, [mstype.int64], self.name)
validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name)
args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)


hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
@@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer):
Outputs: Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`. Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`.
""" """

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.__setattr_flag__ = True self.__setattr_flag__ = True
@@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer):
def __infer__(self, x, shape): def __infer__(self, x, shape):
shp = shape['value'] shp = shape['value']
dtype = x['dtype'] dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp)) self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp, return {'shape': shp,
'dtype': dtype, 'dtype': dtype,
@@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer):
return x_shape, x_shape return x_shape, x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype, mstype.tensor_type(mstype.int32) return x_dtype, mstype.tensor_type(mstype.int32)




@@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset) >>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
""" """

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize index_select""" """Initialize index_select"""
@@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer):


def __infer__(self, params, indices, offset): def __infer__(self, params, indices, offset):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape'] params_shp = params['shape']
if len(params_shp) != 2: if len(params_shp) != 2:
@@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer):
>>> out = P.GatherD()(x, dim, index) >>> out = P.GatherD()(x, dim, index)
[[1, 1], [4, 3]] [[1, 1], [4, 3]]
""" """

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize GatherD""" """Initialize GatherD"""
@@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer):


def __infer__(self, x, dim, index): def __infer__(self, x, dim, index):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name)
validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name)
validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name)
x_shp = x['shape'] x_shp = x['shape']
idx_shp = index['shape'] idx_shp = index['shape']
@@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer):
>>> y = P.Identity()(x) >>> y = P.Identity()(x)
[1, 2, 3, 4] [1, 2, 3, 4]
""" """

@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize identity""" """Initialize identity"""


+ 7
- 7
mindspore/ops/operations/comm_ops.py View File

@@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype




@@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype


def __call__(self, tensor): def __call__(self, tensor):
@@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype


def __call__(self, tensor): def __call__(self, tensor):
@@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype


def __call__(self, tensor): def __call__(self, tensor):
@@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype


def __call__(self, tensor): def __call__(self, tensor):
@@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer):
if not isinstance(x_dtype, tuple): if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name}'s input should be a tuple!") raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype: for _ele in x_dtype:
validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
return x_dtype return x_dtype




@@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype return x_dtype


def __call__(self, tensor): def __call__(self, tensor):


+ 2
- 3
mindspore/ops/operations/control_ops.py View File

@@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer):
def infer_dtype(self, data_type, pred_type): def infer_dtype(self, data_type, pred_type):
validator.check_subclass( validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name) "data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
{"pred": pred_type}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
return (data_type, data_type) return (data_type, data_type)




@@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs): for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item args['inputs[%d]' % i] = item


validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32) return (inputs[0], mstype.int32)

+ 1
- 1
mindspore/ops/operations/debug_ops.py View File

@@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer):
return [1] return [1]


def infer_dtype(self, condition, inputs): def infer_dtype(self, condition, inputs):
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
for dtype in inputs: for dtype in inputs:
validator.check_subclass("input", dtype, [mstype.tensor], self.name) validator.check_subclass("input", dtype, [mstype.tensor], self.name)
return mstype.int32 return mstype.int32

+ 5
- 5
mindspore/ops/operations/image_ops.py View File

@@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer):
box_index_dtype = box_index['dtype'] box_index_dtype = box_index['dtype']
crop_size_dtype = crop_size['dtype'] crop_size_dtype = crop_size['dtype']
# check dytpe # check dytpe
validator.check_tensor_type_same({"x": x_dtype},
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name)
validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name)
validator.check_tensor_dtype_valid("x", x_dtype,
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name)
validator.check_value_type("crop_size", crop_size_value, [tuple], self.name) validator.check_value_type("crop_size", crop_size_value, [tuple], self.name)
# check input shape rank # check input shape rank
validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name)


+ 90
- 89
mindspore/ops/operations/math_ops.py View File

@@ -16,6 +16,8 @@
"""Operators for math.""" """Operators for math."""


import copy import copy
from functools import partial

import numpy as np import numpy as np
from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
@@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp):
@staticmethod @staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
args_type = {"x": x_dtype, "y": y_dtype} args_type = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_type, valid_dtype, prim_name)
validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name)
return x_dtype return x_dtype


def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
@@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp):
@staticmethod @staticmethod
def _check_bitwise_op_input_type(x1_type, x2_type, prim): def _check_bitwise_op_input_type(x1_type, x2_type, prim):
args = {'x1': x1_type, 'x2': x2_type} args = {'x1': x1_type, 'x2': x2_type}
valid_types = mstype.int_type + mstype.uint_type
validator.check_tensor_type_same(args, valid_types, prim)
valid_dtypes = mstype.int_type + mstype.uint_type
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim)
return x1_type return x1_type


def infer_dtype(self, x1_type, x2_type): def infer_dtype(self, x1_type, x2_type):
@@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer):


def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value} args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value return value




@@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer):


def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value} args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name)
return value return value




@@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer):
axis_v = axis['value'] axis_v = axis['value']
input_shp = input_x['shape'] input_shp = input_x['shape']
args = {'input_x': input_x['dtype']} args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name)


if axis_v is None: if axis_v is None:
raise ValueError(f"For {self.name}, axis must be const.") raise ValueError(f"For {self.name}, axis must be const.")
@@ -504,6 +506,7 @@ class ReduceMax(_Reduce):
def __infer__(self, input_x, axis): def __infer__(self, input_x, axis):
return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,)) return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,))



class ReduceMin(_Reduce): class ReduceMin(_Reduce):
""" """
Reduce a dimension of a tensor by the minimum value in the dimension. Reduce a dimension of a tensor by the minimum value in the dimension.
@@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer):


def infer_dtype(self, x_type, axis_type): def infer_dtype(self, x_type, axis_type):
cls_name = self.name cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type return x_type


@@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer):


def infer_dtype(self, x1, x2): def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2} args = {"x1": x1, "x2": x2}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name)
if x1.element_type() == mstype.int8: if x1.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32) return mstype.tensor_type(mstype.int32)
return x1 return x1
@@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer):
self.axes = axes self.axes = axes
validator.check_value_type('axes', axes, [int, tuple, list], self.name) validator.check_value_type('axes', axes, [int, tuple, list], self.name)
if not isinstance(self.axes, int): if not isinstance(self.axes, int):
self.axes = list(self.axes) # to avoid immutability issues
self.axes = list(self.axes) # to avoid immutability issues
if len(self.axes) != 2: if len(self.axes) != 2:
raise ValueError("Require two axes inputs, given less") raise ValueError("Require two axes inputs, given less")
self.int_to_tuple_conv() # convert before length checks
self.int_to_tuple_conv() # convert before length checks
if len(self.axes[0]) != len(self.axes[1]): if len(self.axes[0]) != len(self.axes[1]):
raise ValueError("Axes have to be the same size/length") raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
@@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer):
if isinstance(self.axes, int): if isinstance(self.axes, int):
if self.axes <= 0: if self.axes <= 0:
# outer product, no input validation required # outer product, no input validation required
self.axes = ([], []) # no axes selected for either
self.axes = ([], []) # no axes selected for either
return return
if self.axes > len(x1_shape) or self.axes > len(x2_shape): if self.axes > len(x1_shape) or self.axes > len(x2_shape):
raise ValueError( raise ValueError(
@@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer):


def infer_dtype(self, x1, x2): def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2} args = {"x1": x1, "x2": x2}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x1 return x1




@@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer):
if axis['value'] is None: if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.") raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name) validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
return {'shape': x_shp, return {'shape': x_shp,
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}
@@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer):
if dtype == mstype.undetermined: if dtype == mstype.undetermined:
contains_undetermined = True contains_undetermined = True
if not contains_undetermined: if not contains_undetermined:
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0] return inputs[0]


def infer_value(self, inputs): def infer_value(self, inputs):
@@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer):
args = {} args = {}
for i, dtype in enumerate(inputs): for i, dtype in enumerate(inputs):
args[f"inputs[{i}]"] = dtype args[f"inputs[{i}]"] = dtype
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0] return inputs[0]




@@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer):
"""Initialize Neg""" """Initialize Neg"""
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])


def infer_shape(self, input_x):
return input_x
def infer_shape(self, x_shape):
return x_shape


def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype


def infer_value(self, input_x): def infer_value(self, input_x):
if input_x is not None: if input_x is not None:
@@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype): def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype} args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32] valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, v_shape): def infer_shape(self, x_shape, v_shape):
@@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, v_dtype): def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype} args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32] valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype return x_dtype


def infer_shape(self, x_shape, v_shape): def infer_shape(self, x_shape, v_shape):
@@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape


def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype


def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
@@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape


def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
return x_dtype


def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
@@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])


def check_dtype(self, x_type): def check_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name)


def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
@@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type return x_type




@@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer):
return (self.nbins,) return (self.nbins,)


def infer_dtype(self, x_dtype, range_dtype): def infer_dtype(self, x_dtype, range_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
valid_types = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float16, mstype.float32, mstype.int32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name)
y_dtype = mstype.int32 y_dtype = mstype.int32
return y_dtype return y_dtype


@@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer):
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])


def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape


def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype




class Erf(PrimitiveWithInfer): class Erf(PrimitiveWithInfer):
@@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape


def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype




class Erfc(PrimitiveWithInfer): class Erfc(PrimitiveWithInfer):
@@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
return x_type return x_type




@@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name)
return x_dtype return x_dtype




@@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype return x_dtype




@@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp):
@staticmethod @staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None):
args_dtype = {"x": x_dtype, "y": y_dtype} args_dtype = {"x": x_dtype, "y": y_dtype}
validator.check_tensor_type_same(args_dtype, valid_type, prim_name)
validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)


def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
@@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args_dtype = {"x": x_dtype, "y": y_dtype} args_dtype = {"x": x_dtype, "y": y_dtype}
valid_type = [mstype.float32, mstype.float16] valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name)
validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)




@@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer):


def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype} args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype




@@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)




@@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return mstype.bool_ return mstype.bool_




@@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1] return [1]


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
return x_dtype return x_dtype




@@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
return [8] return [8]


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32 return mstype.float32




@@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
return [8] return [8]


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return mstype.float32 return mstype.float32




@@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))


def infer_dtype(self, bboxes_dtype): def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_) return (bboxes_dtype, mstype.int32, mstype.bool_)




@@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type return x_type


def infer_value(self, x): def infer_value(self, x):
@@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype return x_dtype




@@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
return x_shape return x_shape


def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name)
return x_dtype




class Tan(PrimitiveWithInfer): class Tan(PrimitiveWithInfer):
@@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
valid_types = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name)
return x_type return x_type




@@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type return x_type




@@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name)
return x_type return x_type




@@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer):
return [], [] return [], []


def infer_dtype(self, x_type, y_type): def infer_dtype(self, x_type, y_type):
validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name)
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name)
validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name)
return x_type, y_type return x_type, y_type




@@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer):
return x return x


def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x return x




@@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer):
return x return x


def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name)
return x return x




@@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32,
validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32,
mstype.int32], self.name) mstype.int32], self.name)
return x_dtype return x_dtype


@@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name)
validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name)
return x_dtype return x_dtype




@@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['input_x'], outputs=['y']) self.init_prim_io_names(inputs=['input_x'], outputs=['y'])


def __infer__(self, input_x): def __infer__(self, input_x):
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name)


x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type()) x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type())
if x_nptype == np.float16: if x_nptype == np.float16:
@@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer):
return (1,), (1,) return (1,), (1,)


def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
valid_types = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name)
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("input_value", "input_min", "input_max"),
(data_dtype, data_min_dtype, data_max_dtype)))
validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)

+ 214
- 222
mindspore/ops/operations/nn_ops.py
File diff suppressed because it is too large
View File


+ 10
- 11
mindspore/ops/operations/other_ops.py View File

@@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck):


def check_dtype(self, variable, value): def check_dtype(self, variable, value):
if variable != mstype.type_refkey: if variable != mstype.type_refkey:
validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name)
validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name)
validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name)




class BoundingBoxEncode(PrimitiveWithInfer): class BoundingBoxEncode(PrimitiveWithInfer):
@@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer):


def infer_dtype(self, anchor_box, groundtruth_box): def infer_dtype(self, anchor_box, groundtruth_box):
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box return anchor_box




@@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer):


def infer_dtype(self, anchor_box, deltas): def infer_dtype(self, anchor_box, deltas):
args = {"anchor_box": anchor_box, "deltas": deltas} args = {"anchor_box": anchor_box, "deltas": deltas}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return anchor_box return anchor_box




@@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer):


def infer_dtype(self, bboxes_type, metas_type): def infer_dtype(self, bboxes_type, metas_type):
valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8]
validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name)
validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name)
validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name)
validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name)
return mstype.bool_ return mstype.bool_




@@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer):


def infer_dtype(self, anchor_boxes, gt_boxes): def infer_dtype(self, anchor_boxes, gt_boxes):
valid_type = [mstype.float32, mstype.float16] valid_type = [mstype.float32, mstype.float16]
validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name)
validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name)
validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name)
validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name)
return anchor_boxes return anchor_boxes




@@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
if weights is not None: if weights is not None:
validator.check_subclass('weights', weights, mstype.tensor, self.name) validator.check_subclass('weights', weights, mstype.tensor, self.name)
args = {"labels": labels, "predictions": predictions} args = {"labels": labels, "predictions": predictions}
validator.check_tensor_type_same(args, (mstype.number_type), self.name)
validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name)
return labels return labels




@@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
args = {"x": x_dtype}
validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name)
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name)
return mstype.tensor_type(mstype.uint8) return mstype.tensor_type(mstype.uint8)


class Push(PrimitiveWithInfer): class Push(PrimitiveWithInfer):


+ 9
- 9
mindspore/ops/operations/random_ops.py View File

@@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
out = { out = {
@@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
out = { out = {
'shape': broadcast_shape, 'shape': broadcast_shape,
@@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer):
Validator.check_value_type("shape", shape_v, [tuple], self.name) Validator.check_value_type("shape", shape_v, [tuple], self.name)
for i, shape_i in enumerate(shape_v): for i, shape_i in enumerate(shape_v):
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name)
Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name)
minval_shape = minval['shape'] minval_shape = minval['shape']
maxval_shape = maxval['shape'] maxval_shape = maxval['shape']
Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
@@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
return ([self.count, len(x_shape)], [self.count]) return ([self.count, len(x_shape)], [self.count])


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_) return (mstype.int32, mstype.bool_)




@@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer):


def __infer__(self, logits, num_samples, seed): def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype'] logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
valid_dtypes = (mstype.float32, mstype.float16, mstype.float64)
Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name)
num_samples_v = num_samples['value'] num_samples_v = num_samples['value']
seed_v = seed['value'] seed_v = seed['value']
Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
@@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer):
input_shape = inputs["shape"] input_shape = inputs["shape"]
if len(input_shape) != 1 and len(input_shape) != 2: if len(input_shape) != 1 and len(input_shape) != 2:
raise ValueError("input dim must be 1 or 2") raise ValueError("input dim must be 1 or 2")
Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
num_samples_value = num_samples["value"] num_samples_value = num_samples["value"]
if num_samples_value is None: if num_samples_value is None:
raise ValueError(f"For {self.name}, shape nust be const") raise ValueError(f"For {self.name}, shape nust be const")


+ 2
- 2
mindspore/train/serialization.py View File

@@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs):
if quant_mode not in quant_mode_formats: if quant_mode not in quant_mode_formats:
raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.')


mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
mean = Validator.check_value_type("mean", mean, (int, float))
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))


if context.get_context('device_target') not in supported_device: if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))


+ 1
- 1
tests/ut/python/ir/test_row_tensor.py View File

@@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer):


def __infer__(self, params, indices, axis): def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']


Loading…
Cancel
Save