You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_strategy_search.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import numpy as np
  17. from mindspore.communication.management import init
  18. from mindspore.communication.management import release
  19. from mindspore.communication.management import get_rank
  20. from mindspore.communication.management import get_group_size
  21. from mindspore.nn import Cell
  22. from mindspore.nn import Conv2d
  23. from mindspore.nn import ReLU
  24. from mindspore.nn import Dense
  25. from mindspore.nn import Softmax
  26. import mindspore.ops.operations as P
  27. from mindspore.train.serialization import load_param_into_net
  28. from mindspore.train.callback import CheckpointConfig
  29. from mindspore.train.callback import ModelCheckpoint
  30. from mindspore.train.serialization import load_checkpoint
  31. from mindspore.nn import Momentum
  32. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  33. from mindspore.train import Model
  34. from mindspore.parallel import set_algo_parameters
  35. from mindspore.common.initializer import initializer
  36. from mindspore.common import dtype as mstype
  37. from mindspore import Tensor
  38. from mindspore.common.parameter import Parameter
  39. from mindspore import context
  40. from mindspore.context import ParallelMode
  41. context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
  42. def _count_unequal_element(data_expected, data_me, rtol, atol):
  43. assert data_expected.shape == data_me.shape
  44. total_count = len(data_expected.flatten())
  45. error = np.abs(data_expected - data_me)
  46. greater = np.greater(error, atol + np.abs(data_me) * rtol)
  47. loss_count = np.count_nonzero(greater)
  48. assert (loss_count / total_count) < rtol, \
  49. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  50. format(data_expected[greater], data_me[greater], error[greater])
  51. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  52. if np.any(np.isnan(data_expected)):
  53. assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
  54. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  55. _count_unequal_element(data_expected, data_me, rtol, atol)
  56. else:
  57. assert True
  58. def clean_all_ckpt_files(folder_path):
  59. if os.path.exists(folder_path):
  60. for file_name in os.listdir(folder_path):
  61. if file_name.endswith('.ckpt') or file_name.endswith('.meta'):
  62. os.remove(os.path.join(folder_path, file_name))
  63. def find_newest_ckpt_file(folder_path):
  64. ckpt_files = map(lambda f: os.path.join(folder_path, f),
  65. filter(lambda f: f.endswith('.ckpt'),
  66. os.listdir(folder_path)))
  67. return max(ckpt_files, key=os.path.getctime)
  68. class FakeDataInitMode:
  69. RandomInit = 0
  70. OnesInit = 1
  71. UniqueInit = 2
  72. ZerosInit = 3
  73. class FakeData:
  74. def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
  75. num_classes=10, random_offset=0, use_parallel=False,
  76. fakedata_mode=FakeDataInitMode.RandomInit):
  77. self.size = size
  78. self.rank_batch_size = batch_size
  79. self.total_batch_size = self.rank_batch_size
  80. self.random_offset = random_offset
  81. self.image_size = image_size
  82. self.num_classes = num_classes
  83. self.rank_size = 1
  84. self.rank_id = 0
  85. self.batch_index = 0
  86. self.image_data_type = np.float32
  87. self.label_data_type = np.float32
  88. self.is_onehot = True
  89. self.fakedata_mode = fakedata_mode
  90. if use_parallel is True:
  91. init(backend_name='hccl')
  92. self.rank_size = get_group_size()
  93. self.rank_id = get_rank()
  94. self.total_batch_size = self.rank_batch_size * self.rank_size
  95. assert (self.size % self.total_batch_size) == 0
  96. self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
  97. def get_dataset_size(self):
  98. return int(self.size / self.total_batch_size)
  99. def get_repeat_count(self):
  100. return 1
  101. def set_image_data_type(self, data_type):
  102. self.image_data_type = data_type
  103. def set_label_data_type(self, data_type):
  104. self.label_data_type = data_type
  105. def set_label_onehot(self, is_onehot=True):
  106. self.is_onehot = is_onehot
  107. def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
  108. _ = num_epochs
  109. return self
  110. def __getitem__(self, batch_index):
  111. if batch_index * self.total_batch_size >= len(self):
  112. raise IndexError("{} index out of range".format(self.__class__.__name__))
  113. rng_state = np.random.get_state()
  114. np.random.seed(batch_index + self.random_offset)
  115. if self.fakedata_mode == FakeDataInitMode.OnesInit:
  116. img = np.ones(self.total_batch_data_size)
  117. elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
  118. img = np.zeros(self.total_batch_data_size)
  119. elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
  120. total_size = 1
  121. for i in self.total_batch_data_size:
  122. total_size = total_size * i
  123. img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
  124. else:
  125. img = np.random.randn(*self.total_batch_data_size)
  126. target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
  127. np.random.set_state(rng_state)
  128. img = img[self.rank_id]
  129. target = target[self.rank_id]
  130. img_ret = img.astype(self.image_data_type)
  131. target_ret = target.astype(self.label_data_type)
  132. if self.is_onehot:
  133. target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
  134. target_onehot[np.arange(self.rank_batch_size), target] = 1
  135. target_ret = target_onehot.astype(self.label_data_type)
  136. return Tensor(img_ret), Tensor(target_ret)
  137. def __len__(self):
  138. return self.size
  139. def __iter__(self):
  140. self.batch_index = 0
  141. return self
  142. def reset(self):
  143. self.batch_index = 0
  144. def __next__(self):
  145. if self.batch_index * self.total_batch_size < len(self):
  146. data = self[self.batch_index]
  147. self.batch_index += 1
  148. return data
  149. raise StopIteration
  150. class ParallelStrategySearchNet(Cell):
  151. def __init__(self, in_channel, out_channel, axis, input_shape, mul_size,
  152. test_size, prelu_size, transpose_b, matmul_size, num_class):
  153. super().__init__()
  154. mul_np = np.full(mul_size, 0.5, dtype=np.float32)
  155. self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
  156. bias_np = np.full((12,), 7.1, dtype=np.float32)
  157. self.bias = Parameter(Tensor(bias_np), name="bias")
  158. prelu_np = np.full(prelu_size, 0.8, dtype=np.float32)
  159. self.prelu_weight = Parameter(Tensor(prelu_np), name="prelu_weight")
  160. matmul_np = np.full(matmul_size, 1.1, dtype=np.float32)
  161. self.matmul_weight = Parameter(Tensor(matmul_np), name="matmul_weight")
  162. self.mul = P.Mul()
  163. self.conv = Conv2d(in_channels=in_channel, out_channels=out_channel,
  164. kernel_size=5, has_bias=True,
  165. weight_init='ones', bias_init='ones',
  166. pad_mode='valid')
  167. self.scalar = 0.5
  168. self.parameter = Parameter(
  169. initializer(0.5, test_size, dtype=mstype.float32),
  170. name='parameter')
  171. self.tensor = Tensor(np.full(test_size, 0.05, dtype=np.float32))
  172. self.softmax = Softmax(axis=axis)
  173. self.relu = ReLU()
  174. self.relu.relu.add_prim_attr("primitive_target", "CPU")
  175. self.reshape = P.Reshape()
  176. self.input_shape = input_shape
  177. self.equal = P.Equal()
  178. self.cast = P.Cast()
  179. self.concat = P.Concat(axis=1)
  180. self.reduce_sum = P.ReduceSum()
  181. self.bias_add = P.BiasAdd()
  182. self.cos = P.Cos()
  183. self.prelu = P.PReLU()
  184. self.matmul = P.MatMul(transpose_b=transpose_b)
  185. self.l2norm = P.L2Normalize(axis=(1 - axis))
  186. self.tensoradd = P.TensorAdd()
  187. self.strided_slice = P.StridedSlice()
  188. self.dense = Dense(in_channels=6,
  189. out_channels=num_class,
  190. weight_init='ones',
  191. bias_init='ones',
  192. has_bias=True)
  193. def construct(self, inputs):
  194. x = self.conv(inputs)
  195. x = self.softmax(x)
  196. x = self.relu(x)
  197. x = self.mul(x, self.mul_weight)
  198. x = self.reshape(x, self.input_shape)
  199. y = self.parameter * self.tensor * self.scalar
  200. z = self.equal(self.parameter, self.scalar)
  201. z = self.cast(z, mstype.float16)
  202. z = self.cast(z, mstype.float32)
  203. x = self.concat((x, y, z))
  204. x = self.reduce_sum(x, (2, 3))
  205. x = self.bias_add(x, self.bias)
  206. y = self.cos(x)
  207. y = self.prelu(y, self.prelu_weight)
  208. z = self.matmul(x, self.matmul_weight)
  209. z = self.l2norm(z)
  210. x = self.tensoradd(y, z)
  211. x = self.strided_slice(x, (0, 0), (32, 6), (1, 1))
  212. x = self.dense(x)
  213. return x
  214. class ParallelStrategySearchFactory:
  215. def __init__(self, standalone_mode_net, parallel_mode_net):
  216. self.standalone_mode_net = standalone_mode_net
  217. self.parallel_mode_net = parallel_mode_net
  218. self.parallel_ckpt = None
  219. self.standalone_ckpt = None
  220. self.global_rank_id = None
  221. self._set_parallel_env()
  222. self._init_parallel()
  223. def __enter__(self):
  224. return self
  225. def __exit__(self, exc_type, exc_val, exc_tb):
  226. return
  227. def __del__(self):
  228. self._release_parallel()
  229. def _set_parallel_env(self):
  230. if 'RANK_ID' in os.environ:
  231. self.global_rank_id = int(os.environ['RANK_ID'])
  232. def _init_parallel(self):
  233. self._init_parallel_flag = False
  234. init(backend_name='hccl')
  235. self._init_parallel_flag = True
  236. def _release_parallel(self):
  237. if self._init_parallel_flag:
  238. release()
  239. def _model_train_and_save_ckpt(self, net, dataset, epoch):
  240. self.opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
  241. self.loss_fn = SoftmaxCrossEntropyWithLogits(reduction='mean')
  242. self.model = Model(network=net,
  243. loss_fn=self.loss_fn,
  244. optimizer=self.opt)
  245. ckpt_config = CheckpointConfig(keep_checkpoint_max=1)
  246. ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id)
  247. ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path,
  248. config=ckpt_config)
  249. clean_all_ckpt_files(ckpt_path)
  250. self.model.train(epoch=epoch,
  251. train_dataset=dataset,
  252. callbacks=[ckpt_callback],
  253. dataset_sink_mode=False)
  254. newest_ckpt_file = find_newest_ckpt_file(ckpt_path)
  255. return load_checkpoint(newest_ckpt_file)
  256. def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, auto_parallel_search_mode="dynamic_programming"):
  257. parallel_mode_net = self.parallel_mode_net
  258. set_algo_parameters(fully_use_devices=False)
  259. context.reset_auto_parallel_context()
  260. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
  261. device_num=device_num,
  262. auto_parallel_search_mode=auto_parallel_search_mode)
  263. self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net,
  264. dataset=dataset, epoch=epoch)
  265. context.reset_auto_parallel_context()
  266. def mindspore_standalone_impl(self, dataset, epoch):
  267. standalone_mode_net = self.standalone_mode_net
  268. context.reset_auto_parallel_context()
  269. context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE)
  270. self.standalone_ckpt = self._model_train_and_save_ckpt(net=standalone_mode_net,
  271. dataset=dataset, epoch=epoch)
  272. context.reset_auto_parallel_context()
  273. def checkpoint_cmp(self, inputs_np):
  274. standalone_net = self.standalone_mode_net
  275. load_param_into_net(standalone_net, self.standalone_ckpt)
  276. standalone_out = standalone_net(Tensor(inputs_np))
  277. parallel_net = self.standalone_mode_net
  278. load_param_into_net(parallel_net, self.parallel_ckpt)
  279. parallel_out = parallel_net(Tensor(inputs_np))
  280. allclose_nparray(standalone_out.asnumpy(), parallel_out.asnumpy(),
  281. 0.001, 0.001)
  282. def test_auto_parallel_strategy_search_axis_1_basic():
  283. inputs_np = np.random.randn(32, 3, 224, 224).astype(np.float32)
  284. standalone_mode_net = ParallelStrategySearchNet(in_channel=3,
  285. out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
  286. mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880),
  287. prelu_size=(1,), transpose_b=True, matmul_size=(1, 12),
  288. num_class=12)
  289. context.reset_auto_parallel_context()
  290. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)
  291. parallel_mode_net = ParallelStrategySearchNet(in_channel=3,
  292. out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
  293. mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880),
  294. prelu_size=(1,), transpose_b=True, matmul_size=(1, 12),
  295. num_class=12)
  296. parallel_mode_net.cos.shard(((2, 4),))
  297. parallel_mode_net.matmul.shard(((1, 2), (1, 2)))
  298. standalone_dataset = FakeData(size=128, batch_size=32,
  299. image_size=(3, 224, 224), num_classes=12)
  300. fact = ParallelStrategySearchFactory(standalone_mode_net=standalone_mode_net,
  301. parallel_mode_net=parallel_mode_net)
  302. fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2)
  303. parallel_dataset = FakeData(size=128, batch_size=4,
  304. image_size=(3, 224, 224), use_parallel=True,
  305. num_classes=12)
  306. fact.mindspore_auto_parallel_impl(dataset=parallel_dataset,
  307. epoch=2, device_num=8)
  308. fact.checkpoint_cmp(inputs_np=inputs_np)
  309. def test_auto_parallel_recursive_strategy_search_axis_1_basic():
  310. inputs_np = np.random.randn(32, 3, 224, 224).astype(np.float32)
  311. standalone_mode_net = ParallelStrategySearchNet(in_channel=3,
  312. out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
  313. mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880),
  314. prelu_size=(1,), transpose_b=True, matmul_size=(1, 12),
  315. num_class=12)
  316. context.reset_auto_parallel_context()
  317. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)
  318. parallel_mode_net = ParallelStrategySearchNet(in_channel=3,
  319. out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
  320. mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880),
  321. prelu_size=(1,), transpose_b=True, matmul_size=(1, 12),
  322. num_class=12)
  323. standalone_dataset = FakeData(size=128, batch_size=32,
  324. image_size=(3, 224, 224), num_classes=12)
  325. fact = ParallelStrategySearchFactory(standalone_mode_net=standalone_mode_net,
  326. parallel_mode_net=parallel_mode_net)
  327. fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2)
  328. parallel_dataset = FakeData(size=128, batch_size=4,
  329. image_size=(3, 224, 224), use_parallel=True,
  330. num_classes=12)
  331. fact.mindspore_auto_parallel_impl(dataset=parallel_dataset,
  332. epoch=2, device_num=8, auto_parallel_search_mode="recursive_programming")
  333. fact.checkpoint_cmp(inputs_np=inputs_np)