# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ test global norm test """ import re import os import shutil import glob import numpy as np import mindspore as ms import mindspore.nn as nn import mindspore.dataset as ds from mindspore import Tensor, Parameter, Model from mindspore.train import DynamicLossScaleManager from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell, MicroBatchInterleaved, PipelineCell from mindspore.nn.optim import AdamWeightDecay from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore import context class OneParameterNet(nn.Cell): """Net definition""" def __init__(self, param_type, strategy1, strategy2): super(OneParameterNet, self).__init__() self.fc1 = P.MatMul().shard(strategy1) self.p1 = Parameter(Tensor(np.ones([48, 16]).astype(param_type)), name="weight1") self.sub = P.Sub().shard(strategy2) def construct(self, x, y): x = P.Cast()(x, ms.float16) p1 = P.Cast()(self.p1, ms.float16) x = self.fc1(x, p1) return self.sub(x, 0) class Net(nn.Cell): """Net definition""" def __init__(self, param_type, strategy1, strategy2): super(Net, self).__init__() self.fc1 = P.MatMul().shard(strategy1) self.fc2 = P.MatMul().shard(strategy2) self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(param_type)), name="weight1") self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(param_type)), name="weight2", parallel_optimizer=False) self.sub = P.Sub() def construct(self, x, y): x = P.Cast()(x, ms.float16) p1 = P.Cast()(self.p1, ms.float16) p2 = P.Cast()(self.p2, ms.float16) x = self.fc1(x, p1) x = self.fc2(x, p2) return self.sub(x, y) class Net2(nn.Cell): """Net definition""" def __init__(self, param_type, strategy1, strategy2): super(Net2, self).__init__() self.net1 = Net(param_type, strategy1, strategy2) self.net2 = Net(param_type, strategy1, strategy2) self.net1.pipeline_stage = 0 self.net2.pipeline_stage = 1 self.sub = P.Sub() def construct(self, x, y): out1 = self.net1(x, y) out2 = self.net2(x, y) return self.sub(out1, out2) def get_dataset(): inputs = np.ones([64, 48]).astype(np.float32) label = np.zeros([64, 16]).astype(np.float32) def dataset_generator(): for _ in range(10): yield inputs, label dataset = ds.GeneratorDataset(dataset_generator, column_names=["inputs", "label"]) return dataset class CustomOptimizer(AdamWeightDecay): def __init__(self, params): super(CustomOptimizer, self).__init__(params) self.optimizer = super(CustomOptimizer, self).construct def construct(self, gradients): grads = C.clip_by_global_norm(gradients) return self.optimizer(grads) def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None, interleaved_batch=2, stages=1, micro_size=1, param_type=np.float32, loss_scale_manager=None): context.set_context(mode=context.GRAPH_MODE) context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True, parallel_optimizer_config={"parallel_optimizer_threshold": 1}, pipeline_stages=stages) net = MicroBatchInterleaved(net(param_type, strategy1, strategy2), interleaved_batch) if stages > 1: net = PipelineCell(net, micro_size=micro_size) net = _VirtualDatasetCell(net).set_comm_fusion(4) parameters = net.trainable_params() if stages == 1 else net.infer_param_pipeline_stage() optimizer = CustomOptimizer(parameters) if loss_scale_manager: model = Model(net, optimizer=optimizer, loss_scale_manager=loss_scale_manager) else: model = Model(net, optimizer=optimizer) dataset = get_dataset() model.train(1, dataset) class TestGlobalNormInserted: def setup_method(self): self.output_path = './graphs' + self.__str__() context.set_context(save_graphs=True, save_graphs_path=self.output_path) def teardown_method(self): shutil.rmtree(self.output_path) def run_count_check(self, target_count, pattern): """ This function will check the target_key counts with the golden one. :param target_count: The gold float16 count in the Ir files. :param pattern: The generated keyword in the Ir files. """ # Find the step_parallel_end ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', 'step_parallel_end*.ir')) assert len(ir_files) == 1 appear_count = 0 with open(ir_files[0], 'r') as fp: for line in fp: res = re.findall(pattern, line) if len(res) >= 1: appear_count += 1 assert appear_count == target_count def test_nonpipeline_global_norm_one_parameter(self): """ Feature: Parallel ClipByGlobalNorm Description: Test the global norm using one parameter, there should be only one allreduce Expectation:When there is no PARALLEL_GLOBALNORM_IN_STAGES inserted """ auto_parallel_compile_net("semi_auto_parallel", 8, OneParameterNet, ((1, 8), (8, 1)), ((8, 1), ()), interleaved_batch=1, param_type=np.float32) self.run_count_check(target_count=1, pattern=r"PARALLEL_GLOBALNORM_IN_STAGES") def test_nonpipeline_global_norm(self): """ Feature: Parallel ClipByGlobalNorm Description: Test the global norm when running in semi auto parallel mode, scale for data parallel should be 8 Expectation:When there is no real div inserted or AllReduce inserted """ auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)), interleaved_batch=1, param_type=np.float32) self.run_count_check(target_count=1, pattern=r"=8.*PARALLEL_GLOBALNORM_DIV") self.run_count_check(target_count=2, pattern=r"PARALLEL_GLOBALNORM") def test_pipeline_global_norm(self): """ Feature: Parallel ClipByGlobalNorm Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 Expectation: When there is no real div inserted or AllReduce inserted """ auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32) self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") def test_pipeline_global_norm_loss_scale(self): """ Feature: Parallel ClipByGlobalNorm Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 Expectation: When there is no real div inserted or AllReduce inserted """ auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32, loss_scale_manager=DynamicLossScaleManager()) self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") def test_pipeline_global_norm_fp16(self): """ Feature: Parallel ClipByGlobalNorm Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 Expectation: When there is no real div inserted or AllReduce inserted """ auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16) self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM") def test_pipeline_global_norm_loss_scale_fp16(self): """ Feature: Parallel ClipByGlobalNorm Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8 Expectation: When there is no real div inserted or AllReduce inserted """ auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)), interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16, loss_scale_manager=DynamicLossScaleManager()) self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV") self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")