You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_parallel_onehot.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. import mindspore.nn as nn
  17. from mindspore.ops import operations as P
  18. from mindspore import Tensor
  19. from tests.ut.python.ops.test_math_ops import VirtualLoss
  20. import mindspore as ms
  21. from mindspore.common.api import _executor
  22. from mindspore.ops import composite as C
  23. from mindspore.common.parameter import Parameter
  24. from tests.dataset_mock import MindData
  25. from mindspore.train import Model, ParallelMode
  26. from mindspore.nn.optim.momentum import Momentum
  27. context.set_context(mode=context.GRAPH_MODE)
  28. class Dataset(MindData):
  29. def __init__(self, predict, label, length=3):
  30. super(Dataset, self).__init__(size=length)
  31. self.predict = predict
  32. self.label = label
  33. self.index = 0
  34. self.length = length
  35. def __iter__(self):
  36. return self
  37. def __next__(self):
  38. if self.index >= self.length:
  39. raise StopIteration
  40. self.index += 1
  41. return self.predict, self.label
  42. def reset(self):
  43. self.index = 0
  44. class NetWithLoss(nn.Cell):
  45. def __init__(self, network):
  46. super(NetWithLoss, self).__init__()
  47. self.loss = VirtualLoss()
  48. self.network = network
  49. def construct(self, x, y, b):
  50. predict = self.network(x, y, b)
  51. return self.loss(predict)
  52. class GradWrap(nn.Cell):
  53. def __init__(self, network):
  54. super(GradWrap, self).__init__()
  55. self.network = network
  56. def construct(self, x, y, b):
  57. return C.grad_all(self.network)(x, y, b)
  58. def test_auto_parallel_arithmetic():
  59. class Net(nn.Cell):
  60. def __init__(self):
  61. super().__init__()
  62. self.matmul = P.MatMul()
  63. self.one_hot = P.OneHot()
  64. self.on_value = Tensor(1.0, ms.float32)
  65. self.off_value = Tensor(0.0, ms.float32)
  66. self.matmul2 = P.MatMul()
  67. def construct(self, x, y, b):
  68. out = self.matmul(x, y)
  69. out1 = self.one_hot(b, 64, self.on_value, self.off_value)
  70. out2 = self.matmul2(out, out1)
  71. return out2
  72. context.set_auto_parallel_context(device_num=8, global_rank=0)
  73. net = GradWrap(NetWithLoss(Net()))
  74. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  75. x = Tensor(np.ones([64, 32]), dtype=ms.float32)
  76. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  77. b = Tensor(np.ones([64]), dtype=ms.int32)
  78. _executor.compile(net, x, y, b)
  79. def test_auto_parallel_arithmetic_model():
  80. class NetOneHot(nn.Cell):
  81. def __init__(self):
  82. super().__init__()
  83. self.matmul = P.MatMul()
  84. self.one_hot = P.OneHot().set_strategy(((1, 8), (), ()))
  85. self.on_value = Tensor(1.0, ms.float32)
  86. self.off_value = Tensor(0.0, ms.float32)
  87. self.matmul2 = P.MatMul()
  88. self.w = Parameter(Tensor(np.zeros([32, 64]).astype(np.float32)), "weight", requires_grad=True)
  89. def construct(self, x, b):
  90. out = self.matmul(x, self.w)
  91. out1 = self.one_hot(b, 64, self.on_value, self.off_value)
  92. out2 = self.matmul2(out, out1)
  93. return out2
  94. context.reset_auto_parallel_context()
  95. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  96. net = NetOneHot()
  97. x = Tensor(np.ones([8, 32]), dtype=ms.float32)
  98. b = Tensor(np.ones([8]), dtype=ms.int32)
  99. dataset = Dataset(x, b, 2)
  100. opt = Momentum(net.trainable_params(), 0.1, 0.9)
  101. model = Model(net, optimizer=opt)
  102. model.train(2, dataset, dataset_sink_mode=False)