You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_autocast.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """multitype_ops directory test case"""
  16. import numpy as np
  17. from functools import partial, reduce
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore import dtype as mstype
  21. from mindspore.ops import functional as F, composite as C
  22. import mindspore.context as context
  23. import pytest
  24. class TensorIntAutoCast(nn.Cell):
  25. def __init__(self,):
  26. super(TensorIntAutoCast, self).__init__()
  27. self.i = 2
  28. def construct(self, t):
  29. z = F.tensor_mul(t, self.i)
  30. return z
  31. class TensorFPAutoCast(nn.Cell):
  32. def __init__(self,):
  33. super(TensorFPAutoCast, self).__init__()
  34. self.f = 1.2
  35. def construct(self, t):
  36. z = F.tensor_mul(t, self.f)
  37. return z
  38. class TensorBoolAutoCast(nn.Cell):
  39. def __init__(self,):
  40. super(TensorBoolAutoCast, self).__init__()
  41. self.f = True
  42. def construct(self, t):
  43. z = F.tensor_mul(t, self.f)
  44. return z
  45. class TensorAutoCast(nn.Cell):
  46. def __init__(self,):
  47. super(TensorAutoCast, self).__init__()
  48. def construct(self, t1, t2):
  49. z = F.tensor_mul(t1, t2)
  50. return z
  51. def test_tensor_auto_cast():
  52. context.set_context(mode=context.GRAPH_MODE)
  53. t0 = Tensor([True, False], mstype.bool_)
  54. t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8)
  55. t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8)
  56. t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16)
  57. t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
  58. t_int64 = Tensor(np.ones([2, 1, 2, 2]), mstype.int64)
  59. t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
  60. t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
  61. t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64)
  62. net = TensorAutoCast()
  63. rs = net(t_uint8, t_int8)
  64. assert rs.dtype() == mstype.int16
  65. rs = net(t_uint8, t_int16)
  66. assert rs.dtype() == mstype.int16
  67. rs = net(t_uint8, t_int32)
  68. assert rs.dtype() == mstype.int32
  69. rs = net(t_uint8, t_int64)
  70. assert rs.dtype() == mstype.int64
  71. rs = net(t_int8, t_int16)
  72. assert rs.dtype() == mstype.int16
  73. rs = net(t_int8, t_int32)
  74. assert rs.dtype() == mstype.int32
  75. rs = net(t_int8, t_int64)
  76. assert rs.dtype() == mstype.int64
  77. rs = net(t_int16, t_int32)
  78. assert rs.dtype() == mstype.int32
  79. rs = net(t_int16, t_int64)
  80. assert rs.dtype() == mstype.int64
  81. rs = net(t_int32, t_int64)
  82. assert rs.dtype() == mstype.int64
  83. rs = net(t_fp16, t_fp32)
  84. assert rs.dtype() == mstype.float32
  85. rs = net(t_fp16, t_fp64)
  86. assert rs.dtype() == mstype.float64
  87. rs = net(t_fp32, t_fp64)
  88. assert rs.dtype() == mstype.float64
  89. rs = net(t_uint8, t_fp16)
  90. assert rs.dtype() == mstype.float16
  91. rs = net(t_uint8, t_fp32)
  92. assert rs.dtype() == mstype.float32
  93. rs = net(t_uint8, t_fp64)
  94. assert rs.dtype() == mstype.float64
  95. rs = net(t_int8, t_fp64)
  96. assert rs.dtype() == mstype.float64
  97. rs = net(t_int16, t_fp64)
  98. assert rs.dtype() == mstype.float64
  99. rs = net(t_int32, t_fp64)
  100. assert rs.dtype() == mstype.float64
  101. rs = net(t_int64, t_fp64)
  102. assert rs.dtype() == mstype.float64
  103. rs = net(t_fp16, t_int8)
  104. assert rs.dtype() == mstype.float16
  105. rs = net(t_fp16, t_uint8)
  106. assert rs.dtype() == mstype.float16
  107. rs = net(t_fp16, t_int16)
  108. assert rs.dtype() == mstype.float16
  109. rs = net(t_fp16, t_int32)
  110. assert rs.dtype() == mstype.float16
  111. rs = net(t_fp16, t_int64)
  112. assert rs.dtype() == mstype.float16
  113. tint = TensorIntAutoCast()
  114. rs = tint(t_uint8)
  115. assert rs.dtype() == mstype.uint8
  116. rs = tint(t_int8)
  117. assert rs.dtype() == mstype.int8
  118. rs = tint(t_int16)
  119. assert rs.dtype() == mstype.int16
  120. rs = tint(t_int32)
  121. assert rs.dtype() == mstype.int32
  122. rs = tint(t_int64)
  123. assert rs.dtype() == mstype.int64
  124. rs = tint(t_fp16)
  125. assert rs.dtype() == mstype.float16
  126. rs = tint(t_fp32)
  127. assert rs.dtype() == mstype.float32
  128. rs = tint(t_fp64)
  129. assert rs.dtype() == mstype.float64
  130. tfp = TensorFPAutoCast()
  131. rs = tfp(t_uint8)
  132. assert rs.dtype() == mstype.float32
  133. rs = tfp(t_int8)
  134. assert rs.dtype() == mstype.float32
  135. rs = tfp(t_int16)
  136. assert rs.dtype() == mstype.float32
  137. rs = tfp(t_int32)
  138. assert rs.dtype() == mstype.float32
  139. rs = tfp(t_int64)
  140. assert rs.dtype() == mstype.float32
  141. rs = tfp(t_fp16)
  142. assert rs.dtype() == mstype.float32
  143. rs = tfp(t_fp32)
  144. assert rs.dtype() == mstype.float32
  145. rs = tfp(t_fp64)
  146. assert rs.dtype() == mstype.float64
  147. t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16)
  148. t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32)
  149. t_uint64 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint64)
  150. with pytest.raises(TypeError):
  151. net(t_uint16, t_uint8)
  152. with pytest.raises(TypeError):
  153. net(t_uint16, t_int8)
  154. with pytest.raises(TypeError):
  155. net(t_uint16, t_int16)
  156. with pytest.raises(TypeError):
  157. net(t_uint16, t_int32)
  158. with pytest.raises(TypeError):
  159. net(t_uint16, t_int64)
  160. with pytest.raises(TypeError):
  161. net(t_uint32, t_uint8)
  162. with pytest.raises(TypeError):
  163. net(t_uint32, t_int8)
  164. with pytest.raises(TypeError):
  165. net(t_uint32, t_int16)
  166. with pytest.raises(TypeError):
  167. net(t_uint32, t_int32)
  168. with pytest.raises(TypeError):
  169. net(t_uint32, t_int64)
  170. with pytest.raises(TypeError):
  171. net(t_uint64, t_uint8)
  172. with pytest.raises(TypeError):
  173. net(t_uint64, t_int8)
  174. with pytest.raises(TypeError):
  175. net(t_uint64, t_int16)
  176. with pytest.raises(TypeError):
  177. net(t_uint64, t_int32)
  178. with pytest.raises(TypeError):
  179. net(t_uint64, t_int64)
  180. with pytest.raises(TypeError):
  181. net(t_uint16, t_fp16)
  182. with pytest.raises(TypeError):
  183. net(t_uint16, t_fp32)
  184. with pytest.raises(TypeError):
  185. net(t_uint16, t_fp64)
  186. with pytest.raises(TypeError):
  187. net(t_uint32, t_fp16)
  188. with pytest.raises(TypeError):
  189. net(t_uint32, t_fp32)
  190. with pytest.raises(TypeError):
  191. net(t_uint32, t_fp64)
  192. with pytest.raises(TypeError):
  193. net(t_uint64, t_fp16)
  194. with pytest.raises(TypeError):
  195. net(t_uint64, t_fp32)
  196. with pytest.raises(TypeError):
  197. net(t_uint64, t_fp64)
  198. with pytest.raises(TypeError):
  199. tfp(t_uint16)
  200. with pytest.raises(TypeError):
  201. tfp(t_uint32)
  202. with pytest.raises(TypeError):
  203. tfp(t_uint64)
  204. with pytest.raises(TypeError):
  205. tint(t_uint16)
  206. with pytest.raises(TypeError):
  207. tint(t_uint32)
  208. with pytest.raises(TypeError):
  209. tint(t_uint64)
  210. bnet = TensorBoolAutoCast()
  211. with pytest.raises(TypeError):
  212. bnet(t_uint8)
  213. with pytest.raises(TypeError):
  214. bnet(t_int8)
  215. with pytest.raises(TypeError):
  216. bnet(t_int16)
  217. with pytest.raises(TypeError):
  218. bnet(t_int32)
  219. with pytest.raises(TypeError):
  220. bnet(t_int64)
  221. with pytest.raises(TypeError):
  222. bnet(t_fp16)
  223. with pytest.raises(TypeError):
  224. bnet(t_fp32)
  225. with pytest.raises(TypeError):
  226. bnet(t_fp64)