You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_squeeze_op.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.context as context
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore.ops import operations as P
  20. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  21. class Net(nn.Cell):
  22. def __init__(self):
  23. super(Net, self).__init__()
  24. self.squeeze = P.Squeeze()
  25. def construct(self, tensor):
  26. return self.squeeze(tensor)
  27. def test_net_bool():
  28. x = np.random.randn(1, 16, 1, 1).astype(np.bool)
  29. net = Net()
  30. output = net(Tensor(x))
  31. print(output.asnumpy())
  32. assert np.all(output.asnumpy() == x.squeeze())
  33. def test_net_uint8():
  34. x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
  35. net = Net()
  36. output = net(Tensor(x))
  37. print(output.asnumpy())
  38. assert np.all(output.asnumpy() == x.squeeze())
  39. def test_net_int16():
  40. x = np.random.randn(1, 16, 1, 1).astype(np.int16)
  41. net = Net()
  42. output = net(Tensor(x))
  43. print(output.asnumpy())
  44. assert np.all(output.asnumpy() == x.squeeze())
  45. def test_net_int32():
  46. x = np.random.randn(1, 16, 1, 1).astype(np.int32)
  47. net = Net()
  48. output = net(Tensor(x))
  49. print(output.asnumpy())
  50. assert np.all(output.asnumpy() == x.squeeze())
  51. def test_net_float16():
  52. x = np.random.randn(1, 16, 1, 1).astype(np.float16)
  53. net = Net()
  54. output = net(Tensor(x))
  55. print(output.asnumpy())
  56. assert np.all(output.asnumpy() == x.squeeze())
  57. def test_net_float32():
  58. x = np.random.randn(1, 16, 1, 1).astype(np.float32)
  59. net = Net()
  60. output = net(Tensor(x))
  61. print(output.asnumpy())
  62. assert np.all(output.asnumpy() == x.squeeze())