You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """st for scipy.utils"""
  16. import pytest
  17. import numpy as onp
  18. from mindspore import context, Tensor
  19. from mindspore.scipy.utils import _safe_normalize
  20. @pytest.mark.level0
  21. @pytest.mark.platform_x86_gpu_training
  22. @pytest.mark.platform_x86_cpu
  23. @pytest.mark.env_onecard
  24. @pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
  25. @pytest.mark.parametrize('shape', [(10,), (10, 1)])
  26. @pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
  27. def test_safe_normalize(mode, shape, dtype):
  28. """
  29. Feature: ALL TO ALL
  30. Description: test cases for _safe_normalize
  31. Expectation: the result match scipy
  32. """
  33. context.set_context(mode=mode)
  34. x = onp.random.random(shape).astype(dtype)
  35. normalized_x, x_norm = _safe_normalize(Tensor(x))
  36. normalized_x = normalized_x.asnumpy()
  37. x_norm = x_norm.asnumpy()
  38. assert onp.allclose(onp.sum(normalized_x ** 2), 1)
  39. assert onp.allclose(x / x_norm, normalized_x)