You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_function_vjp_pynative.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test vjp in pynative mode"""
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. import mindspore.context as context
  20. from mindspore import Tensor
  21. from mindspore.ops.functional import vjp
  22. context.set_context(mode=context.PYNATIVE_MODE)
  23. class SingleInputNet(nn.Cell):
  24. def construct(self, x):
  25. return x**3
  26. class MultipleInputsOutputNet(nn.Cell):
  27. def construct(self, x, y):
  28. return 2*x, y**3
  29. @pytest.mark.level0
  30. @pytest.mark.platform_x86_cpu
  31. @pytest.mark.env_onecard
  32. def test_vjp_single_input_pynative():
  33. """
  34. Features: Function vjp
  35. Description: Test vjp with single input, single output and default v in pynative mode.
  36. Expectation: No exception.
  37. """
  38. x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  39. v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
  40. net = SingleInputNet()
  41. expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
  42. expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
  43. primal, grad = vjp(net, x, v)
  44. assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
  45. assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
  46. @pytest.mark.level0
  47. @pytest.mark.platform_x86_cpu
  48. @pytest.mark.env_onecard
  49. def test_vjp_multiple_inputs_default_v_pynative():
  50. """
  51. Features: Function vjp
  52. Description: Test vjp with multiple inputs, multiple outputs and default v in pynative mode.
  53. Expectation: No exception.
  54. """
  55. x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  56. y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  57. v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
  58. net = MultipleInputsOutputNet()
  59. expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
  60. expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
  61. expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
  62. expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
  63. primal, grad = vjp(net, (x, y), (v, v))
  64. assert isinstance(grad, tuple)
  65. assert len(grad) == 2
  66. assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
  67. assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
  68. assert isinstance(primal, tuple)
  69. assert len(primal) == 2
  70. assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
  71. assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
  72. @pytest.mark.level0
  73. @pytest.mark.platform_x86_cpu
  74. @pytest.mark.env_onecard
  75. def test_vjp_input_function_single_input_single_output_default_v_pynative():
  76. """
  77. Features: Function vjp
  78. Description: Test vjp with function, single input, single output and default v in pynative mode.
  79. Expectation: No exception.
  80. """
  81. x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  82. v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
  83. def test_function(inputs):
  84. return inputs**3
  85. primal, grad = vjp(test_function, x, v)
  86. expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
  87. expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
  88. assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
  89. assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
  90. @pytest.mark.level0
  91. @pytest.mark.platform_x86_cpu
  92. @pytest.mark.env_onecard
  93. def test_vjp_construct_single_input_single_output_default_v_pynative():
  94. """
  95. Features: Function vjp
  96. Description: Test vjp with function, single input, single output and default v in pynative mode.
  97. Expectation: No exception.
  98. """
  99. x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  100. v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
  101. class Net(nn.Cell):
  102. def __init__(self, network):
  103. super(Net, self).__init__()
  104. self.net = network
  105. def construct(self, inputs, vectors):
  106. net_out, vjp_out = vjp(self.net, inputs, vectors)
  107. return net_out, vjp_out
  108. test_net_pynative = Net(SingleInputNet())
  109. primal, grad = test_net_pynative(x, v)
  110. expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
  111. expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
  112. assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
  113. assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())