You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_interface.py 5.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore as ms
  16. import mindspore.nn as nn
  17. from mindspore import Tensor
  18. from mindspore import context
  19. from mindspore.common.parameter import Parameter, ParameterTuple
  20. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  21. from mindspore.nn.optim.momentum import Momentum
  22. from mindspore.ops import composite as C, functional as F, operations as P
  23. from mindspore.train import Model
  24. from mindspore.context import ParallelMode
  25. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  26. from tests.dataset_mock import MindData
  27. context.set_context(mode=context.GRAPH_MODE)
  28. class Dataset(MindData):
  29. def __init__(self, predict, label, length=3):
  30. super(Dataset, self).__init__(size=length)
  31. self.predict = predict
  32. self.label = label
  33. self.index = 0
  34. self.length = length
  35. def __iter__(self):
  36. return self
  37. def __next__(self):
  38. if self.index >= self.length:
  39. raise StopIteration
  40. self.index += 1
  41. return self.predict, self.label
  42. def reset(self):
  43. self.index = 0
  44. class AllToAllNet(nn.Cell):
  45. def __init__(self, strategy1):
  46. super(AllToAllNet, self).__init__()
  47. self.matmul = P.MatMul().shard(((1, 1), (1, 8)))
  48. self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
  49. self.transpose1 = P.Transpose().shard(strategy1)
  50. def construct(self, x):
  51. x = self.matmul(x, self.matmul_weight)
  52. x = self.transpose1(x, (1, 0))
  53. return x
  54. def all_to_all_net(strategy1):
  55. return AllToAllNet(strategy1=strategy1)
  56. def loss_scale_manager_common(strategy1):
  57. learning_rate = 0.1
  58. momentum = 0.9
  59. epoch_size = 2
  60. context.reset_auto_parallel_context()
  61. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8)
  62. predict = Tensor(np.ones([32, 128]), dtype=ms.float32)
  63. label = Tensor(np.ones([32]), dtype=ms.int32)
  64. dataset = Dataset(predict, label, 2)
  65. net = all_to_all_net(strategy1)
  66. loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  67. loss.softmax_cross_entropy.shard(((8, 1), (8, 1)))
  68. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  69. scale_manager = DynamicLossScaleManager(32, 2, 2000)
  70. model = Model(net, loss, opt, loss_scale_manager=scale_manager)
  71. # if no GE exists, outputs = self._train_network(*next_element) outputs inputs tensor.
  72. try:
  73. model.train(epoch_size, dataset, dataset_sink_mode=False)
  74. except TypeError:
  75. pass
  76. else:
  77. assert False
  78. def fixme_test_dataset_interface_sens_scalar():
  79. # With error: "The type of sens node is not Tensor or Parameter, it is unsupported now."
  80. strategy1 = ((8, 1),)
  81. loss_scale_manager_common(strategy1)
  82. class TrainOneStepCell(nn.Cell):
  83. def __init__(self, network, optimizer):
  84. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  85. self.network = network
  86. self.network.add_flags(defer_inline=True)
  87. self.weights = ParameterTuple(network.trainable_params())
  88. self.optimizer = optimizer
  89. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  90. def construct(self, data, sens):
  91. weights = self.weights
  92. loss = self.network(data)
  93. grads = self.grad(self.network, weights)(data, sens)
  94. return F.depend(loss, self.optimizer(grads))
  95. def loss_scale_manager_sens(strategy1, sens):
  96. learning_rate = 0.1
  97. momentum = 0.9
  98. device_num = 8
  99. context.reset_auto_parallel_context()
  100. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num)
  101. predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
  102. net = all_to_all_net(strategy1)
  103. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  104. train_net = TrainOneStepCell(net, opt)
  105. train_net.set_train()
  106. train_net(predict, sens)
  107. def test_dataset_interface_sens_shape_not_equal_loss():
  108. strategy1 = ((8, 1),)
  109. sens = Tensor(np.ones([256, 1024]), dtype=ms.float32)
  110. try:
  111. loss_scale_manager_sens(strategy1, sens)
  112. except ValueError:
  113. pass
  114. except TypeError:
  115. pass
  116. except RuntimeError:
  117. pass
  118. def test_dataset_interface_sens_shape_equal_loss():
  119. strategy1 = ((4, 2),)
  120. sens = Tensor(np.ones([256, 256]), dtype=ms.float32)
  121. loss_scale_manager_sens(strategy1, sens)
  122. def test_input_not_in_parameter_layotu_dict():
  123. class Net(nn.Cell):
  124. def __init__(self, strategy1):
  125. super(Net, self).__init__()
  126. self.matmul = P.MatMul().shard(((1, 1), (1, 8)))
  127. self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
  128. self.transpose1 = P.Transpose().shard(strategy1)
  129. def construct(self, x):
  130. x = self.matmul(x, self.matmul_weight)
  131. x = self.transpose1(x, (1, 0))
  132. return x
  133. strategy1 = ((8, 1),)
  134. device_num = 8
  135. context.reset_auto_parallel_context()
  136. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num)
  137. predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
  138. net = Net(strategy1)
  139. net.set_train()
  140. net(predict)