You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_parallel_inference.py 1.3 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. import numpy as np
  2. import mindspore.nn as nn
  3. from mindspore import Tensor, context
  4. from mindspore.ops import operations as P
  5. from mindspore.nn import WithLossCell, TrainOneStepCell
  6. from mindspore.nn import Momentum
  7. from mindspore.parallel._cost_model_context import set_cost_model_context
  8. class Net(nn.Cell):
  9. def __init__(self, input_ch, out_ch):
  10. super(Net, self).__init__()
  11. self.dense = nn.Dense(input_ch, out_ch)
  12. self.relu = P.ReLU()
  13. def construct(self, x):
  14. x = self.dense(x)
  15. x = self.relu(x)
  16. return x
  17. def test_inference_phase():
  18. context.set_auto_parallel_context(device_num=8, global_rank=0)
  19. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  20. set_cost_model_context(run_phase=1)
  21. net = Net(512, 128)
  22. predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001)
  23. label = Tensor(np.ones([64, 128]).astype(np.float32))
  24. loss = nn.SoftmaxCrossEntropyWithLogits()
  25. optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  26. net_with_loss = WithLossCell(net, loss)
  27. train_network = TrainOneStepCell(net_with_loss, optimizer)
  28. train_network.set_train()
  29. train_network.set_auto_parallel()
  30. output = train_network(predict, label)