You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

onehot_model_parallel.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import pytest
  17. import numpy as np
  18. import mindspore as ms
  19. from mindspore.nn import Cell
  20. from mindspore.ops import operations as P
  21. from mindspore.common.tensor import Tensor
  22. import mindspore.context as context
  23. import mindspore.communication.management as distributedTool
  24. device_num = 2
  25. device_id = int(os.getenv('DEVICE_ID'))
  26. rank_id = 0
  27. def setup_module():
  28. global device_num
  29. global rank_id
  30. np.random.seed(0)
  31. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  32. context.set_context(enable_hccl=True)
  33. context.set_context(enable_task_sink=True,
  34. device_id=device_id)
  35. context.set_context(enable_ir_fusion=True)
  36. context.set_context(enable_loop_sink=False)
  37. distributedTool.init()
  38. device_num = distributedTool.get_group_size()
  39. rank_id = distributedTool.get_rank()
  40. context.set_auto_parallel_context(device_num=device_num,
  41. global_rank=rank_id)
  42. def teardown_module():
  43. distributedTool.release()
  44. class Onehot(Cell):
  45. def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None):
  46. super(Onehot, self).__init__()
  47. trans_stra = None
  48. if strategy:
  49. trans_stra = (strategy[0],)
  50. self.onehot = P.OneHot().set_strategy(strategy=strategy)
  51. self.depth = depth
  52. self.on_value = Tensor(on_value, ms.float32)
  53. self.off_value = Tensor(off_value, ms.float32)
  54. self.transpose = P.Transpose().set_strategy(strategy=trans_stra)
  55. self.sub = P.Sub().set_strategy(strategy=((1, 1), (1, 1)))
  56. def construct(self, input, indices):
  57. x = self.onehot(indices, self.depth, self.on_value, self.off_value)
  58. x = self.transpose(x, (1, 0))
  59. x = self.sub(input, x)
  60. return x
  61. class DataGenerator():
  62. def get_parallel_blocks(self, input_, strategy):
  63. blocks = [input_]
  64. i = 0
  65. for stra in strategy:
  66. temp = []
  67. while len(blocks) > 0:
  68. block = blocks.pop(0)
  69. temp.extend(np.split(block, stra, axis=i))
  70. blocks.extend(temp)
  71. i += 1
  72. return blocks
  73. def generate_data(self, shape):
  74. data = np.random.rand(*shape)
  75. return data
  76. def input_data(self, shape):
  77. data = (self.generate_data(shape)*2).astype(np.float32)
  78. stra = [1]*len(shape)
  79. stra[0] = device_num
  80. datas = self.get_parallel_blocks(data, stra)
  81. return Tensor(data), Tensor(datas[rank_id])
  82. def label_data(self, shape, classes):
  83. data = (self.generate_data(shape)*(classes-1)).astype(np.int32)
  84. stra = [1]*len(shape)
  85. stra[0] = device_num
  86. datas = self.get_parallel_blocks(data, stra)
  87. return Tensor(data), Tensor(datas[rank_id])
  88. class OneHotFactory:
  89. def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None):
  90. dataGen = DataGenerator()
  91. self.input_full, self.input_part = dataGen.input_data((classes, batch_size))
  92. self.label_full, self.label_part = dataGen.label_data((batch_size,), classes)
  93. self.depth = classes
  94. self.on_value = on_value
  95. self.off_value = off_value
  96. self.axis = axis
  97. self.strategy = strategy
  98. def forward_mindspore_single_impl(self):
  99. net = Onehot(axis=self.axis,
  100. depth=self.depth,
  101. on_value=self.on_value,
  102. off_value=self.off_value)
  103. out = net(self.input_full, self.label_full)
  104. return out
  105. def forward_mindspore_parallel_impl(self):
  106. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  107. net = Onehot(axis=self.axis,
  108. depth=self.depth,
  109. on_value=self.on_value,
  110. off_value=self.off_value, strategy=self.strategy)
  111. out = net.compile_and_run(self.input_full, self.label_full)
  112. return out
  113. def forward_cmp(self):
  114. out_mindspore_single = self.forward_mindspore_single_impl().asnumpy()
  115. context.reset_auto_parallel_context()
  116. out_mindspore_parallel = self.forward_mindspore_parallel_impl().asnumpy()
  117. context.reset_auto_parallel_context()
  118. assert np.allclose(out_mindspore_single, out_mindspore_parallel, 0.0001, 0.0001)
  119. def test_reid_onehot_forward_int32_128_depth1024_model_parallel():
  120. fact = OneHotFactory(batch_size=128,
  121. classes=1024,
  122. on_value=1.000000,
  123. off_value=0.000000,
  124. axis=-1,
  125. strategy=((1, device_num), (), ()))
  126. fact.forward_cmp()
  127. def test_reid_onehot_forward_int32_1024_depth128_model_parallel():
  128. fact = OneHotFactory(batch_size=1024,
  129. classes=128,
  130. on_value=1.000000,
  131. off_value=0.000000,
  132. axis=-1,
  133. strategy=((1, device_num), (), ()))
  134. fact.forward_cmp()