You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_global_norm.py 9.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test global norm test """
  16. import re
  17. import os
  18. import shutil
  19. import glob
  20. import numpy as np
  21. import mindspore as ms
  22. import mindspore.nn as nn
  23. import mindspore.dataset as ds
  24. from mindspore import Tensor, Parameter, Model
  25. from mindspore.train import DynamicLossScaleManager
  26. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell, MicroBatchInterleaved, PipelineCell
  27. from mindspore.nn.optim import AdamWeightDecay
  28. from mindspore.ops import operations as P
  29. from mindspore.ops import composite as C
  30. from mindspore import context
  31. class OneParameterNet(nn.Cell):
  32. """Net definition"""
  33. def __init__(self, param_type, strategy1, strategy2):
  34. super(OneParameterNet, self).__init__()
  35. self.fc1 = P.MatMul().shard(strategy1)
  36. self.p1 = Parameter(Tensor(np.ones([48, 16]).astype(param_type)), name="weight1")
  37. self.sub = P.Sub().shard(strategy2)
  38. def construct(self, x, y):
  39. x = P.Cast()(x, ms.float16)
  40. p1 = P.Cast()(self.p1, ms.float16)
  41. x = self.fc1(x, p1)
  42. return self.sub(x, 0)
  43. class Net(nn.Cell):
  44. """Net definition"""
  45. def __init__(self, param_type, strategy1, strategy2):
  46. super(Net, self).__init__()
  47. self.fc1 = P.MatMul().shard(strategy1)
  48. self.fc2 = P.MatMul().shard(strategy2)
  49. self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(param_type)), name="weight1")
  50. self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(param_type)), name="weight2", parallel_optimizer=False)
  51. self.sub = P.Sub()
  52. def construct(self, x, y):
  53. x = P.Cast()(x, ms.float16)
  54. p1 = P.Cast()(self.p1, ms.float16)
  55. p2 = P.Cast()(self.p2, ms.float16)
  56. x = self.fc1(x, p1)
  57. x = self.fc2(x, p2)
  58. return self.sub(x, y)
  59. class Net2(nn.Cell):
  60. """Net definition"""
  61. def __init__(self, param_type, strategy1, strategy2):
  62. super(Net2, self).__init__()
  63. self.net1 = Net(param_type, strategy1, strategy2)
  64. self.net2 = Net(param_type, strategy1, strategy2)
  65. self.net1.pipeline_stage = 0
  66. self.net2.pipeline_stage = 1
  67. self.sub = P.Sub()
  68. def construct(self, x, y):
  69. out1 = self.net1(x, y)
  70. out2 = self.net2(x, y)
  71. return self.sub(out1, out2)
  72. def get_dataset():
  73. inputs = np.ones([64, 48]).astype(np.float32)
  74. label = np.zeros([64, 16]).astype(np.float32)
  75. def dataset_generator():
  76. for _ in range(10):
  77. yield inputs, label
  78. dataset = ds.GeneratorDataset(dataset_generator, column_names=["inputs", "label"])
  79. return dataset
  80. class CustomOptimizer(AdamWeightDecay):
  81. def __init__(self, params):
  82. super(CustomOptimizer, self).__init__(params)
  83. self.optimizer = super(CustomOptimizer, self).construct
  84. def construct(self, gradients):
  85. grads = C.clip_by_global_norm(gradients)
  86. return self.optimizer(grads)
  87. def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None,
  88. interleaved_batch=2, stages=1, micro_size=1, param_type=np.float32,
  89. loss_scale_manager=None):
  90. context.set_context(mode=context.GRAPH_MODE)
  91. context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True,
  92. parallel_optimizer_config={"parallel_optimizer_threshold": 1},
  93. pipeline_stages=stages)
  94. net = MicroBatchInterleaved(net(param_type, strategy1, strategy2), interleaved_batch)
  95. if stages > 1:
  96. net = PipelineCell(net, micro_size=micro_size)
  97. net = _VirtualDatasetCell(net).set_comm_fusion(4)
  98. parameters = net.trainable_params() if stages == 1 else net.infer_param_pipeline_stage()
  99. optimizer = CustomOptimizer(parameters)
  100. if loss_scale_manager:
  101. model = Model(net, optimizer=optimizer, loss_scale_manager=loss_scale_manager)
  102. else:
  103. model = Model(net, optimizer=optimizer)
  104. dataset = get_dataset()
  105. model.train(1, dataset)
  106. class TestGlobalNormInserted:
  107. def setup_method(self):
  108. self.output_path = './graphs' + self.__str__()
  109. context.set_context(save_graphs=True,
  110. save_graphs_path=self.output_path)
  111. def teardown_method(self):
  112. shutil.rmtree(self.output_path)
  113. def run_count_check(self, target_count, pattern):
  114. """
  115. This function will check the target_key counts with the golden one.
  116. :param target_count: The gold float16 count in the Ir files.
  117. :param pattern: The generated keyword in the Ir files.
  118. """
  119. # Find the step_parallel_end
  120. ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', 'step_parallel_end*.ir'))
  121. assert len(ir_files) == 1
  122. appear_count = 0
  123. with open(ir_files[0], 'r') as fp:
  124. for line in fp:
  125. res = re.findall(pattern, line)
  126. if len(res) >= 1:
  127. appear_count += 1
  128. assert appear_count == target_count
  129. def test_nonpipeline_global_norm_one_parameter(self):
  130. """
  131. Feature: Parallel ClipByGlobalNorm
  132. Description: Test the global norm using one parameter, there should be only one allreduce
  133. Expectation:When there is no PARALLEL_GLOBALNORM_IN_STAGES inserted
  134. """
  135. auto_parallel_compile_net("semi_auto_parallel", 8, OneParameterNet, ((1, 8), (8, 1)), ((8, 1), ()),
  136. interleaved_batch=1, param_type=np.float32)
  137. self.run_count_check(target_count=1, pattern=r"PARALLEL_GLOBALNORM_IN_STAGES")
  138. def test_nonpipeline_global_norm(self):
  139. """
  140. Feature: Parallel ClipByGlobalNorm
  141. Description: Test the global norm when running in semi auto parallel mode, scale for data parallel should be 8
  142. Expectation:When there is no real div inserted or AllReduce inserted
  143. """
  144. auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
  145. interleaved_batch=1, param_type=np.float32)
  146. self.run_count_check(target_count=1, pattern=r"=8.*PARALLEL_GLOBALNORM_DIV")
  147. self.run_count_check(target_count=2, pattern=r"PARALLEL_GLOBALNORM")
  148. def test_pipeline_global_norm(self):
  149. """
  150. Feature: Parallel ClipByGlobalNorm
  151. Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
  152. Expectation: When there is no real div inserted or AllReduce inserted
  153. """
  154. auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
  155. interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32)
  156. self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
  157. self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")
  158. def test_pipeline_global_norm_loss_scale(self):
  159. """
  160. Feature: Parallel ClipByGlobalNorm
  161. Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
  162. Expectation: When there is no real div inserted or AllReduce inserted
  163. """
  164. auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
  165. interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32,
  166. loss_scale_manager=DynamicLossScaleManager())
  167. self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
  168. self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")
  169. def test_pipeline_global_norm_fp16(self):
  170. """
  171. Feature: Parallel ClipByGlobalNorm
  172. Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
  173. Expectation: When there is no real div inserted or AllReduce inserted
  174. """
  175. auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
  176. interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16)
  177. self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
  178. self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")
  179. def test_pipeline_global_norm_loss_scale_fp16(self):
  180. """
  181. Feature: Parallel ClipByGlobalNorm
  182. Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
  183. Expectation: When there is no real div inserted or AllReduce inserted
  184. """
  185. auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
  186. interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16,
  187. loss_scale_manager=DynamicLossScaleManager())
  188. self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
  189. self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")