You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rolling_op.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. from functools import partial
  16. from typing import Tuple, List
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.ops import PrimitiveWithInfer, prim_attr_register
  21. from mindspore._checkparam import Validator as validator
  22. from mindspore.common import dtype as mstype
  23. import numpy as np
  24. import pytest
  25. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  26. class Rolling(PrimitiveWithInfer):
  27. """
  28. Shift op frontend implementation
  29. """
  30. @prim_attr_register
  31. def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
  32. method: str):
  33. """Initialize Sort"""
  34. self.window = validator.check_value_type("window", window, [int], self.name)
  35. self.min_periods = validator.check_value_type("min_periods", min_periods, [int], self.name)
  36. self.center = validator.check_value_type("center", center, [bool], self.name)
  37. self.axis = validator.check_value_type("axis", axis, [int], self.name)
  38. self.closed = validator.check_value_type("closed", closed, [str], self.name)
  39. self.method = validator.check_value_type("method", method, [str], self.name)
  40. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  41. def __infer__(self, x):
  42. out_shapes = x['shape']
  43. return {
  44. 'shape': tuple(out_shapes),
  45. 'dtype': x['dtype'],
  46. 'value': None
  47. }
  48. def infer_dtype(self, x_dtype):
  49. validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64, mstype.int32, mstype.int64],
  50. self.name, True)
  51. return x_dtype
  52. class RollingNet(nn.Cell):
  53. def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
  54. method: str):
  55. super(RollingNet, self).__init__()
  56. self.rolling = Rolling(window, min_periods, center, axis, closed, method)
  57. def construct(self, x):
  58. return self.rolling(x)
  59. def get_window_bounds(num_values: int, window_size: int, center: bool, closed: str = 'right') -> Tuple[List, List]:
  60. assert closed in {'left', 'both', 'right', 'neither'}
  61. offset = (window_size - 1) // 2 if center else 0
  62. end = np.arange(offset + 1, num_values + 1 + offset, dtype=np.int64)
  63. start = end - window_size
  64. if closed in {'left', 'both'}:
  65. start -= 1
  66. if closed in {'left', 'neither'}:
  67. end -= 1
  68. end = np.clip(end, 0, num_values)
  69. start = np.clip(start, 0, num_values)
  70. return list(start), list(end)
  71. def numpy_rolling(array: np.ndarray, window: int, min_periods: int, center: bool, axis: int, closed: str,
  72. method: str) -> np.ndarray:
  73. assert window > 0
  74. assert 0 < min_periods <= window
  75. assert axis in range(-array.ndim, array.ndim)
  76. reduce_map = {'max': np.max, 'min': np.min, 'mean': np.mean, 'sum': np.sum, 'std': partial(np.std, ddof=1),
  77. 'var': partial(np.var, ddof=1)}
  78. assert method in reduce_map
  79. size = array.shape[axis]
  80. start, end = get_window_bounds(size, window, center, closed)
  81. rolling_indices = [[slice(None)] * array.ndim for _ in range(len(start))]
  82. for i, j, indice in zip(start, end, rolling_indices):
  83. indice[axis] = None if j - i < min_periods else slice(i, j)
  84. # print(f'i={i}, j={j}, index={index}, indice={rolling_indices[index][axis]}')
  85. shape = list(array.shape)
  86. shape[axis] = 1
  87. nan_array = np.empty(shape)
  88. if array.dtype == np.float32 or array.dtype == np.float64:
  89. nan_array[:] = np.nan
  90. elif array.dtype == np.int32 or array.dtype == np.int64:
  91. nan_array[:] = 0
  92. arrays = [
  93. nan_array.copy() if not indice[axis]
  94. else reduce_map[method](array[tuple(indice)], axis=axis, keepdims=True).reshape(shape)
  95. for indice in rolling_indices]
  96. return np.stack(arrays, axis=axis).reshape(array.shape).astype(array.dtype)
  97. @pytest.mark.parametrize('shape', [(10, 8, 15, 7), (5, 3, 8, 10)])
  98. @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
  99. @pytest.mark.parametrize('window, min_periods', [(3, 3), (5, 3)])
  100. @pytest.mark.parametrize('center', [True, False])
  101. @pytest.mark.parametrize('axis', [2, 3, -1])
  102. @pytest.mark.parametrize('closed', ['left', 'both', 'right', 'neither'])
  103. @pytest.mark.parametrize('method', ['max', 'min', 'mean', 'sum', 'std', 'var'])
  104. def test_two_way(shape: List[int], dtype, window: int, min_periods: int, center: bool, axis: int, closed: str,
  105. method: str) -> np.ndarray:
  106. if dtype in (np.int32, np.int64):
  107. arr = np.random.randint(0, 100, size=shape)
  108. else:
  109. arr = np.random.random(shape).astype(dtype)
  110. expect_result = numpy_rolling(arr, window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
  111. method=method)
  112. rolling = RollingNet(window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
  113. method=method)
  114. actual_result = rolling(Tensor(arr)).asnumpy()
  115. print('arr: \n', arr, arr.dtype, arr.shape)
  116. print('np: \n', expect_result, expect_result.dtype, expect_result.shape)
  117. print('mine: \n', actual_result, actual_result.dtype, actual_result.shape)
  118. print(f'center: {center}, axis: {axis}, method: {method}')
  119. assert np.allclose(expect_result, actual_result, equal_nan=True)