You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_eager_transforms.py 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Eager Tests for Transform Tensor ops
  17. """
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. import mindspore.dataset.transforms.c_transforms as data_trans
  21. def test_eager_concatenate():
  22. """
  23. Test Concatenate op is callable
  24. """
  25. prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
  26. append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
  27. concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor)
  28. expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
  29. 11., 12.])
  30. assert np.array_equal(concatenate_op([5., 6., 7., 8.]), expected)
  31. def test_eager_fill():
  32. """
  33. Test Fill op is callable
  34. """
  35. fill_op = data_trans.Fill(3)
  36. expected = np.array([3, 3, 3, 3])
  37. assert np.array_equal(fill_op([4, 5, 6, 7]), expected)
  38. def test_eager_mask():
  39. """
  40. Test Mask op is callable
  41. """
  42. mask_op = data_trans.Mask(data_trans.Relational.EQ, 3, mstype.bool_)
  43. expected = np.array([False, False, True, False, False])
  44. assert np.array_equal(mask_op([1, 2, 3, 4, 5]), expected)
  45. def test_eager_pad_end():
  46. """
  47. Test PadEnd op is callable
  48. """
  49. pad_end_op = data_trans.PadEnd([3], -1)
  50. expected = np.array([1, 2, -1])
  51. assert np.array_equal(pad_end_op([1, 2]), expected)
  52. def test_eager_slice():
  53. """
  54. Test Slice op is callable
  55. """
  56. indexing = [[0], [0, 3]]
  57. slice_op = data_trans.Slice(*indexing)
  58. expected = np.array([[1, 4]])
  59. assert np.array_equal(slice_op([[1, 2, 3, 4, 5]]), expected)
  60. if __name__ == "__main__":
  61. test_eager_concatenate()
  62. test_eager_fill()
  63. test_eager_mask()
  64. test_eager_pad_end()
  65. test_eager_slice()