You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_array_ops_check.py 6.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test ops """
  16. import functools
  17. import numpy as np
  18. import mindspore.nn as nn
  19. import mindspore.ops.composite as C
  20. from mindspore import Tensor
  21. from mindspore import ops
  22. from mindspore.common import dtype as mstype
  23. from mindspore.common.api import _executor
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.ops import functional as F
  26. from mindspore.ops import operations as P
  27. from mindspore.ops.operations import _grad_ops as G
  28. from ..ut_filter import non_graph_engine
  29. from ....mindspore_test_framework.mindspore_test import mindspore_test
  30. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  31. import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
  32. pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
  33. from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
  34. import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
  35. class ExpandDimsNet(nn.Cell):
  36. def __init__(self, axis):
  37. super(ExpandDimsNet, self).__init__()
  38. self.axis = axis
  39. self.op = P.ExpandDims()
  40. def construct(self, x):
  41. return self.op(x, self.axis)
  42. class IsInstanceNet(nn.Cell):
  43. def __init__(self, inst):
  44. super(IsInstanceNet, self).__init__()
  45. self.inst = inst
  46. self.op = P.IsInstance()
  47. def construct(self, t):
  48. return self.op(self.inst, t)
  49. class ReshapeNet(nn.Cell):
  50. def __init__(self, shape):
  51. super(ReshapeNet, self).__init__()
  52. self.shape = shape
  53. self.op = P.Reshape()
  54. def construct(self, x):
  55. return self.op(x, self.shape)
  56. raise_set = [
  57. # input is scala, not Tensor
  58. ('ExpandDims0', {
  59. 'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
  60. 'desc_inputs': [5.0, 1],
  61. 'skip': ['backward']}),
  62. # axis is as a parameter
  63. ('ExpandDims1', {
  64. 'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
  65. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 1],
  66. 'skip': ['backward']}),
  67. # axis as an attribute, but less then lower limit
  68. ('ExpandDims2', {
  69. 'block': (ExpandDimsNet(-4), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
  70. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
  71. 'skip': ['backward']}),
  72. # axis as an attribute, but greater then upper limit
  73. ('ExpandDims3', {
  74. 'block': (ExpandDimsNet(3), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
  75. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
  76. 'skip': ['backward']}),
  77. # input is scala, not Tensor
  78. ('DType0', {
  79. 'block': (P.DType(), {'exception': TypeError, 'error_keywords': ['DType']}),
  80. 'desc_inputs': [5.0],
  81. 'skip': ['backward']}),
  82. # input x scala, not Tensor
  83. ('SameTypeShape0', {
  84. 'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
  85. 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
  86. 'skip': ['backward']}),
  87. # input y scala, not Tensor
  88. ('SameTypeShape1', {
  89. 'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
  90. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 5.0],
  91. 'skip': ['backward']}),
  92. # type of x and y not match
  93. ('SameTypeShape2', {
  94. 'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
  95. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.int32))],
  96. 'skip': ['backward']}),
  97. # shape of x and y not match
  98. ('SameTypeShape3', {
  99. 'block': (P.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
  100. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 3]).astype(np.float32))],
  101. 'skip': ['backward']}),
  102. # sub_type is None
  103. ('IsSubClass0', {
  104. 'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
  105. 'desc_inputs': [None, mstype.number],
  106. 'skip': ['backward']}),
  107. # type_ is None
  108. ('IsSubClass1', {
  109. 'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
  110. 'desc_inputs': [mstype.number, None],
  111. 'skip': ['backward']}),
  112. # inst is var
  113. ('IsInstance0', {
  114. 'block': (P.IsInstance(), {'exception': ValueError, 'error_keywords': ['IsInstance']}),
  115. 'desc_inputs': [5.0, mstype.number],
  116. 'skip': ['backward']}),
  117. # t is not mstype.Type
  118. ('IsInstance1', {
  119. 'block': (IsInstanceNet(5.0), {'exception': TypeError, 'error_keywords': ['IsInstance']}),
  120. 'desc_inputs': [None],
  121. 'skip': ['backward']}),
  122. # input x is scalar, not Tensor
  123. ('Reshape0', {
  124. 'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
  125. 'desc_inputs': [5.0, (1, 2)],
  126. 'skip': ['backward']}),
  127. # input shape is var
  128. ('Reshape1', {
  129. 'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
  130. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), (2, 3, 2)],
  131. 'skip': ['backward']}),
  132. # element of shape is not int
  133. ('Reshape3', {
  134. 'block': (ReshapeNet((2, 3.0, 2)), {'exception': TypeError, 'error_keywords': ['Reshape']}),
  135. 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
  136. 'skip': ['backward']}),
  137. ]
  138. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
  139. def test_check_exception():
  140. return raise_set