You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_optimizer.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test adam """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, WithLossCell
  22. from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
  23. from mindspore.ops import operations as P
  24. from mindspore import context
  25. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  26. class Net(nn.Cell):
  27. """Net definition"""
  28. def __init__(self):
  29. super(Net, self).__init__()
  30. self.fc1 = nn.Dense(128, 768, activation='relu')
  31. self.fc2 = nn.Dense(128, 768, activation='relu')
  32. self.fc3 = nn.Dense(128, 768, activation='relu')
  33. self.fc4 = nn.Dense(768, 768, activation='relu')
  34. self.relu4 = nn.ReLU()
  35. self.relu5 = nn.ReLU()
  36. self.transpose = P.Transpose()
  37. self.matmul1 = P.MatMul()
  38. self.matmul2 = P.MatMul()
  39. def construct(self, x):
  40. q = self.fc1(x)
  41. k = self.fc2(x)
  42. v = self.fc3(x)
  43. k = self.transpose(k, (1, 0))
  44. c = self.relu4(self.matmul1(q, k))
  45. s = self.relu5(self.matmul2(c, v))
  46. s = self.fc4(s)
  47. return s
  48. def test_AdamWeightDecay():
  49. """ test_AdamWeightDecay """
  50. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  51. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  52. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  53. net = Net()
  54. net.set_train()
  55. loss = nn.SoftmaxCrossEntropyWithLogits()
  56. optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
  57. net_with_loss = WithLossCell(net, loss)
  58. train_network = TrainOneStepCell(net_with_loss, optimizer)
  59. _executor.compile(train_network, inputs, label)
  60. context.reset_auto_parallel_context()
  61. def test_lamb_compile():
  62. """ test_Lamb_compile """
  63. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  64. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  65. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  66. net = Net()
  67. net.set_train()
  68. loss = nn.SoftmaxCrossEntropyWithLogits()
  69. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  70. net_with_loss = WithLossCell(net, loss)
  71. train_network = TrainOneStepCell(net_with_loss, optimizer)
  72. _executor.compile(train_network, inputs, label)
  73. context.reset_auto_parallel_context()
  74. def test_lamb_split_fusion():
  75. """ test_Lamb_split_fusion """
  76. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  77. auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8])
  78. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  79. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  80. net = Net()
  81. net.set_train()
  82. loss = nn.SoftmaxCrossEntropyWithLogits()
  83. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  84. net_with_loss = WithLossCell(net, loss)
  85. train_network = TrainOneStepCell(net_with_loss, optimizer)
  86. _executor.compile(train_network, inputs, label)
  87. context.reset_auto_parallel_context()
  88. def test_edge_case():
  89. """ test_edge_case """
  90. context.set_auto_parallel_context(enable_parallel_optimizer=True)
  91. net = Net()
  92. with pytest.raises(RuntimeError):
  93. context.set_auto_parallel_context(parallel_mode="stand_alone")
  94. Lamb(net.trainable_params(), learning_rate=0.1)
  95. with pytest.raises(RuntimeError):
  96. Adam(net.trainable_params(), learning_rate=0.1)
  97. with pytest.raises(RuntimeError):
  98. context.set_auto_parallel_context(device_num=16)
  99. Lamb(net.trainable_params(), learning_rate=0.1)
  100. context.reset_auto_parallel_context()