You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sequence_assign.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test enumerate"""
  16. import numpy as np
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import composite as C
  22. context.set_context(mode=context.GRAPH_MODE)
  23. def test_list_index_1D():
  24. class Net(nn.Cell):
  25. def __init__(self):
  26. super(Net, self).__init__()
  27. def construct(self):
  28. list_ = [[1], [2, 2], [3, 3, 3]]
  29. list_[0] = [100]
  30. return list_
  31. net = Net()
  32. out = net()
  33. assert out[0] == [100]
  34. assert out[1] == [2, 2]
  35. assert out[2] == [3, 3, 3]
  36. def test_list_neg_index_1D():
  37. class Net(nn.Cell):
  38. def __init__(self):
  39. super(Net, self).__init__()
  40. def construct(self):
  41. list_ = [[1], [2, 2], [3, 3, 3]]
  42. list_[-3] = [100]
  43. return list_
  44. net = Net()
  45. out = net()
  46. assert out[0] == [100]
  47. assert out[1] == [2, 2]
  48. assert out[2] == [3, 3, 3]
  49. def test_list_index_2D():
  50. class Net(nn.Cell):
  51. def __init__(self):
  52. super(Net, self).__init__()
  53. def construct(self):
  54. list_ = [[1], [2, 2], [3, 3, 3]]
  55. list_[1][0] = 200
  56. list_[1][1] = 201
  57. return list_
  58. net = Net()
  59. out = net()
  60. assert out[0] == [1]
  61. assert out[1] == [200, 201]
  62. assert out[2] == [3, 3, 3]
  63. def test_list_neg_index_2D():
  64. class Net(nn.Cell):
  65. def __init__(self):
  66. super(Net, self).__init__()
  67. def construct(self):
  68. list_ = [[1], [2, 2], [3, 3, 3]]
  69. list_[1][-2] = 200
  70. list_[1][-1] = 201
  71. return list_
  72. net = Net()
  73. out = net()
  74. assert out[0] == [1]
  75. assert out[1] == [200, 201]
  76. assert out[2] == [3, 3, 3]
  77. def test_list_index_3D():
  78. class Net(nn.Cell):
  79. def __init__(self):
  80. super(Net, self).__init__()
  81. def construct(self):
  82. list_ = [[1], [2, 2], [[3, 3, 3]]]
  83. list_[2][0][0] = 300
  84. list_[2][0][1] = 301
  85. list_[2][0][2] = 302
  86. return list_
  87. net = Net()
  88. out = net()
  89. assert out[0] == [1]
  90. assert out[1] == [2, 2]
  91. assert out[2] == [[300, 301, 302]]
  92. def test_list_neg_index_3D():
  93. class Net(nn.Cell):
  94. def __init__(self):
  95. super(Net, self).__init__()
  96. def construct(self):
  97. list_ = [[1], [2, 2], [[3, 3, 3]]]
  98. list_[2][0][-3] = 300
  99. list_[2][0][-2] = 301
  100. list_[2][0][-1] = 302
  101. return list_
  102. net = Net()
  103. out = net()
  104. assert out[0] == [1]
  105. assert out[1] == [2, 2]
  106. assert out[2] == [[300, 301, 302]]
  107. def test_list_index_1D_parameter():
  108. class Net(nn.Cell):
  109. def __init__(self):
  110. super(Net, self).__init__()
  111. def construct(self, x):
  112. list_ = [x]
  113. list_[0] = 100
  114. return list_
  115. net = Net()
  116. net(Tensor(0))
  117. def test_list_index_2D_parameter():
  118. class Net(nn.Cell):
  119. def __init__(self):
  120. super(Net, self).__init__()
  121. def construct(self, x):
  122. list_ = [[x, x]]
  123. list_[0][0] = 100
  124. return list_
  125. net = Net()
  126. net(Tensor(0))
  127. def test_list_index_3D_parameter():
  128. class Net(nn.Cell):
  129. def __init__(self):
  130. super(Net, self).__init__()
  131. def construct(self, x):
  132. list_ = [[[x, x]]]
  133. list_[0][0][0] = 100
  134. return list_
  135. net = Net()
  136. net(Tensor(0))
  137. def test_const_list_index_3D_bprop():
  138. class Net(nn.Cell):
  139. def __init__(self):
  140. super(Net, self).__init__()
  141. self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
  142. self.relu = P.ReLU()
  143. def construct(self, input_x):
  144. list_x = self.value
  145. list_x[2][0][1] = input_x
  146. return self.relu(list_x[2][0][1])
  147. class GradNet(nn.Cell):
  148. def __init__(self, net):
  149. super(GradNet, self).__init__()
  150. self.net = net
  151. self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
  152. def construct(self, x, sens):
  153. return self.grad_all_with_sens(self.net)(x, sens)
  154. net = Net()
  155. grad_net = GradNet(net)
  156. x = Tensor(np.arange(2 * 3).reshape(2, 3))
  157. sens = Tensor(np.arange(2 * 3).reshape(2, 3))
  158. grad_net(x, sens)
  159. def test_parameter_list_index_3D_bprop():
  160. class Net(nn.Cell):
  161. def __init__(self):
  162. super(Net, self).__init__()
  163. self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
  164. self.relu = P.ReLU()
  165. def construct(self, x, value):
  166. list_value = [[x], [x, x], [[x, x], [x, x]]]
  167. list_value[2][0][1] = value
  168. return self.relu(list_value[2][0][1])
  169. class GradNet(nn.Cell):
  170. def __init__(self, net):
  171. super(GradNet, self).__init__()
  172. self.net = net
  173. self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
  174. def construct(self, x, value, sens):
  175. return self.grad_all_with_sens(self.net)(x, value, sens)
  176. net = Net()
  177. grad_net = GradNet(net)
  178. x = Tensor(np.arange(2 * 3).reshape(2, 3))
  179. value = Tensor(np.ones((2, 3), np.int64))
  180. sens = Tensor(np.arange(2 * 3).reshape(2, 3))
  181. grad_net(x, value, sens)