You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tile_eliminate.py 853 B

123456789101112131415161718192021222324252627
  1. import numpy as np
  2. from mindspore import context
  3. from mindspore import ms_function, ops, Tensor, dtype
  4. @ms_function
  5. def expand_tensor(a, b):
  6. out = ops.tile(a, b)
  7. return out
  8. def test_tile_eliminate():
  9. """
  10. Feature: tile_eliminate
  11. Description: All value of multiplier is '1' but length of multiplier is greater than tensor dims, can't do eliminate
  12. Expectation: success
  13. """
  14. context.set_context(mode=context.PYNATIVE_MODE)
  15. tensor_ = Tensor(np.ndarray([1, 448, 448]), dtype=dtype.float32)
  16. out = ops.tile(tensor_, (1, 1, 1))
  17. assert out.shape == (1, 448, 448)
  18. out = ops.tile(tensor_, (1, 1, 1, 1))
  19. assert out.shape == (1, 1, 448, 448)
  20. out = expand_tensor(tensor_, (1, 1, 1))
  21. assert out.shape == (1, 448, 448)
  22. out = expand_tensor(tensor_, (1, 1, 1, 1))
  23. assert out.shape == (1, 1, 448, 448)