You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_variable.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test variable"""
  16. import numpy as np
  17. from mindspore.ops.composite import GradOperation
  18. from mindspore.common.variable import Variable
  19. from mindspore.common.api import _CellGraphExecutor
  20. from mindspore.ops import operations as P
  21. import mindspore.nn as nn
  22. import mindspore.common.dtype as mstype
  23. from mindspore import Tensor
  24. from mindspore import Parameter
  25. def test_variable_scalar_mul_grad_first():
  26. """
  27. Feature: Set Constants mutable.
  28. Description: Get gradient with respect to the first scalar input.
  29. Expectation: Get the correct gradient.
  30. """
  31. class Net(nn.Cell):
  32. def construct(self, x, y):
  33. return x * y
  34. class GradNet(nn.Cell):
  35. def __init__(self, net):
  36. super(GradNet, self).__init__()
  37. self.net = net
  38. self.grad_op = GradOperation()
  39. def construct(self, x, y):
  40. gradient_function = self.grad_op(self.net)
  41. return gradient_function(x, y)
  42. x = Variable(2)
  43. output = GradNet(Net())(x, 3)
  44. assert output == 3
  45. def test_variable_scalar_mul_grad_all():
  46. """
  47. Feature: Set Constants mutable.
  48. Description: Get gradient with respect to all scalar inputs.
  49. Expectation: Get the correct gradients.
  50. """
  51. class Net(nn.Cell):
  52. def construct(self, x, y):
  53. return x * y
  54. class GradNet(nn.Cell):
  55. def __init__(self, net):
  56. super(GradNet, self).__init__()
  57. self.net = net
  58. self.grad_op = GradOperation(get_all=True)
  59. def construct(self, x, y):
  60. gradient_function = self.grad_op(self.net)
  61. return gradient_function(x, y)
  62. x = Variable(2)
  63. y = Variable(3)
  64. output = GradNet(Net())(x, y)
  65. assert output == (3, 2)
  66. def test_variable_tuple_or_list_scalar_mul_grad():
  67. """
  68. Feature: Set Constants mutable.
  69. Description: Get gradient with respect to the tuple or list scalar input.
  70. Expectation: Get the correct gradients.
  71. """
  72. class Net(nn.Cell):
  73. def construct(self, x):
  74. return x[0] * x[1]
  75. class GradNet(nn.Cell):
  76. def __init__(self, net):
  77. super(GradNet, self).__init__()
  78. self.net = net
  79. self.grad_op = GradOperation()
  80. def construct(self, x):
  81. gradient_function = self.grad_op(self.net)
  82. return gradient_function(x)
  83. x = Variable((2, 3))
  84. output = GradNet(Net())(x)
  85. assert output == (3, 2)
  86. x = Variable([2, 3])
  87. output = GradNet(Net())(x)
  88. assert output == (3, 2)
  89. def test_variable_dict_scalar_mul_grad():
  90. """
  91. Feature: Set Constants mutable.
  92. Description: Get gradient with respect to the dict scalar input.
  93. Expectation: Get the correct gradients.
  94. """
  95. class Net(nn.Cell):
  96. def construct(self, x):
  97. return x['a'] * x['b']
  98. class GradNet(nn.Cell):
  99. def __init__(self, net):
  100. super(GradNet, self).__init__()
  101. self.net = net
  102. self.grad_op = GradOperation()
  103. def construct(self, x):
  104. gradient_function = self.grad_op(self.net)
  105. return gradient_function(x)
  106. x = Variable({'a': 2, 'b': 3})
  107. output = GradNet(Net())(x)
  108. assert output == (3, 2)
  109. def test_variable_mix_scalar_mul_grad_all():
  110. """
  111. Feature: Set Constants mutable.
  112. Description: Get gradient with respect to the mix scalar input including dict and tuple.
  113. Expectation: Get the correct gradients.
  114. """
  115. class Net(nn.Cell):
  116. def construct(self, x, y):
  117. return x['a'] * x['b'] * y[0]
  118. class GradNet(nn.Cell):
  119. def __init__(self, net):
  120. super(GradNet, self).__init__()
  121. self.net = net
  122. self.grad_op = GradOperation(get_all=True)
  123. def construct(self, x, y):
  124. gradient_function = self.grad_op(self.net)
  125. return gradient_function(x, y)
  126. x = Variable({'a': 2, 'b': 3})
  127. y = Variable((4, 5))
  128. output = GradNet(Net())(x, y)
  129. assert output == ((12, 8), (6, 0))
  130. def test_tuple_inputs_compile_phase():
  131. """
  132. Feature: Set Constants mutable.
  133. Description: Test whether the compilation phase for tuple input twice are the same.
  134. Expectation: The phases are the same.
  135. """
  136. class Net(nn.Cell):
  137. def __init__(self):
  138. super(Net, self).__init__()
  139. self.matmul = P.MatMul()
  140. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  141. def construct(self, tuple_input):
  142. x = tuple_input[0]
  143. y = tuple_input[1]
  144. x = x * self.z
  145. out = self.matmul(x, y)
  146. return out
  147. x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
  148. y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
  149. p = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
  150. q = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
  151. net = Net()
  152. _cell_graph_executor = _CellGraphExecutor()
  153. phase1, _ = _cell_graph_executor.compile(net, (x, y))
  154. phase2, _ = _cell_graph_executor.compile(net, (p, q))
  155. assert phase1 != phase2
  156. phase1, _ = _cell_graph_executor.compile(net, Variable((x, y)))
  157. phase2, _ = _cell_graph_executor.compile(net, Variable((p, q)))
  158. assert phase1 == phase2