You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_taylor_differentiation_graph.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test taylor differentiation in graph mode"""
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. import mindspore.context as context
  20. from mindspore.ops import operations as P
  21. from mindspore import Tensor
  22. from mindspore.ops.functional import jet, derivative
  23. context.set_context(mode=context.GRAPH_MODE)
  24. class MultipleInputSingleOutputNet(nn.Cell):
  25. def __init__(self):
  26. super(MultipleInputSingleOutputNet, self).__init__()
  27. self.sin = P.Sin()
  28. self.cos = P.Cos()
  29. self.exp = P.Exp()
  30. def construct(self, x, y):
  31. out1 = self.sin(x)
  32. out2 = self.cos(y)
  33. out3 = out1 * out2 + out1 / out2
  34. out = self.exp(out3)
  35. return out
  36. class SingleInputSingleOutputNet(nn.Cell):
  37. def __init__(self):
  38. super(SingleInputSingleOutputNet, self).__init__()
  39. self.sin = P.Sin()
  40. self.cos = P.Cos()
  41. self.exp = P.Exp()
  42. def construct(self, x):
  43. out1 = self.sin(x)
  44. out2 = self.cos(out1)
  45. out3 = self.exp(out2)
  46. out = out1 + out2 - out3
  47. return out
  48. @pytest.mark.level0
  49. @pytest.mark.platform_arm_ascend_training
  50. @pytest.mark.platform_x86_ascend_training
  51. @pytest.mark.platform_x86_gpu_training
  52. @pytest.mark.platform_x86_cpu
  53. @pytest.mark.env_onecard
  54. def test_jet_single_input_single_output_graph_mode():
  55. """
  56. Features: Function jet
  57. Description: Test jet with single input in graph mode.
  58. Expectation: No exception.
  59. """
  60. primals = Tensor([1., 1.])
  61. series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
  62. net = SingleInputSingleOutputNet()
  63. expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
  64. expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
  65. out_primals, out_series = jet(net, primals, series)
  66. assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
  67. assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
  68. @pytest.mark.level0
  69. @pytest.mark.platform_arm_ascend_training
  70. @pytest.mark.platform_x86_ascend_training
  71. @pytest.mark.platform_x86_gpu_training
  72. @pytest.mark.platform_x86_cpu
  73. @pytest.mark.env_onecard
  74. def test_derivative_single_input_single_output_graph_mode():
  75. """
  76. Features: Function derivative
  77. Description: Test derivative with single input in graph mode.
  78. Expectation: No exception.
  79. """
  80. primals = Tensor([1., 1.])
  81. order = 3
  82. net = SingleInputSingleOutputNet()
  83. expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
  84. expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
  85. out_primals, out_series = derivative(net, primals, order)
  86. assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
  87. assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
  88. @pytest.mark.level0
  89. @pytest.mark.platform_arm_ascend_training
  90. @pytest.mark.platform_x86_ascend_training
  91. @pytest.mark.platform_x86_gpu_training
  92. @pytest.mark.platform_x86_cpu
  93. @pytest.mark.env_onecard
  94. def test_jet_multiple_input_single_output_graph_mode():
  95. """
  96. Features: Function jet
  97. Description: Test jet with multiple inputs in graph mode.
  98. Expectation: No exception.
  99. """
  100. primals = (Tensor([1., 1.]), Tensor([1., 1.]))
  101. series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
  102. net = MultipleInputSingleOutputNet()
  103. expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
  104. expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
  105. out_primals, out_series = jet(net, primals, series)
  106. assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
  107. assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
  108. @pytest.mark.level0
  109. @pytest.mark.platform_arm_ascend_training
  110. @pytest.mark.platform_x86_ascend_training
  111. @pytest.mark.platform_x86_gpu_training
  112. @pytest.mark.platform_x86_cpu
  113. @pytest.mark.env_onecard
  114. def test_derivative_multiple_input_single_output_graph_mode():
  115. """
  116. Features: Function derivative
  117. Description: Test derivative with multiple inputs in graph mode.
  118. Expectation: No exception.
  119. """
  120. primals = (Tensor([1., 1.]), Tensor([1., 1.]))
  121. order = 3
  122. net = MultipleInputSingleOutputNet()
  123. expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
  124. expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
  125. out_primals, out_series = derivative(net, primals, order)
  126. assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
  127. assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)