You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_parallel.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import numpy as np
  17. from mindspore.communication.management import init
  18. from mindspore.communication.management import release
  19. from mindspore.communication.management import get_rank
  20. from mindspore.communication.management import get_group_size
  21. from mindspore.nn import Cell
  22. from mindspore.nn import ReLU
  23. from mindspore.nn import Dense
  24. from mindspore.nn import Flatten
  25. from mindspore.nn import Momentum
  26. import mindspore.ops.operations as P
  27. from mindspore.train.serialization import load_param_into_net
  28. from mindspore.train.callback import CheckpointConfig
  29. from mindspore.train.callback import ModelCheckpoint
  30. from mindspore.train.serialization import load_checkpoint
  31. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  32. from mindspore.train import Model
  33. from mindspore.parallel import set_algo_parameters
  34. from mindspore import Tensor
  35. from mindspore.common.parameter import Parameter
  36. from mindspore import context
  37. from mindspore.context import ParallelMode
  38. context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
  39. def _count_unequal_element(data_expected, data_me, rtol, atol):
  40. assert data_expected.shape == data_me.shape
  41. total_count = len(data_expected.flatten())
  42. error = np.abs(data_expected - data_me)
  43. greater = np.greater(error, atol + np.abs(data_me) * rtol)
  44. loss_count = np.count_nonzero(greater)
  45. assert (loss_count / total_count) < rtol, \
  46. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  47. format(data_expected[greater], data_me[greater], error[greater])
  48. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  49. if np.any(np.isnan(data_expected)):
  50. assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
  51. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  52. _count_unequal_element(data_expected, data_me, rtol, atol)
  53. else:
  54. assert True
  55. def clean_all_ckpt_files(folder_path):
  56. if os.path.exists(folder_path):
  57. for file_name in os.listdir(folder_path):
  58. if file_name.endswith('.ckpt') or file_name.endswith('.meta'):
  59. os.remove(os.path.join(folder_path, file_name))
  60. def find_newest_ckpt_file(folder_path):
  61. ckpt_files = map(lambda f: os.path.join(folder_path, f),
  62. filter(lambda f: f.endswith('.ckpt'),
  63. os.listdir(folder_path)))
  64. return max(ckpt_files, key=os.path.getctime)
  65. class FakeDataInitMode:
  66. RandomInit = 0
  67. OnesInit = 1
  68. UniqueInit = 2
  69. ZerosInit = 3
  70. class FakeData:
  71. def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
  72. num_classes=10, random_offset=0, use_parallel=False,
  73. fakedata_mode=FakeDataInitMode.RandomInit):
  74. self.size = size
  75. self.rank_batch_size = batch_size
  76. self.total_batch_size = self.rank_batch_size
  77. self.random_offset = random_offset
  78. self.image_size = image_size
  79. self.num_classes = num_classes
  80. self.rank_size = 1
  81. self.rank_id = 0
  82. self.batch_index = 0
  83. self.image_data_type = np.float32
  84. self.label_data_type = np.float32
  85. self.is_onehot = True
  86. self.fakedata_mode = fakedata_mode
  87. if use_parallel is True:
  88. init(backend_name='hccl')
  89. self.rank_size = get_group_size()
  90. self.rank_id = get_rank()
  91. self.total_batch_size = self.rank_batch_size * self.rank_size
  92. assert (self.size % self.total_batch_size) == 0
  93. self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
  94. def get_dataset_size(self):
  95. return int(self.size / self.total_batch_size)
  96. def get_repeat_count(self):
  97. return 1
  98. def set_image_data_type(self, data_type):
  99. self.image_data_type = data_type
  100. def set_label_data_type(self, data_type):
  101. self.label_data_type = data_type
  102. def set_label_onehot(self, is_onehot=True):
  103. self.is_onehot = is_onehot
  104. def create_tuple_iterator(self, num_epochs=-1):
  105. _ = num_epochs
  106. return self
  107. def __getitem__(self, batch_index):
  108. if batch_index * self.total_batch_size >= len(self):
  109. raise IndexError("{} index out of range".format(self.__class__.__name__))
  110. rng_state = np.random.get_state()
  111. np.random.seed(batch_index + self.random_offset)
  112. if self.fakedata_mode == FakeDataInitMode.OnesInit:
  113. img = np.ones(self.total_batch_data_size)
  114. elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
  115. img = np.zeros(self.total_batch_data_size)
  116. elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
  117. total_size = 1
  118. for i in self.total_batch_data_size:
  119. total_size = total_size * i
  120. img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
  121. else:
  122. img = np.random.randn(*self.total_batch_data_size)
  123. target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
  124. np.random.set_state(rng_state)
  125. img = img[self.rank_id]
  126. target = target[self.rank_id]
  127. img_ret = img.astype(self.image_data_type)
  128. target_ret = target.astype(self.label_data_type)
  129. if self.is_onehot:
  130. target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
  131. target_onehot[np.arange(self.rank_batch_size), target] = 1
  132. target_ret = target_onehot.astype(self.label_data_type)
  133. return Tensor(img_ret), Tensor(target_ret)
  134. def __len__(self):
  135. return self.size
  136. def __iter__(self):
  137. self.batch_index = 0
  138. return self
  139. def reset(self):
  140. self.batch_index = 0
  141. def __next__(self):
  142. if self.batch_index * self.total_batch_size < len(self):
  143. data = self[self.batch_index]
  144. self.batch_index += 1
  145. return data
  146. raise StopIteration
  147. class OptimizerSemiAutoAndAutoParallel6Net(Cell):
  148. def __init__(self, strategy_dict=None):
  149. super().__init__()
  150. shared_np = np.full((16, 1, 32, 32), 0.5, dtype=np.float32)
  151. self.shared_weight = Parameter(Tensor(shared_np), name='shared_weight')
  152. self.fc1 = Dense(in_channels=1024,
  153. out_channels=116,
  154. weight_init='ones',
  155. bias_init='ones',
  156. has_bias=True)
  157. self.relu = ReLU()
  158. self.sigmoid = P.Sigmoid()
  159. self.add1 = P.TensorAdd()
  160. self.add2 = P.TensorAdd()
  161. self.mul1 = P.Mul().add_prim_attr('primitive_target', 'CPU')
  162. self.mul2 = P.Mul()
  163. self.mul3 = P.Mul()
  164. self.flatten = Flatten()
  165. mul2_weight_np = np.full((16, 116), 1, dtype=np.float32)
  166. self.mul2_weight = Parameter(Tensor(mul2_weight_np), name='mul2_weight')
  167. mul3_weight_np = np.full((16, 116), 1, dtype=np.float32)
  168. self.mul3_weight = Parameter(Tensor(mul3_weight_np), name='mul3_weight')
  169. if strategy_dict is not None:
  170. self.add1.shard(strategy_dict['add1'])
  171. self.mul1.shard(strategy_dict['mul1'])
  172. self.fc1.matmul.shard(strategy_dict['fc1_matmul'])
  173. self.fc1.bias_add.shard(strategy_dict['fc1_bias_add'])
  174. self.mul2.shard(strategy_dict['mul2'])
  175. self.mul3.shard(strategy_dict['mul3'])
  176. def construct(self, inputs):
  177. relu = self.relu(inputs)
  178. sigmoid = self.sigmoid(inputs)
  179. add1 = self.add1(relu, self.shared_weight)
  180. mul = self.mul1(sigmoid, self.shared_weight)
  181. add2 = self.add2(add1, mul)
  182. flatten = self.flatten(add2)
  183. dense = self.fc1(flatten)
  184. mul2 = self.mul2(dense, self.mul2_weight)
  185. out = self.mul3(mul2, self.mul3_weight)
  186. return out
  187. class OptimizerSemiAutoAndAutoParallelFactory:
  188. def __init__(self, net, strategy_dict=None):
  189. self.parallel_ckpt = None
  190. self.optimizer_parallel_ckpt = None
  191. self.net = net
  192. self.strategy_dict = strategy_dict
  193. self.global_rank_id = None
  194. self._set_parallel_env()
  195. self._init_parallel()
  196. def __enter__(self):
  197. return self
  198. def __exit__(self, exc_type, exc_val, exc_tb):
  199. return
  200. def __del__(self):
  201. self._release_parallel()
  202. def _set_parallel_env(self):
  203. if 'RANK_ID' in os.environ:
  204. self.global_rank_id = int(os.environ['RANK_ID'])
  205. def _init_parallel(self):
  206. self._init_parallel_flag = False
  207. init(backend_name='hccl')
  208. self._init_parallel_flag = True
  209. def _release_parallel(self):
  210. if self._init_parallel_flag:
  211. release()
  212. def _model_train_and_save_ckpt(self, net, dataset, epoch):
  213. self.opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
  214. self.loss_fn = SoftmaxCrossEntropyWithLogits(reduction='mean')
  215. self.model = Model(network=net,
  216. loss_fn=self.loss_fn,
  217. optimizer=self.opt)
  218. ckpt_config = CheckpointConfig(keep_checkpoint_max=1)
  219. ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id)
  220. ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path,
  221. config=ckpt_config)
  222. clean_all_ckpt_files(ckpt_path)
  223. self.model.train(epoch=epoch,
  224. train_dataset=dataset,
  225. callbacks=[ckpt_callback],
  226. dataset_sink_mode=False)
  227. newest_ckpt_file = find_newest_ckpt_file(ckpt_path)
  228. return load_checkpoint(newest_ckpt_file)
  229. def mindspore_auto_parallel_impl(self,
  230. dataset,
  231. epoch,
  232. device_num):
  233. set_algo_parameters(fully_use_devices=False)
  234. context.reset_auto_parallel_context()
  235. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
  236. device_num=device_num)
  237. parallel_mode_net = self.net(self.strategy_dict)
  238. self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net,
  239. dataset=dataset, epoch=epoch)
  240. context.reset_auto_parallel_context()
  241. def mindspore_optimizer_auto_parallel_impl(self,
  242. dataset,
  243. epoch,
  244. device_num):
  245. set_algo_parameters(fully_use_devices=False)
  246. context.reset_auto_parallel_context()
  247. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
  248. device_num=device_num,
  249. enable_parallel_optimizer=True)
  250. parallel_mode_net = self.net(self.strategy_dict)
  251. self.optimizer_parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net,
  252. dataset=dataset, epoch=epoch)
  253. context.reset_auto_parallel_context()
  254. def checkpoint_cmp(self, inputs_np):
  255. optimizer_parallel_net = self.net(self.strategy_dict)
  256. load_param_into_net(optimizer_parallel_net, self.optimizer_parallel_ckpt)
  257. optimizer_parallel_out = optimizer_parallel_net(Tensor(inputs_np))
  258. parallel_net = self.net(self.strategy_dict)
  259. load_param_into_net(parallel_net, self.parallel_ckpt)
  260. parallel_out = parallel_net(Tensor(inputs_np))
  261. allclose_nparray(optimizer_parallel_out.asnumpy(), parallel_out.asnumpy(), 0.001, 0.001)
  262. def test_optimizer_parallel_auto_4p_6_parameter_same_strategy_1_1_2_1_momentum():
  263. inputs_np = np.random.randn(16, 1, 32, 32).astype(np.float32)
  264. dataset = FakeData(size=32,
  265. batch_size=4,
  266. image_size=(1, 32, 32),
  267. use_parallel=True,
  268. num_classes=116)
  269. strategy_dict = {'add1': ((1, 1, 2, 1), (1, 1, 2, 1)),
  270. 'mul1': ((1, 1, 2, 1), (1, 1, 2, 1)),
  271. 'fc1_matmul': ((1, 2), (1, 2)),
  272. 'fc1_bias_add': ((1, 2), (2,)),
  273. 'mul2': ((1, 2), (1, 2)),
  274. 'mul3': ((1, 2), (1, 2))}
  275. fact = OptimizerSemiAutoAndAutoParallelFactory(net=OptimizerSemiAutoAndAutoParallel6Net,
  276. strategy_dict=strategy_dict)
  277. fact.mindspore_auto_parallel_impl(dataset=dataset, epoch=2, device_num=4)
  278. fact.mindspore_optimizer_auto_parallel_impl(dataset=dataset, epoch=2, device_num=4)
  279. fact.checkpoint_cmp(inputs_np=inputs_np)