plug finetune :已在du reader- robust数据集上回归至最佳结果
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10916382
master^2
| @@ -338,6 +338,7 @@ class Trainers(object): | |||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | ||||
| text_generation_trainer = 'text-generation-trainer' | text_generation_trainer = 'text-generation-trainer' | ||||
| nlp_plug_trainer = 'nlp-plug-trainer' | |||||
| # audio trainers | # audio trainers | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| @@ -500,6 +501,9 @@ class Hooks(object): | |||||
| # CLIP logit_scale clamp | # CLIP logit_scale clamp | ||||
| ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' | ClipClampLogitScaleHook = 'ClipClampLogitScaleHook' | ||||
| # train | |||||
| DeepspeedHook = 'DeepspeedHook' | |||||
| class LR_Schedulers(object): | class LR_Schedulers(object): | ||||
| """learning rate scheduler is defined here | """learning rate scheduler is defined here | ||||
| @@ -0,0 +1,88 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """PyTorch DataLoader for TFRecords""" | |||||
| import math | |||||
| import torch | |||||
| from torch.optim.lr_scheduler import _LRScheduler | |||||
| class AnnealingLR(_LRScheduler): | |||||
| """Anneals the learning rate from start to zero along a cosine curve.""" | |||||
| DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] | |||||
| def __init__(self, | |||||
| optimizer, | |||||
| start_lr, | |||||
| warmup_iter, | |||||
| num_iters, | |||||
| decay_style=None, | |||||
| last_iter=-1): | |||||
| self.optimizer = optimizer | |||||
| self.start_lr = start_lr | |||||
| self.warmup_iter = warmup_iter | |||||
| self._step_count = last_iter + 1 | |||||
| self.end_iter = num_iters | |||||
| self.decay_style = decay_style.lower() if isinstance(decay_style, | |||||
| str) else None | |||||
| self.step(self._step_count) | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('learning rate decaying', decay_style) | |||||
| def get_lr(self): | |||||
| # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 | |||||
| if self.warmup_iter > 0 and self._step_count <= self.warmup_iter: | |||||
| return float(self.start_lr) * self._step_count / self.warmup_iter | |||||
| else: | |||||
| if self.decay_style == self.DECAY_STYLES[0]: | |||||
| return self.start_lr * (( | |||||
| self.end_iter - # noqa W504 | |||||
| (self._step_count - self.warmup_iter)) / self.end_iter) | |||||
| elif self.decay_style == self.DECAY_STYLES[1]: | |||||
| return self.start_lr / 2.0 * ( | |||||
| math.cos(math.pi * (self._step_count - self.warmup_iter) | |||||
| / self.end_iter) + 1) | |||||
| elif self.decay_style == self.DECAY_STYLES[2]: | |||||
| # TODO: implement exponential decay | |||||
| return self.start_lr | |||||
| else: | |||||
| return self.start_lr | |||||
| def step(self, step_num=None): | |||||
| if step_num is None: | |||||
| step_num = self._step_count + 1 | |||||
| self._step_count = step_num | |||||
| new_lr = self.get_lr() | |||||
| for group in self.optimizer.param_groups: | |||||
| group['lr'] = new_lr | |||||
| def state_dict(self): | |||||
| sd = { | |||||
| 'start_lr': self.start_lr, | |||||
| 'warmup_iter': self.warmup_iter, | |||||
| '_step_count': self._step_count, | |||||
| 'decay_style': self.decay_style, | |||||
| 'end_iter': self.end_iter | |||||
| } | |||||
| return sd | |||||
| def load_state_dict(self, sd): | |||||
| self.start_lr = sd['start_lr'] | |||||
| self.warmup_iter = sd['warmup_iter'] | |||||
| self._step_count = sd['_step_count'] | |||||
| self.end_iter = sd['end_iter'] | |||||
| self.decay_style = sd['decay_style'] | |||||
| self.step(self._step_count) | |||||
| @@ -1009,6 +1009,118 @@ class PlugModel(torch.nn.Module): | |||||
| sequence_output=sequence_output, | sequence_output=sequence_output, | ||||
| parallel_output=parallel_output) | parallel_output=parallel_output) | ||||
| @staticmethod | |||||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||||
| # This function has been mostly taken from huggingface conversational ai code at | |||||
| # https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||||
| # conversational-ai-with-transfer-learning-2d818ac26313 | |||||
| if top_k > 0: | |||||
| # Remove all tokens with a probability less than the last token of the top-k | |||||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||||
| None] | |||||
| logits[indices_to_remove] = filter_value | |||||
| if top_p > 0.0: | |||||
| # convert to 1D | |||||
| logits = logits.view(logits.size()[1]).contiguous() | |||||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||||
| cumulative_probs = torch.cumsum( | |||||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
| # Remove tokens with cumulative probability above the threshold | |||||
| sorted_indices_to_remove = cumulative_probs > top_p | |||||
| # Shift the indices to the right to keep also the first token above the threshold | |||||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||||
| ..., :-1].clone() | |||||
| sorted_indices_to_remove[..., 0] = 0 | |||||
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||||
| logits[indices_to_remove] = filter_value | |||||
| # going back to 2D | |||||
| logits = logits.view(1, -1).contiguous() | |||||
| return logits | |||||
| def generate(self, input, out_length=128, model_cfg=None, *kwargs): | |||||
| device = torch.cuda.current_device() | |||||
| batch_size = input['input_ids'].shape[0] | |||||
| tokens = input['input_ids'].view(1, -1).contiguous().to(device) | |||||
| dec_input_ids = input['dec_input_ids'].to(device) | |||||
| attention_mask = input['attention_mask'].to(device) | |||||
| self.model.eval() | |||||
| with torch.no_grad(): | |||||
| # Only supports batch_size=1 | |||||
| all_generate_tokens = [] | |||||
| generate_tokens = [] | |||||
| counter = 0 | |||||
| sequence_output = None | |||||
| vocab_size = self.config.original_vocab_size | |||||
| sep_token_idx = 102 # index of [SEP] token in BertTokenizer | |||||
| while counter < out_length: | |||||
| if counter % 128 == 0 and counter != 0: | |||||
| # Sliding window | |||||
| generate_tokens.append(sep_token_idx) | |||||
| start = (tokens == sep_token_idx).nonzero( | |||||
| as_tuple=True)[-1] | |||||
| if start + len(generate_tokens) >= 512: | |||||
| tokens = torch.cat([ | |||||
| tokens[:start], | |||||
| torch.cuda.LongTensor(generate_tokens) | |||||
| ], -1)[-512:] | |||||
| else: | |||||
| tokens[0][start:start + len(generate_tokens | |||||
| )] = torch.cuda.LongTensor( | |||||
| generate_tokens) | |||||
| attention_mask = (tokens != 0) | |||||
| dec_input_ids = input['dec_input_ids'].to(device) | |||||
| generate_tokens = [] | |||||
| sequence_output = None | |||||
| position_ids = torch.full([batch_size, 1], | |||||
| len(generate_tokens), | |||||
| dtype=torch.long, | |||||
| device=device) | |||||
| _, logits, sequence_output = self.model( | |||||
| tokens, | |||||
| None, | |||||
| attention_mask, | |||||
| dec_input_ids, | |||||
| attention_mask, | |||||
| position_ids, | |||||
| is_infer=True, | |||||
| sequence_output=sequence_output, | |||||
| parallel_output=False) | |||||
| logits = logits[:, -1, :] | |||||
| logits = logits / model_cfg['temperature'] | |||||
| logits = self.top_k_logits( | |||||
| logits, top_k=model_cfg['top_k'], top_p=model_cfg['top_p']) | |||||
| log_probs = F.softmax(logits, dim=-1) | |||||
| prev = torch.argmax(log_probs, 1).unsqueeze(1) | |||||
| # prev = torch.multinomial(log_probs, num_samples=1) | |||||
| prev_token = prev[0].item() | |||||
| if prev_token >= vocab_size: | |||||
| prev_token = 100 | |||||
| prev[0] = 100 | |||||
| if prev_token == 102 and len(all_generate_tokens) > int( | |||||
| max(1, out_length) * 0.8): | |||||
| break | |||||
| if prev_token == 102: | |||||
| counter += 1 | |||||
| continue | |||||
| dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) | |||||
| generate_tokens.append(prev_token) | |||||
| all_generate_tokens.append(prev_token) | |||||
| counter += 1 | |||||
| generate_context = [] | |||||
| for token in all_generate_tokens: | |||||
| if generate_context and generate_context[ | |||||
| -1] == 100 and token == 100: | |||||
| continue | |||||
| else: | |||||
| generate_context.append(token) | |||||
| return {'generate_context': generate_context} | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | def state_dict(self, destination=None, prefix='', keep_vars=False): | ||||
| return self.model.state_dict( | return self.model.state_dict( | ||||
| destination=destination, prefix=prefix, keep_vars=keep_vars) | destination=destination, prefix=prefix, keep_vars=keep_vars) | ||||
| @@ -225,7 +225,7 @@ class PlugNLGConfig(PlugNLUConfig): | |||||
| fp32_layernorm=True, | fp32_layernorm=True, | ||||
| fp32_embedding=False, | fp32_embedding=False, | ||||
| fp32_tokentypes=False, | fp32_tokentypes=False, | ||||
| layernorm_epsilon=1e-5, | |||||
| layernorm_epsilon=1e-12, | |||||
| attn_separate=False, | attn_separate=False, | ||||
| **kwargs): | **kwargs): | ||||
| super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | ||||
| @@ -75,7 +75,7 @@ class DistributedPlug(TorchModel): | |||||
| seed = 42 if 'seed' not in kwargs else kwargs['seed'] | seed = 42 if 'seed' not in kwargs else kwargs['seed'] | ||||
| set_random_seed_mpu(seed) | set_random_seed_mpu(seed) | ||||
| self.iteration = 0 | self.iteration = 0 | ||||
| self.dist_model = self.initialize_model(path_load_tag='model') | |||||
| self.model = self.initialize_model(path_load_tag='model') | |||||
| def initialize_model(self, path_load_tag='model'): | def initialize_model(self, path_load_tag='model'): | ||||
| """Build the model.""" | """Build the model.""" | ||||
| @@ -120,115 +120,28 @@ class DistributedPlug(TorchModel): | |||||
| model.module.model.load_state_dict(load_model, strict=False) | model.module.model.load_state_dict(load_model, strict=False) | ||||
| return model | return model | ||||
| @staticmethod | |||||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||||
| # This function has been mostly taken from huggingface conversational ai code at | |||||
| # https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||||
| # conversational-ai-with-transfer-learning-2d818ac26313 | |||||
| if top_k > 0: | |||||
| # Remove all tokens with a probability less than the last token of the top-k | |||||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||||
| None] | |||||
| logits[indices_to_remove] = filter_value | |||||
| if top_p > 0.0: | |||||
| # convert to 1D | |||||
| logits = logits.view(logits.size()[1]).contiguous() | |||||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||||
| cumulative_probs = torch.cumsum( | |||||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
| # Remove tokens with cumulative probability above the threshold | |||||
| sorted_indices_to_remove = cumulative_probs > top_p | |||||
| # Shift the indices to the right to keep also the first token above the threshold | |||||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||||
| ..., :-1].clone() | |||||
| sorted_indices_to_remove[..., 0] = 0 | |||||
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||||
| logits[indices_to_remove] = filter_value | |||||
| # going back to 2D | |||||
| logits = logits.view(1, -1).contiguous() | |||||
| return logits | |||||
| def forward(self, | |||||
| input_tokens, | |||||
| token_type_ids=None, | |||||
| attention_mask=None, | |||||
| target_tokens=None, | |||||
| position_ids=None, | |||||
| decode_attention_mask=None, | |||||
| checkpoint_activations=False, | |||||
| is_infer=False, | |||||
| sequence_output=None, | |||||
| parallel_output=True): | |||||
| return self.model( | |||||
| input_tokens, | |||||
| token_type_ids, | |||||
| attention_mask, | |||||
| target_tokens, | |||||
| position_ids, | |||||
| decode_attention_mask, | |||||
| checkpoint_activations=checkpoint_activations, | |||||
| is_infer=is_infer, | |||||
| sequence_output=sequence_output, | |||||
| parallel_output=parallel_output) | |||||
| def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): | def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): | ||||
| device = torch.cuda.current_device() | |||||
| batch_size = input['input_ids'].shape[0] | |||||
| tokens = input['input_ids'].view(1, -1).contiguous().to(device) | |||||
| dec_input_ids = input['dec_input_ids'].to(device) | |||||
| attention_mask = input['attention_mask'].to(device) | |||||
| self.dist_model.eval() | |||||
| with torch.no_grad(): | |||||
| # Only supports batch_size=1 | |||||
| all_generate_tokens = [] | |||||
| generate_tokens = [] | |||||
| counter = 0 | |||||
| sequence_output = None | |||||
| vocab_size = self.config.original_vocab_size | |||||
| sep_token_idx = 102 # index of [SEP] token in BertTokenizer | |||||
| while counter < out_length: | |||||
| if counter % 128 == 0 and counter != 0: | |||||
| # Sliding window | |||||
| generate_tokens.append(sep_token_idx) | |||||
| start = (tokens == sep_token_idx).nonzero( | |||||
| as_tuple=True)[-1] | |||||
| if start + len(generate_tokens) >= 512: | |||||
| tokens = torch.cat([ | |||||
| tokens[:start], | |||||
| torch.cuda.LongTensor(generate_tokens) | |||||
| ], -1)[-512:] | |||||
| else: | |||||
| tokens[0][start:start + len(generate_tokens | |||||
| )] = torch.cuda.LongTensor( | |||||
| generate_tokens) | |||||
| attention_mask = (tokens != 0) | |||||
| dec_input_ids = input['dec_input_ids'].to(device) | |||||
| generate_tokens = [] | |||||
| sequence_output = None | |||||
| position_ids = torch.full([batch_size, 1], | |||||
| len(generate_tokens), | |||||
| dtype=torch.long, | |||||
| device=device) | |||||
| _, logits, sequence_output = self.dist_model( | |||||
| tokens, | |||||
| None, | |||||
| attention_mask, | |||||
| dec_input_ids, | |||||
| attention_mask, | |||||
| position_ids, | |||||
| is_infer=True, | |||||
| sequence_output=sequence_output, | |||||
| parallel_output=False) | |||||
| logits = logits[:, -1, :] | |||||
| logits = logits / self.model_cfg['temperature'] | |||||
| logits = self.top_k_logits( | |||||
| logits, | |||||
| top_k=self.model_cfg['top_k'], | |||||
| top_p=self.model_cfg['top_p']) | |||||
| log_probs = F.softmax(logits, dim=-1) | |||||
| prev = torch.multinomial(log_probs, num_samples=1) | |||||
| prev_token = prev[0].item() | |||||
| if prev_token >= vocab_size: | |||||
| prev_token = 100 | |||||
| prev[0] = 100 | |||||
| if prev_token == 102 and len(all_generate_tokens) > int( | |||||
| max(1, out_length) * 0.8): | |||||
| break | |||||
| if prev_token == 102: | |||||
| counter += 1 | |||||
| continue | |||||
| dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) | |||||
| generate_tokens.append(prev_token) | |||||
| all_generate_tokens.append(prev_token) | |||||
| counter += 1 | |||||
| generate_context = [] | |||||
| for token in all_generate_tokens: | |||||
| if generate_context and generate_context[ | |||||
| -1] == 100 and token == 100: | |||||
| continue | |||||
| else: | |||||
| generate_context.append(token) | |||||
| return {'generate_context': generate_context} | |||||
| return self.model.generate(input, out_length, self.model_cfg, *kwargs) | |||||
| @@ -0,0 +1,225 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| class TextGenerator(object): | |||||
| def __init__(self, | |||||
| model, | |||||
| vocab, | |||||
| symbols, | |||||
| global_scorer=None, | |||||
| logger=None, | |||||
| dump_beam=''): | |||||
| self.alpha = 0.6 | |||||
| self.logger = logger | |||||
| self.cuda = (torch.cuda.device_count() > 0) | |||||
| self.model = model | |||||
| # TODO generator | |||||
| self.vocab = vocab | |||||
| self.symbols = symbols | |||||
| self.start_token = 101 # ['[PAD]'] | |||||
| self.end_token = 102 # '[PAD]'] | |||||
| self.global_scorer = global_scorer | |||||
| self.beam_size = 5 | |||||
| self.min_length = 5 | |||||
| self.max_length = 384 | |||||
| self.dump_beam = dump_beam | |||||
| # for debugging | |||||
| self.beam_trace = self.dump_beam != '' | |||||
| self.beam_accum = None | |||||
| if self.beam_trace: | |||||
| self.beam_accum = { | |||||
| 'predicted_ids': [], | |||||
| 'beam_parent_ids': [], | |||||
| 'scores': [], | |||||
| 'log_probs': [] | |||||
| } | |||||
| def _build_target_tokens(self, pred): | |||||
| tokens = [] | |||||
| for tok in pred: | |||||
| tok = int(tok) | |||||
| tokens.append(tok) | |||||
| if tokens[-1] == self.end_token: | |||||
| tokens = tokens[:-1] | |||||
| break | |||||
| tokens = [t for t in tokens if t < len(self.vocab)] | |||||
| tokens = self.vocab.DecodeIds(tokens).split(' ') | |||||
| return tokens | |||||
| def tile(self, x, count, dim=0): | |||||
| """ | |||||
| Tiles x on dimension dim count times. | |||||
| """ | |||||
| perm = list(range(len(x.size()))) | |||||
| if dim != 0: | |||||
| perm[0], perm[dim] = perm[dim], perm[0] | |||||
| x = x.permute(perm).contiguous() | |||||
| out_size = list(x.size()) | |||||
| out_size[0] *= count | |||||
| batch = x.size(0) | |||||
| x = x.view(batch, -1) \ | |||||
| .transpose(0, 1) \ | |||||
| .repeat(count, 1) \ | |||||
| .transpose(0, 1) \ | |||||
| .contiguous() \ | |||||
| .view(*out_size) | |||||
| if dim != 0: | |||||
| x = x.permute(perm).contiguous() | |||||
| return x | |||||
| def translate_batch(self, encoder_inputs, fast=False): | |||||
| with torch.no_grad(): | |||||
| return self._fast_translate_batch( | |||||
| encoder_inputs, self.max_length, min_length=self.min_length) | |||||
| def _fast_translate_batch(self, encoder_inputs, max_length, min_length=0): | |||||
| assert not self.dump_beam | |||||
| beam_size = self.beam_size | |||||
| tokens, types, padding_mask = encoder_inputs | |||||
| batch_size = tokens.size(0) | |||||
| device = tokens.device | |||||
| tmp_alive_seq = torch.full([batch_size, 1], | |||||
| self.start_token, | |||||
| dtype=torch.long, | |||||
| device=device) | |||||
| prediction_scores, dec_feat_seq, sequence_output = self.model( | |||||
| tokens, | |||||
| types, | |||||
| padding_mask, | |||||
| tmp_alive_seq, | |||||
| None, | |||||
| None, | |||||
| checkpoint_activations=False, | |||||
| is_infer=True, | |||||
| parallel_output=False, | |||||
| sequence_output=None) | |||||
| src_features = sequence_output | |||||
| src_features = self.tile(src_features, beam_size, dim=0) | |||||
| attention_mask = self.tile(padding_mask, beam_size, dim=0) | |||||
| batch_offset = torch.arange( | |||||
| batch_size, dtype=torch.long, device=device) | |||||
| beam_offset = torch.arange( | |||||
| 0, | |||||
| batch_size * beam_size, | |||||
| step=beam_size, | |||||
| dtype=torch.long, | |||||
| device=device) | |||||
| alive_seq = torch.full([batch_size * beam_size, 1], | |||||
| self.start_token, | |||||
| dtype=torch.long, | |||||
| device=device) | |||||
| # Give full probability to the first beam on the first step. | |||||
| topk_log_probs = ( | |||||
| torch.tensor( | |||||
| [0.0] + [float('-inf')] * (beam_size - 1), | |||||
| device=device).repeat(batch_size)) | |||||
| # Structure that holds finished hypotheses. | |||||
| hypotheses = [[] for _ in range(batch_size)] # noqa: F812 | |||||
| results = {} | |||||
| results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812 | |||||
| results['scores'] = [[] for _ in range(batch_size)] # noqa: F812 | |||||
| results['gold_score'] = [0] * batch_size | |||||
| results['batch'] = [] | |||||
| dec_attn_mask = None | |||||
| dec_position_ids = None | |||||
| for step in range(max_length): | |||||
| prediction_scores, dec_feat_seq, _ = self.model( | |||||
| tokens, | |||||
| types, | |||||
| attention_mask, | |||||
| alive_seq, | |||||
| dec_position_ids, | |||||
| dec_attn_mask, | |||||
| checkpoint_activations=False, | |||||
| is_infer=True, | |||||
| parallel_output=False, | |||||
| sequence_output=src_features) | |||||
| dec_feat_seq = dec_feat_seq[:, -1, :] | |||||
| vocab_size = dec_feat_seq.size(-1) | |||||
| log_probs = torch.log( | |||||
| torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1)) | |||||
| if step < min_length: | |||||
| log_probs[:, self.end_token] = -1e20 | |||||
| log_probs += topk_log_probs.view(-1).unsqueeze(1) | |||||
| alpha = self.alpha # global_scorer.alpha | |||||
| length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha | |||||
| curr_scores = log_probs / length_penalty | |||||
| curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) | |||||
| topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) | |||||
| topk_log_probs = topk_scores * length_penalty | |||||
| # Resolve beam origin and true word ids. | |||||
| topk_beam_index = topk_ids.div(vocab_size, rounding_mode='trunc') | |||||
| topk_ids = topk_ids.fmod(vocab_size) | |||||
| # Map beam_index to batch_index in the flat representation. | |||||
| batch_index = ( | |||||
| topk_beam_index | |||||
| + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) | |||||
| select_indices = batch_index.view(-1) | |||||
| # Append last prediction. | |||||
| alive_seq = torch.cat([ | |||||
| alive_seq.index_select(0, select_indices), | |||||
| topk_ids.view(-1, 1) | |||||
| ], -1) | |||||
| is_finished = topk_ids.eq(self.end_token) | |||||
| if step + 1 == max_length: | |||||
| is_finished.fill_(1) # self.end_token) | |||||
| # End condition is top beam is finished. | |||||
| end_condition = is_finished[:, 0].eq(1) # self.end_token) | |||||
| # Save finished hypotheses. | |||||
| if is_finished.any(): | |||||
| predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) | |||||
| for i in range(is_finished.size(0)): | |||||
| b = batch_offset[i] | |||||
| if end_condition[i]: | |||||
| is_finished[i].fill_(1) # self.end_token) | |||||
| finished_hyp = is_finished[i].nonzero().view(-1) | |||||
| # Store finished hypotheses for this batch. | |||||
| for j in finished_hyp: | |||||
| hypotheses[b].append( | |||||
| (topk_scores[i, j], predictions[i, j, 1:])) | |||||
| # If the batch reached the end, save the n_best hypotheses. | |||||
| if end_condition[i]: | |||||
| best_hyp = sorted( | |||||
| hypotheses[b], key=lambda x: x[0], reverse=True) | |||||
| score, pred = best_hyp[0] | |||||
| results['scores'][b].append(score) | |||||
| results['predictions'][b].append(pred) | |||||
| non_finished = end_condition.eq(0).nonzero().view(-1) | |||||
| # If all sentences are translated, no need to go further. | |||||
| if len(non_finished) == 0: | |||||
| break | |||||
| # Remove finished batches for the next step. | |||||
| topk_log_probs = topk_log_probs.index_select(0, non_finished) | |||||
| batch_index = batch_index.index_select(0, non_finished) | |||||
| batch_offset = batch_offset.index_select(0, non_finished) | |||||
| alive_seq = predictions.index_select(0, non_finished) \ | |||||
| .view(-1, alive_seq.size(-1)) | |||||
| # Reorder states. | |||||
| select_indices = batch_index.view(-1) | |||||
| src_features = src_features.index_select(0, select_indices) | |||||
| attention_mask = attention_mask.index_select(0, select_indices) | |||||
| return results | |||||
| @@ -122,6 +122,8 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase): | |||||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | ||||
| False) | False) | ||||
| kwargs['max_length'] = sequence_length | kwargs['max_length'] = sequence_length | ||||
| self.src_length = kwargs['max_length'] | |||||
| self.tgt_length = kwargs.pop('target_max_length', kwargs['max_length']) | |||||
| model_type = None | model_type = None | ||||
| if model_dir is not None: | if model_dir is not None: | ||||
| model_type = get_model_type(model_dir) | model_type = get_model_type(model_dir) | ||||
| @@ -154,10 +156,14 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase): | |||||
| 'return_tensors'] = 'pt' if self.mode == ModeKeys.INFERENCE else None | 'return_tensors'] = 'pt' if self.mode == ModeKeys.INFERENCE else None | ||||
| output = self.nlp_tokenizer(sequence1, **kwargs) | output = self.nlp_tokenizer(sequence1, **kwargs) | ||||
| if self.mode != ModeKeys.INFERENCE: | if self.mode != ModeKeys.INFERENCE: | ||||
| if sequence2 is not None: | if sequence2 is not None: | ||||
| self.nlp_tokenizer.tokenize_kwargs[ | |||||
| 'max_length'] = self.tgt_length | |||||
| labels = self.nlp_tokenizer(sequence2)['input_ids'] | labels = self.nlp_tokenizer(sequence2)['input_ids'] | ||||
| self.nlp_tokenizer.tokenize_kwargs[ | |||||
| 'max_length'] = self.src_length | |||||
| src_input_ids = output['input_ids'] | src_input_ids = output['input_ids'] | ||||
| src_attention_mask = output['attention_mask'] | src_attention_mask = output['attention_mask'] | ||||
| else: | else: | ||||
| @@ -25,7 +25,7 @@ else: | |||||
| 'hook': ['Hook'], | 'hook': ['Hook'], | ||||
| 'iter_timer_hook': ['IterTimerHook'], | 'iter_timer_hook': ['IterTimerHook'], | ||||
| 'logger': ['TensorboardHook', 'TextLoggerHook'], | 'logger': ['TensorboardHook', 'TextLoggerHook'], | ||||
| 'lr_scheduler_hook': ['LrSchedulerHook'], | |||||
| 'lr_scheduler_hook': ['LrSchedulerHook', 'NoneLrSchedulerHook'], | |||||
| 'optimizer_hook': [ | 'optimizer_hook': [ | ||||
| 'ApexAMPOptimizerHook', 'NoneOptimizerHook', 'OptimizerHook', | 'ApexAMPOptimizerHook', 'NoneOptimizerHook', 'OptimizerHook', | ||||
| 'TorchAMPOptimizerHook' | 'TorchAMPOptimizerHook' | ||||
| @@ -104,7 +104,8 @@ class CheckpointHook(Hook): | |||||
| return | return | ||||
| if self._should_save(trainer): | if self._should_save(trainer): | ||||
| if is_master(): | |||||
| if is_master() or trainer.cfg.model.get('model_parallel_size', | |||||
| 1) != 1: | |||||
| self.logger.info( | self.logger.info( | ||||
| f'Saving checkpoint at {trainer.epoch + 1} epoch') | f'Saving checkpoint at {trainer.epoch + 1} epoch') | ||||
| self._save_checkpoint(trainer) | self._save_checkpoint(trainer) | ||||
| @@ -260,7 +261,8 @@ class CheckpointHook(Hook): | |||||
| return | return | ||||
| if self._should_save(trainer): | if self._should_save(trainer): | ||||
| if is_master(): | |||||
| if is_master() or trainer.cfg.model.get('model_parallel_size', | |||||
| 1) != 1: | |||||
| self.logger.info( | self.logger.info( | ||||
| f'Saving checkpoint at {trainer.iter + 1} iterations') | f'Saving checkpoint at {trainer.iter + 1} iterations') | ||||
| self._save_checkpoint(trainer) | self._save_checkpoint(trainer) | ||||
| @@ -0,0 +1,116 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from types import MethodType | |||||
| import deepspeed | |||||
| from megatron import mpu | |||||
| from modelscope.metainfo import Hooks | |||||
| from modelscope.trainers.hooks import (BestCkptSaverHook, CheckpointHook, | |||||
| LrSchedulerHook, NoneLrSchedulerHook, | |||||
| NoneOptimizerHook, OptimizerHook) | |||||
| from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | |||||
| from modelscope.utils.constant import LogKeys, ModelFile | |||||
| from modelscope.utils.torch_utils import is_master | |||||
| from .builder import HOOKS | |||||
| from .hook import Hook | |||||
| from .priority import Priority | |||||
| @HOOKS.register_module(module_name=Hooks.DeepspeedHook) | |||||
| class DeepspeedHook(Hook): | |||||
| PRIORITY = Priority.VERY_HIGH | |||||
| def __init__(self, | |||||
| deepspeed_activation_checkpointing=True, | |||||
| save_zero_checkpoint=False, | |||||
| loss_key='loss'): | |||||
| self.save_zero_checkpoint = save_zero_checkpoint | |||||
| self.loss_key = loss_key | |||||
| self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing | |||||
| def before_run(self, trainer): | |||||
| # deepspeed init | |||||
| args = trainer.cfg.train | |||||
| args.deepspeed_config = os.path.join(trainer.model_dir, | |||||
| args.deepspeed_config) | |||||
| trainer.model, _, _, _ = deepspeed.initialize( | |||||
| model=trainer.model, | |||||
| optimizer=trainer.optimizer, | |||||
| args=args, | |||||
| lr_scheduler=trainer.lr_scheduler, | |||||
| mpu=mpu, | |||||
| dist_init_required=False) | |||||
| trainer.model.save_zero_checkpoint = self.save_zero_checkpoint | |||||
| if self.deepspeed_activation_checkpointing: | |||||
| model = trainer.model | |||||
| while hasattr(model, 'module'): | |||||
| model = model.module | |||||
| deepspeed.checkpointing.configure( | |||||
| mpu, | |||||
| deepspeed_config=args.deepspeed_config, | |||||
| num_checkpoints=model.config.num_hidden_layers) | |||||
| mpu.checkpoint = deepspeed.checkpointing.checkpoint | |||||
| mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker | |||||
| mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed | |||||
| # modify hooks | |||||
| for i, hook in enumerate(trainer._hooks): | |||||
| # backward & step | |||||
| if isinstance(hook, OptimizerHook): | |||||
| trainer._hooks[i] = NoneOptimizerHook() | |||||
| if isinstance(hook, LrSchedulerHook): | |||||
| trainer._hooks[i] = NoneLrSchedulerHook() | |||||
| # save checkpoint | |||||
| if isinstance(hook, CheckpointHook): | |||||
| def _save_checkpoint(self, trainer): | |||||
| if self.by_epoch: | |||||
| cur_save_dir = os.path.join( | |||||
| self.save_dir, | |||||
| f'{LogKeys.EPOCH}_{trainer.epoch + 1}') | |||||
| else: | |||||
| cur_save_dir = os.path.join( | |||||
| self.save_dir, | |||||
| f'{LogKeys.ITER}_{trainer.iter + 1}') | |||||
| if (self.is_last_epoch(trainer) | |||||
| and self.by_epoch) or (self.is_last_iter(trainer) | |||||
| and not self.by_epoch): | |||||
| cur_save_dir = os.path.join(self.save_dir, | |||||
| ModelFile.TRAIN_OUTPUT_DIR) | |||||
| trainer.model.save_checkpoint(cur_save_dir) | |||||
| trainer._hooks[i]._save_checkpoint = MethodType( | |||||
| _save_checkpoint, trainer._hooks[i]) | |||||
| if isinstance(hook, BestCkptSaverHook): | |||||
| def _save_checkpoint(self, trainer): | |||||
| if self.by_epoch: | |||||
| cur_save_dir = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}' | |||||
| ) | |||||
| else: | |||||
| cur_save_dir = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| trainer.model.save_checkpoint(cur_save_dir) | |||||
| self._best_ckpt_file = cur_save_dir | |||||
| trainer._hooks[i]._save_checkpoint = MethodType( | |||||
| _save_checkpoint, trainer._hooks[i]) | |||||
| def after_train_iter(self, trainer): | |||||
| # The `trainer.model` here is actually a deepspeed engine object. | |||||
| # backward step | |||||
| loss = trainer.train_outputs[self.loss_key] | |||||
| trainer.model.backward(loss) | |||||
| # update parameters | |||||
| trainer.model.step() | |||||
| @@ -80,7 +80,8 @@ class TextLoggerHook(LoggerHook): | |||||
| dtype=torch.int, | dtype=torch.int, | ||||
| device=device) | device=device) | ||||
| _, world_size = get_dist_info() | _, world_size = get_dist_info() | ||||
| if world_size > 1: | |||||
| if world_size > 1 and getattr(trainer.cfg.model, 'model_parallel_size', | |||||
| 1) < world_size: | |||||
| dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) | dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) | ||||
| return mem_mb.item() | return mem_mb.item() | ||||
| @@ -0,0 +1,195 @@ | |||||
| import os | |||||
| from typing import Callable, Dict, List, Optional, Tuple, Union | |||||
| import torch | |||||
| from megatron import mpu | |||||
| from torch import nn | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import Model, TorchModel | |||||
| from modelscope.models.nlp.plug import DistributedPlug | |||||
| from modelscope.models.nlp.plug.backbone import BertLayerNorm | |||||
| from modelscope.models.nlp.plug.generator import TextGenerator | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| from ..base import TRAINERS | |||||
| from ..nlp_trainer import NlpEpochBasedTrainer | |||||
| @TRAINERS.register_module(module_name=Trainers.nlp_plug_trainer) | |||||
| class PlugTrainer(NlpEpochBasedTrainer): | |||||
| def build_model(self) -> Union[nn.Module, TorchModel]: | |||||
| rank = int(os.environ.get('LOCAL_RANK', -1)) | |||||
| master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1') | |||||
| master_port = os.environ.get('MASTER_PORT', '29500') | |||||
| model = DistributedPlug( | |||||
| self.model_dir, | |||||
| rank, | |||||
| master_ip=master_ip, | |||||
| master_port=master_port, | |||||
| **self.cfg.model) | |||||
| return model.model | |||||
| def to_parallel(self, model) -> Union[nn.Module, TorchModel]: | |||||
| from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP | |||||
| return DDP(model) | |||||
| def _get_params_for_weight_decay_optimization(self, module): | |||||
| weight_decay_params = {'params': []} | |||||
| no_weight_decay_params = {'params': [], 'weight_decay': 0.0} | |||||
| for module_ in module.modules(): | |||||
| if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)): | |||||
| no_weight_decay_params['params'].extend([ | |||||
| p for p in list(module_._parameters.values()) | |||||
| if p is not None | |||||
| ]) | |||||
| else: | |||||
| weight_decay_params['params'].extend([ | |||||
| p for n, p in list(module_._parameters.items()) | |||||
| if p is not None and 'mask_score' not in n | |||||
| and 'mask' not in n and n != 'bias' | |||||
| ]) | |||||
| no_weight_decay_params['params'].extend([ | |||||
| p for n, p in list(module_._parameters.items()) | |||||
| if p is not None and n == 'bias' | |||||
| ]) | |||||
| return weight_decay_params, no_weight_decay_params | |||||
| def create_optimizer_and_scheduler(self): | |||||
| optimizer, lr_scheduler = self.optimizers | |||||
| optimizer_cfg = self.cfg.train.get('optimizer', None) | |||||
| # optim_options = {} | |||||
| if optimizer_cfg is not None: | |||||
| optim_options = optimizer_cfg.pop('options', {}) | |||||
| from deepspeed.ops.adam import DeepSpeedCPUAdam | |||||
| model = self.model | |||||
| embeddings = model.module.module.model.bert.embeddings | |||||
| layers = model.module.module.model.bert.encoder.layer | |||||
| dec_layers = model.module.module.model.decoder.decoder | |||||
| param_groups = [] | |||||
| param_groups += list( | |||||
| self._get_params_for_weight_decay_optimization(layers)) | |||||
| param_groups += list( | |||||
| self._get_params_for_weight_decay_optimization(embeddings)) | |||||
| param_groups += list( | |||||
| self._get_params_for_weight_decay_optimization(dec_layers)) | |||||
| for param_group in param_groups: | |||||
| for param in param_group['params']: | |||||
| if not hasattr(param, 'model_parallel'): | |||||
| param.model_parallel = False | |||||
| optimizer = DeepSpeedCPUAdam( | |||||
| param_groups, | |||||
| lr=optimizer_cfg.lr, | |||||
| weight_decay=optimizer_cfg.weight_decay) | |||||
| lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) | |||||
| if lr_scheduler_cfg is not None: | |||||
| assert optimizer is not None | |||||
| lr_options = lr_scheduler_cfg.pop('options', {}) | |||||
| from modelscope.models.nlp.plug.AnnealingLR import AnnealingLR | |||||
| num_iters = self.max_iters | |||||
| lr_scheduler = AnnealingLR( | |||||
| optimizer, | |||||
| start_lr=optimizer_cfg.lr, | |||||
| warmup_iter=lr_scheduler_cfg.warmup * num_iters, | |||||
| num_iters=num_iters, | |||||
| decay_style=lr_scheduler_cfg.decay_style, | |||||
| last_iter=-1) | |||||
| self.optimizer = optimizer | |||||
| self.lr_scheduler = lr_scheduler | |||||
| return self.optimizer, self.lr_scheduler, optim_options, lr_options | |||||
| def _get_masks_and_position_ids(self, data, eod_token): | |||||
| # Extract batch size and sequence length. | |||||
| batch_size, seq_length = data.size() | |||||
| # Attention mask (lower triangular). | |||||
| att_mask_batch = 1 | |||||
| attention_mask = torch.tril( | |||||
| torch.ones((att_mask_batch, seq_length, seq_length), | |||||
| device=data.device)).view(att_mask_batch, 1, seq_length, | |||||
| seq_length) | |||||
| # Loss mask. | |||||
| loss_mask = torch.ones( | |||||
| data.size(), dtype=torch.float, device=data.device) | |||||
| loss_mask[data == eod_token] = 0.0 | |||||
| # Position ids. | |||||
| position_ids = torch.arange( | |||||
| seq_length, dtype=torch.long, device=data.device) | |||||
| position_ids = position_ids.unsqueeze(0).expand_as(data) | |||||
| return attention_mask, loss_mask, position_ids | |||||
| def train_step(self, model, inputs): | |||||
| self._mode = ModeKeys.TRAIN | |||||
| # format inputs | |||||
| checkpoint_activations = getattr(self.cfg.train, | |||||
| 'checkpoint_activations', True) | |||||
| tgt_tokens = inputs['labels'][:, :-1].contiguous() | |||||
| tgt_labels = inputs['labels'][:, 1:].contiguous() | |||||
| tgt_attention_mask, dec_loss_mask, position_ids = self._get_masks_and_position_ids( | |||||
| tgt_tokens, 0) | |||||
| if getattr(self.cfg.train, 'fp16', None): | |||||
| tgt_attention_mask = tgt_attention_mask.half() | |||||
| # forward step | |||||
| _, output = model( | |||||
| inputs['input_ids'], | |||||
| None, | |||||
| inputs['attention_mask'], | |||||
| tgt_tokens, | |||||
| position_ids, | |||||
| tgt_attention_mask, | |||||
| checkpoint_activations=checkpoint_activations) | |||||
| losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), | |||||
| tgt_labels) | |||||
| dec_loss_mask = dec_loss_mask.view(-1) | |||||
| loss = torch.sum(losses.view(-1) * dec_loss_mask) / dec_loss_mask.sum() | |||||
| # add model output info to log | |||||
| self.train_outputs = {'loss': loss} | |||||
| self.log_buffer.update(self.train_outputs) | |||||
| def evaluation_step(self, data): | |||||
| # wapper 1: DeepspeedEngine, wapper 2: DDP | |||||
| model = self.model.module.module | |||||
| model.eval() | |||||
| # model: fp16 wapper; model.module : distributedPlug | |||||
| vocab_size = model.module.config.original_vocab_size | |||||
| batch_size = data['input_ids'].shape[0] | |||||
| beam_generator = TextGenerator(model, | |||||
| self.eval_preprocessor.nlp_tokenizer, | |||||
| None) | |||||
| with torch.no_grad(): | |||||
| tokens = data['input_ids'].long() | |||||
| padding_mask = data['attention_mask'].byte() | |||||
| target_ids = data['labels'].long() | |||||
| target_labels = target_ids[:, 1:].contiguous() | |||||
| encoder_inputs = [tokens, None, padding_mask] | |||||
| result = beam_generator.translate_batch(encoder_inputs) | |||||
| pred_list = result['predictions'] | |||||
| target_list = target_labels.cpu().numpy().tolist() | |||||
| result['preds'] = [] | |||||
| data['tgts'] = [] | |||||
| for i in range(batch_size): | |||||
| pred_ids = pred_list[i][0] | |||||
| pred_ids[pred_ids > vocab_size - 1] = 100 | |||||
| pred_ids = pred_ids.cpu().numpy().tolist() | |||||
| gold_string = self.eval_preprocessor.decode( | |||||
| target_list[i], skip_special_tokens=True) | |||||
| pred_string = self.eval_preprocessor.decode( | |||||
| pred_ids, skip_special_tokens=True) | |||||
| result['preds'].append(pred_string) | |||||
| data['tgts'].append(gold_string) | |||||
| return result | |||||
| @@ -845,7 +845,10 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| batch_size = batch_size_per_gpu | batch_size = batch_size_per_gpu | ||||
| num_workers = workers_per_gpu | num_workers = workers_per_gpu | ||||
| if dist and not isinstance(dataset, torch.utils.data.IterableDataset): | |||||
| if dist and not isinstance( | |||||
| dataset, | |||||
| torch.utils.data.IterableDataset) and self.cfg.model.get( | |||||
| 'model_parallel_size', 1) == 1: | |||||
| sampler = DistributedSampler( | sampler = DistributedSampler( | ||||
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | ||||
| else: | else: | ||||
| @@ -935,7 +938,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | ||||
| """ | """ | ||||
| if self._dist: | |||||
| if self._dist and self.cfg.model.get('model_parallel_size', 1) == 1: | |||||
| from modelscope.trainers.utils.inference import multi_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test | ||||
| metric_values = multi_gpu_test( | metric_values = multi_gpu_test( | ||||
| self, | self, | ||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import argparse | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| def test_trainer_with_model_and_args(): | |||||
| def concat_answer_context(dataset): | |||||
| dataset['src_txt'] = dataset['answers']['text'][0] + '[SEP]' + dataset[ | |||||
| 'context'] | |||||
| return dataset | |||||
| from datasets import load_dataset | |||||
| dataset_dict = load_dataset('luozhouyang/dureader', 'robust') | |||||
| train_dataset = dataset_dict['train'].map(concat_answer_context) \ | |||||
| .rename_columns({'question': 'tgt_txt'}).remove_columns('context') \ | |||||
| .remove_columns('id').remove_columns('answers') | |||||
| eval_dataset = dataset_dict['validation'].map(concat_answer_context) \ | |||||
| .rename_columns({'question': 'tgt_txt'}).remove_columns('context') \ | |||||
| .remove_columns('id').remove_columns('answers') | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| model_id = 'damo/nlp_plug_text-generation_27B' | |||||
| kwargs = dict( | |||||
| model=model_id, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| work_dir=tmp_dir) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.nlp_plug_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('--local_rank') | |||||
| test_trainer_with_model_and_args() | |||||