You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_param_transform.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. from mindspore import RowTensor
  18. from mindspore import context, nn, Tensor, ParameterTuple
  19. from mindspore.common import dtype as mstype
  20. from mindspore.common import ms_function
  21. from mindspore.ops import operations as P
  22. from mindspore.ops import composite as C
  23. def setup_module():
  24. context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
  25. class _Grad(nn.Cell):
  26. def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
  27. super().__init__()
  28. self.network = network
  29. self.grad = grad
  30. self.sens_param = self.grad.sens_param
  31. self.wrt_params = wrt_params
  32. self.real_inputs_count = real_inputs_count
  33. if self.wrt_params:
  34. self.params = ParameterTuple(self.network.trainable_params())
  35. def construct(self, *inputs):
  36. if self.wrt_params:
  37. if self.real_inputs_count is None or self.sens_param is False:
  38. return self.grad(self.network, self.params)(*inputs)
  39. real_inputs = inputs[:self.real_inputs_count]
  40. sense_param_inputs = inputs[self.real_inputs_count:]
  41. return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
  42. if self.real_inputs_count is None or self.sens_param is False:
  43. return self.grad(self.network)(*inputs)
  44. real_inputs = inputs[:self.real_inputs_count]
  45. sense_param_inputs = inputs[self.real_inputs_count:]
  46. return self.grad(self.network)(*real_inputs, sense_param_inputs)
  47. class GradOfFirstInput(_Grad):
  48. """
  49. get grad of first input
  50. """
  51. def __init__(self, network, sens_param=True, real_inputs_count=None):
  52. super().__init__(grad=C.GradOperation(sens_param=sens_param),
  53. network=network, real_inputs_count=real_inputs_count)
  54. class GradOfAllInputs(_Grad):
  55. """
  56. get grad of first input
  57. """
  58. def __init__(self, network, sens_param=True, real_inputs_count=None):
  59. super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
  60. network=network, real_inputs_count=real_inputs_count)
  61. @pytest.mark.level0
  62. @pytest.mark.platform_arm_ascend_training
  63. @pytest.mark.platform_x86_ascend_training
  64. @pytest.mark.env_onecard
  65. def test_row_tensor_in_while():
  66. class RowTensorValuesDouble(nn.Cell):
  67. def construct(self, x):
  68. indices = x.indices
  69. values = x.values * 2
  70. dense_shape = x.dense_shape
  71. return RowTensor(indices, values, dense_shape)
  72. class RowTensorValuesAdd2(nn.Cell):
  73. def construct(self, x):
  74. indices = x.indices
  75. values = x.values + 2
  76. dense_shape = x.dense_shape
  77. return RowTensor(indices, values, dense_shape)
  78. class RowTensorWithControlWhile(nn.Cell):
  79. def __init__(self, dense_shape):
  80. super().__init__()
  81. self.op1 = RowTensorValuesDouble()
  82. self.op2 = RowTensorValuesAdd2()
  83. self.dense_shape = dense_shape
  84. @ms_function
  85. def construct(self, a, b, indices, values):
  86. x = RowTensor(indices, values, self.dense_shape)
  87. x = self.op2(x)
  88. while a > b:
  89. x = self.op1(x)
  90. b = b + 1
  91. return x.indices, x.values, x.dense_shape
  92. a = Tensor(np.array(3).astype(np.int32))
  93. b = Tensor(np.array(0).astype(np.int32))
  94. indices = Tensor(np.array([0, 2]).astype(np.int32))
  95. values = Tensor(np.ones([2, 2]).astype(np.float32))
  96. dense_shape = (5, 2)
  97. net = RowTensorWithControlWhile(dense_shape)
  98. out = net(a, b, indices, values)
  99. assert np.allclose(indices.asnumpy(), out[0].asnumpy(), .0, .0)
  100. assert np.allclose(values.asnumpy()*24, out[1].asnumpy(), .0, .0)
  101. assert dense_shape == out[2]
  102. @pytest.mark.level0
  103. @pytest.mark.platform_arm_ascend_training
  104. @pytest.mark.platform_x86_ascend_training
  105. @pytest.mark.env_onecard
  106. def test_parser_switch_layer_inputs_tuple():
  107. class Add(nn.Cell):
  108. def __init__(self):
  109. super().__init__()
  110. self.op = P.TensorAdd()
  111. def construct(self, x):
  112. y = self.op(x[0], x[1])
  113. return self.op(x[0], y)
  114. class Mul(nn.Cell):
  115. def __init__(self):
  116. super().__init__()
  117. self.op = P.Mul()
  118. def construct(self, x):
  119. y = self.op(x[0], x[1])
  120. return self.op(x[0], y)
  121. class MulTwoInput(nn.Cell):
  122. def __init__(self):
  123. super().__init__()
  124. self.op = P.Mul()
  125. @ms_function
  126. def construct(self, x, y):
  127. y = self.op(x, y)
  128. return self.op(x, y)
  129. class TwoInputTupleFinalNet(nn.Cell):
  130. def __init__(self, funcs):
  131. super().__init__()
  132. self.funcs = funcs
  133. @ms_function
  134. def construct(self, i, inputa, inputb):
  135. inputs = (inputa, inputb)
  136. x = self.funcs[i](inputs)
  137. return x
  138. func1 = Add()
  139. func2 = Mul()
  140. funcs = (func1, func2)
  141. net = TwoInputTupleFinalNet(funcs)
  142. input_data = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  143. input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  144. i = Tensor(1, mstype.int32)
  145. netout = net(i, input_data, input2)
  146. net_good = MulTwoInput()
  147. goodout = net_good(input_data, input2)
  148. assert np.allclose(goodout.asnumpy(), netout.asnumpy(), 0, 0)
  149. @pytest.mark.level0
  150. @pytest.mark.platform_arm_ascend_training
  151. @pytest.mark.platform_x86_ascend_training
  152. @pytest.mark.env_onecard
  153. def test_imagenet():
  154. class ImageGradients(nn.Cell):
  155. def __init__(self):
  156. super().__init__()
  157. self.imagegradients = nn.ImageGradients()
  158. def construct(self, inputs):
  159. return self.imagegradients(inputs)
  160. net = ImageGradients()
  161. net_me = GradOfFirstInput(net, real_inputs_count=1)
  162. net_me.set_train()
  163. input_data = Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32)
  164. output_grad = (Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32),
  165. Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32))
  166. net_me(input_data, *output_grad)