|
- import numpy as np
- from mindspore import context
- from mindspore import ms_function, ops, Tensor, dtype
-
-
- @ms_function
- def expand_tensor(a, b):
- out = ops.tile(a, b)
- return out
-
-
- def test_tile_eliminate():
- """
- Feature: tile_eliminate
- Description: All value of multiplier is '1' but length of multiplier is greater than tensor dims, can't do eliminate
- Expectation: success
- """
- context.set_context(mode=context.PYNATIVE_MODE)
- tensor_ = Tensor(np.ndarray([1, 448, 448]), dtype=dtype.float32)
- out = ops.tile(tensor_, (1, 1, 1))
- assert out.shape == (1, 448, 448)
- out = ops.tile(tensor_, (1, 1, 1, 1))
- assert out.shape == (1, 1, 448, 448)
- out = expand_tensor(tensor_, (1, 1, 1))
- assert out.shape == (1, 448, 448)
- out = expand_tensor(tensor_, (1, 1, 1, 1))
- assert out.shape == (1, 1, 448, 448)
|