|
|
|
@@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell |
|
|
|
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore import context |
|
|
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context |
|
|
|
|
|
|
|
class Net(nn.Cell): |
|
|
|
"""Net definition""" |
|
|
|
@@ -64,6 +64,7 @@ def test_AdamWeightDecay(): |
|
|
|
net_with_loss = WithLossCell(net, loss) |
|
|
|
train_network = TrainOneStepCell(net_with_loss, optimizer) |
|
|
|
_executor.compile(train_network, inputs, label) |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
|
|
|
|
|
|
|
|
def test_lamb_compile(): |
|
|
|
@@ -79,7 +80,24 @@ def test_lamb_compile(): |
|
|
|
net_with_loss = WithLossCell(net, loss) |
|
|
|
train_network = TrainOneStepCell(net_with_loss, optimizer) |
|
|
|
_executor.compile(train_network, inputs, label) |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
|
|
|
|
|
|
|
|
def test_lamb_split_fusion(): |
|
|
|
""" test_Lamb_split_fusion """ |
|
|
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) |
|
|
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) |
|
|
|
label = Tensor(np.zeros([32, 768]).astype(np.float32)) |
|
|
|
net = Net() |
|
|
|
net.set_train() |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
optimizer = Lamb(net.trainable_params(), learning_rate=0.1) |
|
|
|
|
|
|
|
net_with_loss = WithLossCell(net, loss) |
|
|
|
train_network = TrainOneStepCell(net_with_loss, optimizer) |
|
|
|
_executor.compile(train_network, inputs, label) |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
|
|
|
|
def test_edge_case(): |
|
|
|
""" test_edge_case """ |
|
|
|
@@ -93,3 +111,4 @@ def test_edge_case(): |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
context.set_auto_parallel_context(device_num=16) |
|
|
|
Lamb(net.trainable_params(), learning_rate=0.1) |
|
|
|
context.reset_auto_parallel_context() |