You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hcom_sparsetensor.py 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import numpy as np
  17. from mindspore.communication.management import get_rank
  18. from mindspore import Tensor
  19. from mindspore import Parameter
  20. from mindspore import context
  21. from mindspore.ops import operations as P
  22. import mindspore.nn as nn
  23. from mindspore.train import Model
  24. from mindspore.context import ParallelMode
  25. from mindspore.communication.management import init
  26. from mindspore.communication.management import get_group_size
  27. class FakeDataInitMode:
  28. RandomInit = 0
  29. OnesInit = 1
  30. UniqueInit = 2
  31. ZerosInit = 3
  32. class FakeData:
  33. def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224), num_class=10,
  34. random_offset=0, use_parallel=False, fakedata_mode=FakeDataInitMode.RandomInit):
  35. self.size = size
  36. self.rank_batch_size = batch_size
  37. self.total_batch_size = self.rank_batch_size
  38. self.random_offset = random_offset
  39. self.image_size = image_size
  40. self.num_class = num_class
  41. self.rank_size = 1
  42. self.rank_id = 0
  43. self.batch_index = 0
  44. self.image_data_type = np.float32
  45. self.label_data_type = np.float32
  46. self.is_onehot = True
  47. self.fakedata_mode = fakedata_mode
  48. if use_parallel:
  49. if 'CONTEXT_DEVICE_TARGET' in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU':
  50. init(backend_name='nccl')
  51. else:
  52. init(backend_name='hccl')
  53. self.rank_size = get_group_size()
  54. self.rank_id = get_rank()
  55. self.total_batch_size = self.rank_batch_size * self.rank_size
  56. assert self.size % self.total_batch_size == 0
  57. self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
  58. def get_dataset_size(self):
  59. return int(self.size / self.total_batch_size)
  60. def get_reeat_count(self):
  61. return 1
  62. def set_image_data_type(self, data_type):
  63. self.image_data_type = data_type
  64. def set_label_data_type(self, data_type):
  65. self.label_data_type = data_type
  66. def set_label_onehot(self, is_onehot=True):
  67. self.is_onehot = is_onehot
  68. def create_tuple_iterator(self, num_epochs=-1, do_copy=False):
  69. return self
  70. def __getitem__(self, batch_index):
  71. if batch_index * self.total_batch_size >= len(self):
  72. raise IndexError("{} index out of range".format(self.__class__.__name__))
  73. rng_state = np.random.get_state()
  74. np.random.seed(batch_index + self.random_offset)
  75. if self.fakedata_mode == FakeDataInitMode.OnesInit:
  76. img = np.ones(self.total_batch_data_size)
  77. elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
  78. img = np.zeros(self.total_batch_data_size)
  79. elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
  80. total_size = 1
  81. for i in self.total_batch_data_size:
  82. total_size = total_size* i
  83. img = np.reshape(np.arange(total_size)*0.0001, self.total_batch_data_size)
  84. else:
  85. img = np.random.randn(*self.total_batch_data_size)
  86. target = np.random.randint(0, self.num_class, size=(self.rank_size, self.rank_batch_size))
  87. np.random.set_state(rng_state)
  88. img = img[self.rank_id]
  89. target = target[self.rank_id]
  90. img_ret = img.astype(self.image_data_type)
  91. target_ret = target.astype(self.label_data_type)
  92. if self.is_onehot:
  93. target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_class))
  94. target_onehot[np.arange(self.rank_batch_size), target] = 1
  95. target_ret = target_onehot.astype(self.label_data_type)
  96. return Tensor(img_ret), Tensor(target_ret)
  97. def __len__(self):
  98. return self.size
  99. def __iter__(self):
  100. self.batch_index = 0
  101. return self
  102. def reset(self):
  103. self.batch_index = 0
  104. def __next__(self):
  105. if self.batch_index * self.total_batch_size < len(self):
  106. data = self[self.batch_index]
  107. self.batch_index += 1
  108. return data
  109. raise StopIteration
  110. class NetWithSparseGatherV2(nn.Cell):
  111. def __init__(self, strategy=None, sparse=True):
  112. super(NetWithSparseGatherV2, self).__init__()
  113. self.axis = 0
  114. self.sparse = sparse
  115. if sparse:
  116. self.weight = Parameter(Tensor(np.ones([8, 8]).astype(np.float32)), name="weight")
  117. self.gather = P.SparseGatherV2()
  118. else:
  119. self.weight = Parameter(Tensor(np.ones([8, 8]).astype(np.float32)), name="weight")
  120. self.gather = P.Gather()
  121. if strategy is not None:
  122. self.gather.shard(strategy)
  123. def construct(self, indices):
  124. x = self.gather(self.weight, indices, self.axis)
  125. return x
  126. def train_mindspore_impl(self, indices, epoch, batch_size, use_parallel=True):
  127. ds = FakeData(size=8, batch_size=batch_size, num_class=8, image_size=(), use_parallel=use_parallel)
  128. ds.set_image_data_type(np.int32)
  129. net = self
  130. net.set_train()
  131. loss = nn.SoftmaxCrossEntropyWithLogits()
  132. optimizer = nn.Adam(net.trainable_params())
  133. optimizer.target = "CPU"
  134. model = Model(net, loss, optimizer)
  135. for _ in range(epoch):
  136. model.train(1, ds, dataset_sink_mode=False)
  137. output = net(indices)
  138. return output
  139. def test_allreduce_sparsegatherv2_adam_auto_parallel():
  140. context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
  141. init(backend_name='hccl')
  142. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8, gradients_mean=True)
  143. indices = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]).astype(np.int32))
  144. epoch = 3
  145. batch_size = 1
  146. context.set_context(enable_sparse=True)
  147. net = NetWithSparseGatherV2(sparse=True)
  148. output_sparse = net.train_mindspore_impl(indices, epoch, batch_size)
  149. net = NetWithSparseGatherV2(sparse=False)
  150. output = net.train_mindspore_impl(indices, epoch, batch_size)
  151. assert np.allclose(output.asnumpy(), output_sparse.asnumpy(), 0.001, 0.001)