You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_parallel_inference.py 1.3 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839
  1. import numpy as np
  2. import mindspore.nn as nn
  3. from mindspore import Tensor, context
  4. from mindspore.nn import Momentum
  5. from mindspore.nn import WithLossCell, TrainOneStepCell
  6. from mindspore.ops import operations as P
  7. from mindspore.parallel._cost_model_context import set_cost_model_context
  8. class Net(nn.Cell):
  9. def __init__(self, input_ch, out_ch):
  10. super(Net, self).__init__()
  11. self.dense = nn.Dense(input_ch, out_ch)
  12. self.relu = P.ReLU()
  13. def construct(self, x):
  14. x = self.dense(x)
  15. x = self.relu(x)
  16. return x
  17. def test_inference_phase():
  18. context.set_auto_parallel_context(device_num=8, global_rank=0)
  19. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  20. set_cost_model_context(run_phase=1)
  21. net = Net(512, 128)
  22. predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001)
  23. label = Tensor(np.ones([64, 128]).astype(np.float32))
  24. loss = nn.SoftmaxCrossEntropyWithLogits()
  25. optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  26. net_with_loss = WithLossCell(net, loss)
  27. train_network = TrainOneStepCell(net_with_loss, optimizer)
  28. train_network.set_train()
  29. train_network.set_auto_parallel()
  30. _ = train_network(predict, label)