|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- from functools import partial
- from typing import Tuple, List
- import mindspore.context as context
- import mindspore.nn as nn
- from mindspore import Tensor
- from mindspore.ops import PrimitiveWithInfer, prim_attr_register
- from mindspore._checkparam import Validator as validator
- from mindspore.common import dtype as mstype
- import numpy as np
- import pytest
-
- context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
-
-
- class Rolling(PrimitiveWithInfer):
- """
- Shift op frontend implementation
- """
-
- @prim_attr_register
- def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
- method: str):
- """Initialize Sort"""
- self.window = validator.check_value_type("window", window, [int], self.name)
- self.min_periods = validator.check_value_type("min_periods", min_periods, [int], self.name)
- self.center = validator.check_value_type("center", center, [bool], self.name)
- self.axis = validator.check_value_type("axis", axis, [int], self.name)
- self.closed = validator.check_value_type("closed", closed, [str], self.name)
- self.method = validator.check_value_type("method", method, [str], self.name)
-
- self.init_prim_io_names(inputs=['x'], outputs=['output'])
-
- def __infer__(self, x):
- out_shapes = x['shape']
- return {
- 'shape': tuple(out_shapes),
- 'dtype': x['dtype'],
- 'value': None
- }
-
- def infer_dtype(self, x_dtype):
- validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64, mstype.int32, mstype.int64],
- self.name, True)
- return x_dtype
-
-
- class RollingNet(nn.Cell):
- def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
- method: str):
- super(RollingNet, self).__init__()
- self.rolling = Rolling(window, min_periods, center, axis, closed, method)
-
- def construct(self, x):
- return self.rolling(x)
-
-
- def get_window_bounds(num_values: int, window_size: int, center: bool, closed: str = 'right') -> Tuple[List, List]:
- assert closed in {'left', 'both', 'right', 'neither'}
- offset = (window_size - 1) // 2 if center else 0
-
- end = np.arange(offset + 1, num_values + 1 + offset, dtype=np.int64)
- start = end - window_size
- if closed in {'left', 'both'}:
- start -= 1
- if closed in {'left', 'neither'}:
- end -= 1
-
- end = np.clip(end, 0, num_values)
- start = np.clip(start, 0, num_values)
-
- return list(start), list(end)
-
-
- def numpy_rolling(array: np.ndarray, window: int, min_periods: int, center: bool, axis: int, closed: str,
- method: str) -> np.ndarray:
- assert window > 0
- assert 0 < min_periods <= window
- assert axis in range(-array.ndim, array.ndim)
- reduce_map = {'max': np.max, 'min': np.min, 'mean': np.mean, 'sum': np.sum, 'std': partial(np.std, ddof=1),
- 'var': partial(np.var, ddof=1)}
- assert method in reduce_map
-
- size = array.shape[axis]
- start, end = get_window_bounds(size, window, center, closed)
-
- rolling_indices = [[slice(None)] * array.ndim for _ in range(len(start))]
- for i, j, indice in zip(start, end, rolling_indices):
- indice[axis] = None if j - i < min_periods else slice(i, j)
- # print(f'i={i}, j={j}, index={index}, indice={rolling_indices[index][axis]}')
-
- shape = list(array.shape)
- shape[axis] = 1
- nan_array = np.empty(shape)
- if array.dtype == np.float32 or array.dtype == np.float64:
- nan_array[:] = np.nan
- elif array.dtype == np.int32 or array.dtype == np.int64:
- nan_array[:] = 0
-
- arrays = [
- nan_array.copy() if not indice[axis]
- else reduce_map[method](array[tuple(indice)], axis=axis, keepdims=True).reshape(shape)
- for indice in rolling_indices]
-
- return np.stack(arrays, axis=axis).reshape(array.shape).astype(array.dtype)
-
-
- @pytest.mark.parametrize('shape', [(10, 8, 15, 7), (5, 3, 8, 10)])
- @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
- @pytest.mark.parametrize('window, min_periods', [(3, 3), (5, 3)])
- @pytest.mark.parametrize('center', [True, False])
- @pytest.mark.parametrize('axis', [2, 3, -1])
- @pytest.mark.parametrize('closed', ['left', 'both', 'right', 'neither'])
- @pytest.mark.parametrize('method', ['max', 'min', 'mean', 'sum', 'std', 'var'])
- def test_two_way(shape: List[int], dtype, window: int, min_periods: int, center: bool, axis: int, closed: str,
- method: str) -> np.ndarray:
- if dtype in (np.int32, np.int64):
- arr = np.random.randint(0, 100, size=shape)
- else:
- arr = np.random.random(shape).astype(dtype)
- expect_result = numpy_rolling(arr, window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
- method=method)
- rolling = RollingNet(window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
- method=method)
- actual_result = rolling(Tensor(arr)).asnumpy()
- print('arr: \n', arr, arr.dtype, arr.shape)
- print('np: \n', expect_result, expect_result.dtype, expect_result.shape)
- print('mine: \n', actual_result, actual_result.dtype, actual_result.shape)
- print(f'center: {center}, axis: {axis}, method: {method}')
- assert np.allclose(expect_result, actual_result, equal_nan=True)
|