You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_l1_regularizer_op.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ Test L1Regularizer """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. import mindspore.context as context
  20. from mindspore import Tensor, ms_function
  21. context.set_context(mode=context.GRAPH_MODE)
  22. class Net_l1_regularizer(nn.Cell):
  23. def __init__(self, scale):
  24. super(Net_l1_regularizer, self).__init__()
  25. self.l1_regularizer = nn.L1Regularizer(scale)
  26. @ms_function
  27. def construct(self, weights):
  28. return self.l1_regularizer(weights)
  29. @pytest.mark.level0
  30. @pytest.mark.platform_x86_cpu
  31. @pytest.mark.env_onecard
  32. def test_l1_regularizer01():
  33. scale = 0.5
  34. weights = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
  35. l1_regularizer = Net_l1_regularizer(scale)
  36. output = l1_regularizer(weights)
  37. print("After l1_regularizer01 is: ", output.asnumpy())
  38. print("output.shape: ", output.shape)
  39. print("output.dtype: ", output.dtype)
  40. expect = 5.0
  41. assert np.all(output.asnumpy() == expect)
  42. @pytest.mark.level0
  43. @pytest.mark.platform_x86_cpu
  44. @pytest.mark.env_onecard
  45. def test_l1_regularizer08():
  46. scale = 0.5
  47. net = nn.L1Regularizer(scale)
  48. weights = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
  49. output = net(weights)
  50. expect = 5.0
  51. print("output : ", output.asnumpy())
  52. assert np.all(output.asnumpy() == expect)
  53. @pytest.mark.level0
  54. @pytest.mark.platform_x86_cpu
  55. @pytest.mark.env_onecard
  56. def test_l1_regularizer_input_int():
  57. scale = 0.5
  58. net = nn.L1Regularizer(scale)
  59. weights = 2
  60. try:
  61. output = net(weights)
  62. print("output : ", output.asnumpy())
  63. except TypeError:
  64. assert True
  65. @pytest.mark.level0
  66. @pytest.mark.platform_x86_cpu
  67. @pytest.mark.env_onecard
  68. def test_l1_regularizer_input_tuple():
  69. scale = 0.5
  70. net = nn.L1Regularizer(scale)
  71. weights = (1, 2, 3, 4)
  72. try:
  73. output = net(weights)
  74. print("output : ", output.asnumpy())
  75. except TypeError:
  76. assert True