You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_autocast.py 8.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """multitype_ops directory test case"""
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore import dtype as mstype
  21. from mindspore.ops import functional as F
  22. import mindspore.context as context
  23. class TensorIntAutoCast(nn.Cell):
  24. def __init__(self,):
  25. super(TensorIntAutoCast, self).__init__()
  26. self.i = 2
  27. def construct(self, t):
  28. z = F.tensor_mul(t, self.i)
  29. return z
  30. class TensorFPAutoCast(nn.Cell):
  31. def __init__(self,):
  32. super(TensorFPAutoCast, self).__init__()
  33. self.f = 1.2
  34. def construct(self, t):
  35. z = F.tensor_mul(t, self.f)
  36. return z
  37. class TensorBoolAutoCast(nn.Cell):
  38. def __init__(self,):
  39. super(TensorBoolAutoCast, self).__init__()
  40. self.f = True
  41. def construct(self, t):
  42. z = F.tensor_mul(t, self.f)
  43. return z
  44. class TensorAutoCast(nn.Cell):
  45. def __init__(self,):
  46. super(TensorAutoCast, self).__init__()
  47. def construct(self, t1, t2):
  48. z = F.tensor_mul(t1, t2)
  49. return z
  50. def test_tensor_auto_cast():
  51. context.set_context(mode=context.GRAPH_MODE)
  52. Tensor([True, False], mstype.bool_)
  53. t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8)
  54. t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8)
  55. t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16)
  56. t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
  57. t_int64 = Tensor(np.ones([2, 1, 2, 2]), mstype.int64)
  58. t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
  59. t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
  60. t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64)
  61. net = TensorAutoCast()
  62. rs = net(t_uint8, t_int8)
  63. assert rs.dtype == mstype.int16
  64. rs = net(t_uint8, t_int16)
  65. assert rs.dtype == mstype.int16
  66. rs = net(t_uint8, t_int32)
  67. assert rs.dtype == mstype.int32
  68. rs = net(t_uint8, t_int64)
  69. assert rs.dtype == mstype.int64
  70. rs = net(t_int8, t_int16)
  71. assert rs.dtype == mstype.int16
  72. rs = net(t_int8, t_int32)
  73. assert rs.dtype == mstype.int32
  74. rs = net(t_int8, t_int64)
  75. assert rs.dtype == mstype.int64
  76. rs = net(t_int16, t_int32)
  77. assert rs.dtype == mstype.int32
  78. rs = net(t_int16, t_int64)
  79. assert rs.dtype == mstype.int64
  80. rs = net(t_int32, t_int64)
  81. assert rs.dtype == mstype.int64
  82. rs = net(t_fp16, t_fp32)
  83. assert rs.dtype == mstype.float32
  84. rs = net(t_fp16, t_fp64)
  85. assert rs.dtype == mstype.float64
  86. rs = net(t_fp32, t_fp64)
  87. assert rs.dtype == mstype.float64
  88. rs = net(t_uint8, t_fp16)
  89. assert rs.dtype == mstype.float16
  90. rs = net(t_uint8, t_fp32)
  91. assert rs.dtype == mstype.float32
  92. rs = net(t_uint8, t_fp64)
  93. assert rs.dtype == mstype.float64
  94. rs = net(t_int8, t_fp64)
  95. assert rs.dtype == mstype.float64
  96. rs = net(t_int16, t_fp64)
  97. assert rs.dtype == mstype.float64
  98. rs = net(t_int32, t_fp64)
  99. assert rs.dtype == mstype.float64
  100. rs = net(t_int64, t_fp64)
  101. assert rs.dtype == mstype.float64
  102. rs = net(t_fp16, t_int8)
  103. assert rs.dtype == mstype.float16
  104. rs = net(t_fp16, t_uint8)
  105. assert rs.dtype == mstype.float16
  106. rs = net(t_fp16, t_int16)
  107. assert rs.dtype == mstype.float16
  108. rs = net(t_fp16, t_int32)
  109. assert rs.dtype == mstype.float16
  110. rs = net(t_fp16, t_int64)
  111. assert rs.dtype == mstype.float16
  112. tint = TensorIntAutoCast()
  113. rs = tint(t_uint8)
  114. assert rs.dtype == mstype.uint8
  115. rs = tint(t_int8)
  116. assert rs.dtype == mstype.int8
  117. rs = tint(t_int16)
  118. assert rs.dtype == mstype.int16
  119. rs = tint(t_int32)
  120. assert rs.dtype == mstype.int32
  121. rs = tint(t_int64)
  122. assert rs.dtype == mstype.int64
  123. rs = tint(t_fp16)
  124. assert rs.dtype == mstype.float16
  125. rs = tint(t_fp32)
  126. assert rs.dtype == mstype.float32
  127. rs = tint(t_fp64)
  128. assert rs.dtype == mstype.float64
  129. tfp = TensorFPAutoCast()
  130. rs = tfp(t_uint8)
  131. assert rs.dtype == mstype.float32
  132. rs = tfp(t_int8)
  133. assert rs.dtype == mstype.float32
  134. rs = tfp(t_int16)
  135. assert rs.dtype == mstype.float32
  136. rs = tfp(t_int32)
  137. assert rs.dtype == mstype.float32
  138. rs = tfp(t_int64)
  139. assert rs.dtype == mstype.float32
  140. rs = tfp(t_fp16)
  141. assert rs.dtype == mstype.float32
  142. rs = tfp(t_fp32)
  143. assert rs.dtype == mstype.float32
  144. rs = tfp(t_fp64)
  145. assert rs.dtype == mstype.float64
  146. t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16)
  147. t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32)
  148. t_uint64 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint64)
  149. with pytest.raises(TypeError):
  150. net(t_uint16, t_uint8)
  151. with pytest.raises(TypeError):
  152. net(t_uint16, t_int8)
  153. with pytest.raises(TypeError):
  154. net(t_uint16, t_int16)
  155. with pytest.raises(TypeError):
  156. net(t_uint16, t_int32)
  157. with pytest.raises(TypeError):
  158. net(t_uint16, t_int64)
  159. with pytest.raises(TypeError):
  160. net(t_uint32, t_uint8)
  161. with pytest.raises(TypeError):
  162. net(t_uint32, t_int8)
  163. with pytest.raises(TypeError):
  164. net(t_uint32, t_int16)
  165. with pytest.raises(TypeError):
  166. net(t_uint32, t_int32)
  167. with pytest.raises(TypeError):
  168. net(t_uint32, t_int64)
  169. with pytest.raises(TypeError):
  170. net(t_uint64, t_uint8)
  171. with pytest.raises(TypeError):
  172. net(t_uint64, t_int8)
  173. with pytest.raises(TypeError):
  174. net(t_uint64, t_int16)
  175. with pytest.raises(TypeError):
  176. net(t_uint64, t_int32)
  177. with pytest.raises(TypeError):
  178. net(t_uint64, t_int64)
  179. with pytest.raises(TypeError):
  180. net(t_uint16, t_fp16)
  181. with pytest.raises(TypeError):
  182. net(t_uint16, t_fp32)
  183. with pytest.raises(TypeError):
  184. net(t_uint16, t_fp64)
  185. with pytest.raises(TypeError):
  186. net(t_uint32, t_fp16)
  187. with pytest.raises(TypeError):
  188. net(t_uint32, t_fp32)
  189. with pytest.raises(TypeError):
  190. net(t_uint32, t_fp64)
  191. with pytest.raises(TypeError):
  192. net(t_uint64, t_fp16)
  193. with pytest.raises(TypeError):
  194. net(t_uint64, t_fp32)
  195. with pytest.raises(TypeError):
  196. net(t_uint64, t_fp64)
  197. with pytest.raises(TypeError):
  198. tfp(t_uint16)
  199. with pytest.raises(TypeError):
  200. tfp(t_uint32)
  201. with pytest.raises(TypeError):
  202. tfp(t_uint64)
  203. with pytest.raises(TypeError):
  204. tint(t_uint16)
  205. with pytest.raises(TypeError):
  206. tint(t_uint32)
  207. with pytest.raises(TypeError):
  208. tint(t_uint64)
  209. bnet = TensorBoolAutoCast()
  210. with pytest.raises(TypeError):
  211. bnet(t_uint8)
  212. with pytest.raises(TypeError):
  213. bnet(t_int8)
  214. with pytest.raises(TypeError):
  215. bnet(t_int16)
  216. with pytest.raises(TypeError):
  217. bnet(t_int32)
  218. with pytest.raises(TypeError):
  219. bnet(t_int64)
  220. with pytest.raises(TypeError):
  221. bnet(t_fp16)
  222. with pytest.raises(TypeError):
  223. bnet(t_fp32)
  224. with pytest.raises(TypeError):
  225. bnet(t_fp64)
  226. def test_bool_tensor_and_float():
  227. context.set_context(mode=context.GRAPH_MODE)
  228. t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_)
  229. t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
  230. t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
  231. t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
  232. net = TensorFPAutoCast()
  233. out = net(t_bool)
  234. assert out.dtype == mstype.float32
  235. net = TensorIntAutoCast()
  236. out = net(t_bool)
  237. assert out.dtype == mstype.int32
  238. out = net(t_fp16)
  239. assert out.dtype == mstype.float16
  240. out = net(t_fp32)
  241. assert out.dtype == mstype.float32
  242. out = net(t_int32)
  243. assert out.dtype == mstype.int32