You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_optimizer.py 3.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test adam """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore.common.api import _executor
  21. from mindspore.nn import TrainOneStepCell, WithLossCell
  22. from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
  23. from mindspore.ops import operations as P
  24. from mindspore import context
  25. class Net(nn.Cell):
  26. """Net definition"""
  27. def __init__(self):
  28. super(Net, self).__init__()
  29. self.fc1 = nn.Dense(128, 768, activation='relu')
  30. self.fc2 = nn.Dense(128, 768, activation='relu')
  31. self.fc3 = nn.Dense(128, 768, activation='relu')
  32. self.fc4 = nn.Dense(768, 768, activation='relu')
  33. self.relu4 = nn.ReLU()
  34. self.relu5 = nn.ReLU()
  35. self.transpose = P.Transpose()
  36. self.matmul1 = P.MatMul()
  37. self.matmul2 = P.MatMul()
  38. def construct(self, x):
  39. q = self.fc1(x)
  40. k = self.fc2(x)
  41. v = self.fc3(x)
  42. k = self.transpose(k, (1, 0))
  43. c = self.relu4(self.matmul1(q, k))
  44. s = self.relu5(self.matmul2(c, v))
  45. s = self.fc4(s)
  46. return s
  47. def test_AdamWeightDecay():
  48. """ test_AdamWeightDecay """
  49. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  50. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  51. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  52. net = Net()
  53. net.set_train()
  54. loss = nn.SoftmaxCrossEntropyWithLogits()
  55. optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
  56. net_with_loss = WithLossCell(net, loss)
  57. train_network = TrainOneStepCell(net_with_loss, optimizer)
  58. _executor.compile(train_network, inputs, label)
  59. def test_lamb_compile():
  60. """ test_Lamb_compile """
  61. context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
  62. inputs = Tensor(np.ones([32, 128]).astype(np.float32))
  63. label = Tensor(np.zeros([32, 768]).astype(np.float32))
  64. net = Net()
  65. net.set_train()
  66. loss = nn.SoftmaxCrossEntropyWithLogits()
  67. optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
  68. net_with_loss = WithLossCell(net, loss)
  69. train_network = TrainOneStepCell(net_with_loss, optimizer)
  70. _executor.compile(train_network, inputs, label)
  71. def test_edge_case():
  72. """ test_edge_case """
  73. context.set_auto_parallel_context(enable_parallel_optimizer=True)
  74. net = Net()
  75. with pytest.raises(RuntimeError):
  76. context.set_auto_parallel_context(parallel_mode="stand_alone")
  77. Lamb(net.trainable_params(), learning_rate=0.1)
  78. with pytest.raises(RuntimeError):
  79. Adam(net.trainable_params(), learning_rate=0.1)
  80. with pytest.raises(RuntimeError):
  81. context.set_auto_parallel_context(device_num=16)
  82. Lamb(net.trainable_params(), learning_rate=0.1)