You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multifieldembeddinglookup_parallel.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import numpy as np
  17. import mindspore.ops.operations as P
  18. from mindspore.nn import Cell
  19. from mindspore.nn import Adam
  20. from mindspore.nn import MultiFieldEmbeddingLookup as embedding
  21. from mindspore import Tensor
  22. from mindspore import context
  23. from mindspore.train import Model
  24. from mindspore.train.callback import CheckpointConfig
  25. from mindspore.train.callback import ModelCheckpoint
  26. from mindspore.train.serialization import load_checkpoint
  27. from mindspore.train.serialization import load_param_into_net
  28. from mindspore.communication.management import init
  29. from mindspore.communication.management import release
  30. from mindspore.communication.management import get_rank
  31. from mindspore.communication.management import get_group_size
  32. from mindspore.context import ParallelMode
  33. context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
  34. def _count_unequal_element(data_expected, data_me, rtol, atol):
  35. assert data_expected.shape == data_me.shape
  36. total_count = len(data_expected.flatten())
  37. error = np.abs(data_expected - data_me)
  38. greater = np.greater(error, atol + np.abs(data_me) * rtol)
  39. loss_count = np.count_nonzero(greater)
  40. assert (loss_count / total_count) < rtol, \
  41. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  42. format(data_expected[greater], data_me[greater], error[greater])
  43. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  44. if np.any(np.isnan(data_expected)):
  45. assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
  46. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  47. _count_unequal_element(data_expected, data_me, rtol, atol)
  48. else:
  49. assert True
  50. def clean_all_ckpt_files(folder_path):
  51. if os.path.exists(folder_path):
  52. for file_name in os.listdir(folder_path):
  53. if file_name.endswith('.ckpt') or file_name.endswith('.meta'):
  54. os.remove(os.path.join(folder_path, file_name))
  55. def find_newest_ckpt_file(folder_path):
  56. ckpt_files = map(lambda f: os.path.join(folder_path, f),
  57. filter(lambda f: f.endswith('.ckpt'),
  58. os.listdir(folder_path)))
  59. return max(ckpt_files, key=os.path.getctime)
  60. class FakeDataInitMode:
  61. RandomInit = 0
  62. OnesInit = 1
  63. UniqueInit = 2
  64. ZerosInit = 3
  65. class FakeData:
  66. def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
  67. num_classes=10, random_offset=0, use_parallel=False,
  68. fakedata_mode=FakeDataInitMode.RandomInit):
  69. self.size = size
  70. self.rank_batch_size = batch_size
  71. self.total_batch_size = self.rank_batch_size
  72. self.random_offset = random_offset
  73. self.image_size = image_size
  74. self.num_classes = num_classes
  75. self.rank_size = 1
  76. self.rank_id = 0
  77. self.batch_index = 0
  78. self.image_data_type = np.float32
  79. self.label_data_type = np.float32
  80. self.is_onehot = True
  81. self.fakedata_mode = fakedata_mode
  82. if use_parallel is True:
  83. init(backend_name='nccl')
  84. self.rank_size = get_group_size()
  85. self.rank_id = get_rank()
  86. self.total_batch_size = self.rank_batch_size * self.rank_size
  87. assert (self.size % self.total_batch_size) == 0
  88. self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
  89. def get_dataset_size(self):
  90. return int(self.size / self.total_batch_size)
  91. def get_repeat_count(self):
  92. return 1
  93. def set_image_data_type(self, data_type):
  94. self.image_data_type = data_type
  95. def set_label_data_type(self, data_type):
  96. self.label_data_type = data_type
  97. def set_label_onehot(self, is_onehot=True):
  98. self.is_onehot = is_onehot
  99. def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
  100. _ = num_epochs
  101. return self
  102. def __getitem__(self, batch_index):
  103. if batch_index * self.total_batch_size >= len(self):
  104. raise IndexError("{} index out of range".format(self.__class__.__name__))
  105. rng_state = np.random.get_state()
  106. np.random.seed(batch_index + self.random_offset)
  107. if self.fakedata_mode == FakeDataInitMode.OnesInit:
  108. img = np.ones(self.total_batch_data_size)
  109. elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
  110. img = np.zeros(self.total_batch_data_size)
  111. elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
  112. total_size = 1
  113. for i in self.total_batch_data_size:
  114. total_size = total_size * i
  115. img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
  116. else:
  117. img = np.random.randn(*self.total_batch_data_size)
  118. target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
  119. np.random.set_state(rng_state)
  120. img = img[self.rank_id]
  121. target = target[self.rank_id]
  122. img_ret = img.astype(self.image_data_type)
  123. target_ret = target.astype(self.label_data_type)
  124. if self.is_onehot:
  125. target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
  126. target_onehot[np.arange(self.rank_batch_size), target] = 1
  127. target_ret = target_onehot.astype(self.label_data_type)
  128. return Tensor(img_ret), Tensor(target_ret)
  129. def __len__(self):
  130. return self.size
  131. def __iter__(self):
  132. self.batch_index = 0
  133. return self
  134. def reset(self):
  135. self.batch_index = 0
  136. def __next__(self):
  137. if self.batch_index * self.total_batch_size < len(self):
  138. data = self[self.batch_index]
  139. self.batch_index += 1
  140. return data
  141. raise StopIteration
  142. class MultiHotNet(Cell):
  143. def __init__(self, vocab_size, embedding_size, field_size,
  144. param_init, target, slice_mode, sparse, operator, indices, field_ids):
  145. super().__init__()
  146. self.embedding = embedding(vocab_size=vocab_size,
  147. embedding_size=embedding_size, field_size=field_size,
  148. param_init=param_init, target=target, slice_mode=slice_mode,
  149. sparse=sparse, operator=operator)
  150. self.relu = P.ReLU()
  151. self.indices = Tensor(indices)
  152. self.field_ids = Tensor(field_ids)
  153. if slice_mode == "table_column_slice":
  154. self.relu.shard(((1, 1, 8),))
  155. elif slice_mode == "table_row_slice":
  156. self.relu.shard(((8, 1, 1),))
  157. elif slice_mode == "batch_slice":
  158. self.relu.shard(((8, 1, 1),))
  159. def construct(self, values, label):
  160. x = self.embedding(self.indices, values, self.field_ids)
  161. output = self.relu(x)
  162. return output
  163. class ParallelMultiHotFactory:
  164. def __init__(self, vocab_size, embedding_size, field_size,
  165. param_init, target, slice_mode, sparse, operator, indices, field_ids):
  166. self.vocab_size = vocab_size
  167. self.embedding_size = embedding_size
  168. self.field_size = field_size
  169. self.param_init = param_init
  170. self.target = target
  171. self.slice_mode = slice_mode
  172. self.sparse = sparse
  173. self.operator = operator
  174. self.indices = indices
  175. self.field_ids = field_ids
  176. self.global_rank_id = None
  177. self.opt = None
  178. self.model = None
  179. self.standalone_ckpt = None
  180. self.parallel_ckpt = None
  181. self.loss_fn = None
  182. self._init_parallel()
  183. self._set_parallel_env()
  184. def __enter__(self):
  185. return self
  186. def __exit__(self, exc_type, exc_val, exc_tb):
  187. return
  188. def __del__(self):
  189. self._release_parallel()
  190. def _set_parallel_env(self):
  191. self.global_rank_id = get_rank()
  192. def _init_parallel(self):
  193. self._init_parallel_flag = False
  194. init(backend_name='nccl')
  195. self._init_parallel_flag = True
  196. def _release_parallel(self):
  197. release()
  198. def _model_train_and_save_ckpt(self, net, dataset, epoch):
  199. self.opt = Adam(params=net.get_parameters())
  200. if self.target == 'CPU':
  201. self.opt.target = self.target
  202. if self.sparse:
  203. context.set_context(enable_sparse=True)
  204. self.model = Model(network=net,
  205. loss_fn=self.loss_fn,
  206. optimizer=self.opt)
  207. ckpt_config = CheckpointConfig(keep_checkpoint_max=1)
  208. ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id)
  209. ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path,
  210. config=ckpt_config)
  211. clean_all_ckpt_files(ckpt_path)
  212. self.model.train(epoch=epoch,
  213. train_dataset=dataset,
  214. callbacks=[ckpt_callback],
  215. dataset_sink_mode=False)
  216. newest_ckpt_file = find_newest_ckpt_file(ckpt_path)
  217. return load_checkpoint(newest_ckpt_file)
  218. def mindspore_auto_parallel_impl(self, dataset, epoch, device_num):
  219. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
  220. device_num=device_num)
  221. parallel_mode_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
  222. field_size=self.field_size, param_init=self.param_init, target=self.target,
  223. slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
  224. indices=self.indices, field_ids=self.field_ids)
  225. self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net, epoch=epoch, dataset=dataset)
  226. def mindspore_standalone_impl(self, epoch, dataset):
  227. context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE)
  228. stand_alone_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
  229. field_size=self.field_size, param_init=self.param_init, target=self.target,
  230. slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
  231. indices=self.indices, field_ids=self.field_ids)
  232. self.standalone_ckpt = self._model_train_and_save_ckpt(net=stand_alone_net,
  233. epoch=epoch, dataset=dataset)
  234. def checkpoint_cmp(self, inputs_np, label):
  235. standalone_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
  236. field_size=self.field_size, param_init=self.param_init, target=self.target,
  237. slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
  238. indices=self.indices, field_ids=self.field_ids)
  239. parallel_net = MultiHotNet(vocab_size=self.vocab_size, embedding_size=self.embedding_size,
  240. field_size=self.field_size, param_init=self.param_init, target=self.target,
  241. slice_mode=self.slice_mode, sparse=self.sparse, operator=self.operator,
  242. indices=self.indices, field_ids=self.field_ids)
  243. load_param_into_net(standalone_net, self.standalone_ckpt)
  244. load_param_into_net(parallel_net, self.parallel_ckpt)
  245. standalone_out = standalone_net(Tensor(inputs_np), Tensor(label))
  246. parallel_out = parallel_net(Tensor(inputs_np), Tensor(label))
  247. allclose_nparray(standalone_out.asnumpy(), parallel_out.asnumpy(), 0.001, 0.001)
  248. def test_auto_parallel_multifieldembeddinglookup_device_table_column_slice_mean():
  249. inputs_np = 10 * np.random.randn(64, 64).astype(np.float32)
  250. label = 10 * np.random.randn(64, 64).astype(np.float32)
  251. indices = np.random.randint(0, 9, (64, 64), np.int32)
  252. field_ids = np.random.randint(0, 20, (64, 64), np.int32)
  253. fact = ParallelMultiHotFactory(vocab_size=32, embedding_size=64, field_size=64, param_init='one', target='DEVICE',
  254. slice_mode='table_column_slice', sparse=False, operator='MEAN',
  255. indices=indices, field_ids=field_ids)
  256. #stand alone
  257. standalone_dataset = FakeData(size=64, batch_size=64, image_size=(64,))
  258. fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2)
  259. #auto parallel
  260. parallel_dataset = FakeData(size=64, batch_size=8, image_size=(64,), use_parallel=True)
  261. fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, epoch=2, device_num=8)
  262. #compare
  263. fact.checkpoint_cmp(inputs_np=inputs_np, label=label)