| @@ -3,6 +3,8 @@ from typing import Any, Dict, Optional | |||
| from maas_lib.utils.constant import Tasks | |||
| from ...base import Model, Tensor | |||
| from ...builder import MODELS | |||
| from .model.generator import Generator | |||
| from .model.model_base import ModelBase | |||
| __all__ = ['DialogGenerationModel'] | |||
| @@ -21,7 +23,14 @@ class DialogGenerationModel(Model): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| pass | |||
| self.text_field = kwargs.pop('text_field') | |||
| self.config = kwargs.pop('config') | |||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||
| self.model = ModelBase.create( | |||
| model_dir=model_dir, | |||
| config=self.config, | |||
| reader=self.text_field, | |||
| generator=self.generator) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| @@ -0,0 +1,285 @@ | |||
| """ | |||
| IntentUnifiedTransformer | |||
| """ | |||
| import torch | |||
| from maas_lib.models.nlp.space.model.unified_transformer import \ | |||
| UnifiedTransformer | |||
| class GenUnifiedTransformer(UnifiedTransformer): | |||
| """ | |||
| Implement generation unified transformer. | |||
| """ | |||
| def __init__(self, model_dir, config, reader, generator): | |||
| super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, | |||
| generator) | |||
| self.understand = config.BPETextField.understand | |||
| if self.use_gpu: | |||
| self.cuda() | |||
| return | |||
| def _forward(self, inputs, is_training, with_label): | |||
| """ Real forward process of model in different mode(train/test). """ | |||
| def cat(x, y, dim=1): | |||
| return torch.cat([x, y], dim=dim) | |||
| outputs = {} | |||
| if self.understand or self.policy: | |||
| if self.understand: | |||
| prompt_token = inputs['understand_token'] | |||
| prompt_mask = inputs['understand_mask'] | |||
| if self.policy: | |||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||
| else: | |||
| prompt_token = inputs['policy_token'] | |||
| prompt_mask = inputs['policy_mask'] | |||
| enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| tgt_token=inputs['tgt_token'][:, :-1], | |||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||
| prompt_token=prompt_token, | |||
| prompt_mask=prompt_mask, | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn'], | |||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||
| tgt_type=inputs['tgt_type'][:, :-1], | |||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||
| else: | |||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| tgt_token=inputs['tgt_token'][:, :-1], | |||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn'], | |||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||
| tgt_type=inputs['tgt_type'][:, :-1], | |||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||
| outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) | |||
| return outputs | |||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
| metrics = {} | |||
| loss = 0. | |||
| label = inputs['tgt_token'][:, 1:] | |||
| token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) | |||
| nll = self.nll_loss( | |||
| torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) | |||
| nll = torch.sum(nll, dim=1) | |||
| token_nll = torch.sum(nll) / token_num | |||
| nll = torch.mean(nll) | |||
| metrics['nll'] = nll | |||
| metrics['token_nll'] = token_nll | |||
| metrics['token_num'] = token_num | |||
| loss = loss + (token_nll if self.token_loss else nll) | |||
| metrics['loss'] = loss | |||
| if self.gpu > 1: | |||
| return nll, token_nll, token_num | |||
| else: | |||
| return metrics | |||
| def _optimize(self, loss, do_update=False, optimizer=None): | |||
| """ Optimize loss function and update model. """ | |||
| assert optimizer is not None | |||
| if self.gradient_accumulation_steps > 1: | |||
| loss = loss / self.gradient_accumulation_steps | |||
| loss.backward() | |||
| if self.grad_clip is not None and self.grad_clip > 0: | |||
| torch.nn.utils.clip_grad_norm_( | |||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||
| if do_update: | |||
| optimizer.step() | |||
| optimizer.zero_grad() | |||
| return | |||
| def _init_state(self, | |||
| src_token, | |||
| src_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None): | |||
| """ Initialize decode state. """ | |||
| state = {} | |||
| batch_size = src_token.shape[0] | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| src_embed = self.embed_layer_norm(src_embed) | |||
| mask = self._create_mask(src_mask, append_head=False) | |||
| enc_out = src_embed | |||
| cache = {} | |||
| for l, layer in enumerate(self.layers): | |||
| cache[f'layer_{l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||
| state['cache'] = cache | |||
| state['mask'] = mask[:, :1] | |||
| state['batch_size'] = batch_size | |||
| shape = [batch_size, 1, 1] | |||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||
| if self.use_gpu: | |||
| state['pred_mask'] = state['pred_mask'].cuda() | |||
| state['pred_pos'] = state['pred_pos'].cuda() | |||
| state['pred_type'] = state['pred_type'].cuda() | |||
| state['pred_turn'] = state['pred_turn'].cuda() | |||
| return state | |||
| def _init_prompt_state(self, | |||
| src_token, | |||
| src_mask, | |||
| prompt_token, | |||
| prompt_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None, | |||
| prompt_pos=None, | |||
| prompt_type=None, | |||
| prompt_turn=None): | |||
| """ Initialize decode state. """ | |||
| state = {} | |||
| batch_size = src_token.shape[0] | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||
| prompt_turn) | |||
| embed = torch.cat([src_embed, prompt_embed], dim=1) | |||
| embed = self.embed_layer_norm(embed) | |||
| enc_out = embed | |||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
| dec_mask = self._create_mask(prompt_mask, auto_regressive=True) | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| cache = {} | |||
| for l, layer in enumerate(self.layers): | |||
| cache[f'layer_{l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||
| state['cache'] = cache | |||
| state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | |||
| state['batch_size'] = batch_size | |||
| shape = [batch_size, 1, 1] | |||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||
| if self.use_gpu: | |||
| state['pred_mask'] = state['pred_mask'].cuda() | |||
| state['pred_pos'] = state['pred_pos'].cuda() | |||
| state['pred_type'] = state['pred_type'].cuda() | |||
| state['pred_turn'] = state['pred_turn'].cuda() | |||
| return state | |||
| def _decode(self, state): | |||
| """ Decoding one time stamp. """ | |||
| # shape: [batch_size, 1, seq_len] | |||
| mask = state['mask'] | |||
| # shape: [batch_size, 1, 1] | |||
| pred_token = state['pred_token'] | |||
| pred_mask = state['pred_mask'] | |||
| pred_pos = state['pred_pos'] | |||
| pred_type = state['pred_type'] | |||
| pred_turn = state['pred_turn'] | |||
| # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] | |||
| cache = state['cache'] | |||
| pred_embed = self.embedder(pred_token, pred_pos, pred_type, | |||
| pred_turn).squeeze(-2) | |||
| pred_embed = self.embed_layer_norm(pred_embed) | |||
| # shape: [batch_size, 1, seq_len + 1] | |||
| mask = torch.cat([mask, 1 - pred_mask], dim=2) | |||
| # shape: [batch_size, 1, hidden_dim] | |||
| for l, layer in enumerate(self.layers): | |||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{l}']) | |||
| # shape: [batch_size, vocab_size] | |||
| pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | |||
| pred_logits = torch.log(pred_probs) | |||
| state['mask'] = mask | |||
| return pred_logits, state | |||
| def _infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ Real inference process of model. """ | |||
| def cat(x, y, dim=1): | |||
| return torch.cat([x, y], dim=dim) | |||
| # Initial decode state. | |||
| if self.understand or self.policy: | |||
| if self.understand: | |||
| prompt_token = inputs['understand_token'] | |||
| prompt_mask = inputs['understand_mask'] | |||
| if self.policy: | |||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||
| else: | |||
| prompt_token = inputs['policy_token'] | |||
| prompt_mask = inputs['policy_mask'] | |||
| state = self._init_prompt_state( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| prompt_token=prompt_token, | |||
| prompt_mask=prompt_mask, | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn']) | |||
| else: | |||
| state = self._init_state( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn']) | |||
| # Generation process. | |||
| gen_results = self.generator( | |||
| step_fn=self._decode, | |||
| state=state, | |||
| start_id=start_id, | |||
| eos_id=eos_id, | |||
| max_gen_len=max_gen_len, | |||
| prev_input=prev_input) | |||
| outputs = gen_results['preds'] | |||
| return outputs | |||
| GenUnifiedTransformer.register('GenUnifiedTransformer') | |||
| @@ -0,0 +1,296 @@ | |||
| """ | |||
| Generator class. | |||
| """ | |||
| import math | |||
| import numpy as np | |||
| import torch | |||
| from .gen_unified_transformer import GenUnifiedTransformer | |||
| from .unified_transformer import UnifiedTransformer | |||
| def repeat(var, times): | |||
| if isinstance(var, list): | |||
| return [repeat(x, times) for x in var] | |||
| elif isinstance(var, dict): | |||
| return {k: repeat(v, times) for k, v in var.items()} | |||
| elif isinstance(var, torch.Tensor): | |||
| var = var.unsqueeze(1) | |||
| expand_times = [1] * len(var.shape) | |||
| expand_times[1] = times | |||
| dtype = var.dtype | |||
| var = var.float() | |||
| var = var.repeat(*expand_times) | |||
| shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) | |||
| var = var.reshape(*shape) | |||
| var = torch.tensor(var, dtype=dtype) | |||
| return var | |||
| else: | |||
| return var | |||
| def gather(var, idx): | |||
| if isinstance(var, list): | |||
| return [gather(x, idx) for x in var] | |||
| elif isinstance(var, dict): | |||
| return {k: gather(v, idx) for k, v in var.items()} | |||
| elif isinstance(var, torch.Tensor): | |||
| out = var.index_select(dim=0, index=idx) | |||
| return out | |||
| else: | |||
| return var | |||
| class Generator(object): | |||
| """ Genrator class. """ | |||
| _registry = dict() | |||
| @classmethod | |||
| def register(cls, name): | |||
| Generator._registry[name] = cls | |||
| return | |||
| @staticmethod | |||
| def by_name(name): | |||
| return Generator._registry[name] | |||
| @staticmethod | |||
| def create(config, *args, **kwargs): | |||
| """ Create generator. """ | |||
| generator_cls = Generator.by_name(config.Generator.generator) | |||
| return generator_cls(config, *args, **kwargs) | |||
| def __init__(self, config, reader): | |||
| self.vocab_size = reader.vocab_size | |||
| self.bos_id = reader.bos_id | |||
| self.eos_id = reader.eos_id | |||
| self.unk_id = reader.unk_id | |||
| self.pad_id = reader.pad_id | |||
| self.min_gen_len = config.Generator.min_gen_len | |||
| self.max_gen_len = config.Generator.max_gen_len | |||
| self.use_gpu = config.use_gpu | |||
| assert 1 <= self.min_gen_len <= self.max_gen_len | |||
| return | |||
| def __call__(self, step_fn, state): | |||
| """ | |||
| Running generation. | |||
| @param : step_fn : decoding one step | |||
| @type : function | |||
| @param : state : initial state | |||
| @type : dict | |||
| """ | |||
| raise NotImplementedError | |||
| class BeamSearch(Generator): | |||
| """ BeamSearch generator. """ | |||
| def __init__(self, config, reader): | |||
| super().__init__(config, reader) | |||
| self.beam_size = config.Generator.beam_size | |||
| self.length_average = config.Generator.length_average | |||
| self.length_penalty = config.Generator.length_penalty | |||
| self.ignore_unk = config.Generator.ignore_unk | |||
| return | |||
| def __call__(self, | |||
| step_fn, | |||
| state, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ | |||
| Running beam search. | |||
| @param : step_fn : decoding one step | |||
| @type : function | |||
| @param : state : initial state | |||
| @type : dict | |||
| """ | |||
| if prev_input is not None: | |||
| if isinstance(prev_input, list): | |||
| length = max(list(map(lambda x: len(x), prev_input))) | |||
| prev_input_numpy = np.full((len(prev_input), length), | |||
| self.pad_id) | |||
| for i, x in enumerate(prev_input): | |||
| prev_input_numpy[i, :len(x)] = x | |||
| prev_input_tensor = torch.from_numpy(prev_input_numpy) | |||
| if self.use_gpu: | |||
| prev_input_tensor = prev_input_tensor.cuda() | |||
| for i in range(length): | |||
| state['pred_token'] = prev_input_tensor[:, i].unsqueeze( | |||
| -1).unsqueeze(-1) | |||
| if i != 0: | |||
| state['pred_mask'] = torch.not_equal( | |||
| state['pred_token'], self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + state[ | |||
| 'pred_mask'].int() | |||
| _, state = step_fn(state) | |||
| else: | |||
| assert isinstance(prev_input, torch.Tensor) | |||
| for i, input in enumerate(prev_input): | |||
| state['pred_token'] = input.expand(1, 1, 1) | |||
| if i != 0: | |||
| state['pred_mask'] = torch.not_equal( | |||
| state['pred_token'], self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + 1 | |||
| _, state = step_fn(state) | |||
| batch_size = state['batch_size'] | |||
| beam_size = self.beam_size | |||
| # shape: [batch_size, 1] | |||
| pos_index = torch.arange( | |||
| 0, batch_size, 1, dtype=torch.int64) * beam_size | |||
| pos_index = pos_index.unsqueeze(1) | |||
| # shape: [batch_size, beam_size, 1] | |||
| if start_id is None: | |||
| start_id = self.bos_id | |||
| if eos_id is None: | |||
| eos_id = self.eos_id | |||
| predictions = torch.ones([batch_size, beam_size, 1], | |||
| dtype=torch.int64) * start_id | |||
| if self.use_gpu: | |||
| pos_index = pos_index.cuda() | |||
| predictions = predictions.cuda() | |||
| # initial input (start_id) | |||
| state['pred_token'] = predictions[:, :1] | |||
| if prev_input is not None: | |||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||
| self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + 1 | |||
| # shape: [batch_size, vocab_size] | |||
| scores, state = step_fn(state) | |||
| unk_penalty = np.zeros(self.vocab_size, dtype='float32') | |||
| unk_penalty[self.unk_id] = -1e10 | |||
| unk_penalty = torch.from_numpy(unk_penalty) | |||
| eos_penalty = np.zeros(self.vocab_size, dtype='float32') | |||
| eos_penalty[eos_id] = -1e10 | |||
| eos_penalty = torch.from_numpy(eos_penalty) | |||
| scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | |||
| scores_after_end[ | |||
| self.pad_id] = 0 # 希望<eos>之后只生成<pad>,故使词表中log(p(<pad>))最高(0) | |||
| scores_after_end = torch.from_numpy(scores_after_end) | |||
| if self.use_gpu: | |||
| unk_penalty = unk_penalty.cuda() | |||
| eos_penalty = eos_penalty.cuda() | |||
| scores_after_end = scores_after_end.cuda() | |||
| if self.ignore_unk: | |||
| scores = scores + unk_penalty | |||
| scores = scores + eos_penalty | |||
| # shape: [batch_size, beam_size] | |||
| sequence_scores, preds = torch.topk(scores, self.beam_size) | |||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
| state = repeat(state, beam_size) | |||
| parent_idx_list = [] | |||
| pred_list = [] | |||
| if max_gen_len is None: | |||
| max_gen_len = self.max_gen_len | |||
| for step in range(2, max_gen_len + 1): | |||
| pre_ids = predictions[:, :, -1:] | |||
| state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) | |||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||
| self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + 1 | |||
| scores, state = step_fn(state) | |||
| # Generate next | |||
| # scores shape: [batch_size * beam_size, vocab_size] | |||
| if self.ignore_unk: | |||
| scores = scores + unk_penalty | |||
| if step <= self.min_gen_len: | |||
| scores = scores + eos_penalty | |||
| # scores shape: [batch_size, beam_size, vocab_size] | |||
| scores = scores.reshape(batch_size, beam_size, self.vocab_size) | |||
| # previous token is [PAD] or [EOS] | |||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
| scores = scores * (1 - pre_eos_mask) + \ | |||
| pre_eos_mask.repeat(1, 1, self.vocab_size) * scores_after_end | |||
| if self.length_average: | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - | |||
| 1 / step) | |||
| sequence_scores = sequence_scores.unsqueeze(2) * scaled_value | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) | |||
| scores = scores * scaled_value | |||
| elif self.length_penalty >= 0.0: | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||
| (math.pow((4 + step) / (5 + step), self.length_penalty)) | |||
| sequence_scores = scaled_value * sequence_scores | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||
| (math.pow(1 / (5 + step), self.length_penalty)) | |||
| scores = scores * scaled_value | |||
| scores = scores + sequence_scores.unsqueeze(-1) | |||
| scores = scores.reshape(batch_size, beam_size * self.vocab_size) | |||
| topk_scores, topk_indices = torch.topk(scores, beam_size) | |||
| # topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) | |||
| # 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 | |||
| parent_idx = topk_indices.floor_divide(self.vocab_size) | |||
| # 对vocab_size取余 | |||
| preds = topk_indices % self.vocab_size | |||
| # Gather state / sequence_scores | |||
| parent_idx = parent_idx + pos_index | |||
| parent_idx = parent_idx.reshape(batch_size * beam_size) | |||
| state = gather(state, parent_idx) | |||
| sequence_scores = topk_scores | |||
| predictions = predictions.reshape(batch_size * beam_size, step) | |||
| predictions = gather(predictions, parent_idx) | |||
| predictions = predictions.reshape(batch_size, beam_size, step) | |||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
| # 希望生成的整个句子已完结,所以要求最后一个token为<eos>或者<pad>(跟在<eos>之后),否则惩罚 | |||
| pre_ids = predictions[:, :, -1] | |||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
| sequence_scores = sequence_scores * pre_eos_mask + ( | |||
| 1 - pre_eos_mask) * (-1e10) | |||
| # 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) | |||
| indices = torch.argsort(sequence_scores, dim=1) | |||
| indices = indices + pos_index | |||
| indices = indices.reshape(-1) | |||
| sequence_scores = sequence_scores.reshape(batch_size * beam_size) | |||
| predictions = predictions.reshape(batch_size * beam_size, -1) | |||
| sequence_scores = gather(sequence_scores, indices) | |||
| predictions = gather(predictions, indices) | |||
| sequence_scores = sequence_scores.reshape(batch_size, beam_size) | |||
| predictions = predictions.reshape(batch_size, beam_size, -1) | |||
| results = { | |||
| 'preds': predictions[:, -1], | |||
| 'scores': sequence_scores[:, -1] | |||
| } | |||
| return results | |||
| BeamSearch.register('BeamSearch') | |||
| @@ -0,0 +1,99 @@ | |||
| """ | |||
| Model base | |||
| """ | |||
| import os | |||
| import torch.nn as nn | |||
| class ModelBase(nn.Module): | |||
| """ | |||
| Basic model wrapper for static graph and dygrpah. | |||
| """ | |||
| _registry = dict() | |||
| @classmethod | |||
| def register(cls, name): | |||
| ModelBase._registry[name] = cls | |||
| return | |||
| @staticmethod | |||
| def by_name(name): | |||
| return ModelBase._registry[name] | |||
| @staticmethod | |||
| def create(model_dir, config, *args, **kwargs): | |||
| model_cls = ModelBase.by_name(config.Model.model) | |||
| return model_cls(model_dir, config, *args, **kwargs) | |||
| def __init__(self, model_dir, config): | |||
| super(ModelBase, self).__init__() | |||
| self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin') | |||
| self.abandon_label = config.Dataset.abandon_label | |||
| self.use_gpu = config.use_gpu | |||
| self.gpu = config.Trainer.gpu | |||
| return | |||
| def _create_parameters(self): | |||
| """ Create model's paramters. """ | |||
| raise NotImplementedError | |||
| def _forward(self, inputs, is_training, with_label): | |||
| """ NO LABEL: Real forward process of model in different mode(train/test). """ | |||
| raise NotImplementedError | |||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
| """ NO LABEL: Calculate loss function by using inputs and outputs. """ | |||
| raise NotImplementedError | |||
| def _optimize(self, loss, optimizer, lr_scheduler): | |||
| """ Optimize loss function and update model. """ | |||
| raise NotImplementedError | |||
| def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): | |||
| """ Real inference process of model. """ | |||
| raise NotImplementedError | |||
| def forward(self, | |||
| inputs, | |||
| is_training=False, | |||
| with_label=False, | |||
| data_file=None): | |||
| """ | |||
| Forward process, include real forward, collect metrices and optimize(optional) | |||
| @params : inputs : input data | |||
| @type : dict of numpy.ndarray/int/float/... | |||
| """ | |||
| if is_training: | |||
| self.train() | |||
| else: | |||
| self.eval() | |||
| with_label = False if self.abandon_label else with_label | |||
| outputs = self._forward(inputs, is_training, with_label=with_label) | |||
| metrics = self._collect_metrics( | |||
| inputs, outputs, with_label=with_label, data_file=data_file) | |||
| return metrics | |||
| def infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ | |||
| Inference process. | |||
| @params : inputs : input data | |||
| @type : dict of numpy.ndarray/int/float/... | |||
| """ | |||
| self.eval() | |||
| results = self._infer( | |||
| inputs, | |||
| start_id=start_id, | |||
| eos_id=eos_id, | |||
| max_gen_len=max_gen_len, | |||
| prev_input=prev_input) | |||
| return results | |||
| @@ -0,0 +1,322 @@ | |||
| """ | |||
| UnifiedTransformer | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from maas_lib.models.nlp.space.model.model_base import ModelBase | |||
| from maas_lib.models.nlp.space.modules.embedder import Embedder | |||
| from maas_lib.models.nlp.space.modules.transformer_block import \ | |||
| TransformerBlock | |||
| class UnifiedTransformer(ModelBase): | |||
| """ | |||
| Implement unified transformer. | |||
| """ | |||
| def __init__(self, model_dir, config, reader, generator, dtype='float32'): | |||
| super(UnifiedTransformer, self).__init__(model_dir, config) | |||
| self.reader = reader | |||
| self.generator = generator | |||
| self.policy = config.BPETextField.policy | |||
| self.generation = config.BPETextField.generation | |||
| self.num_token_embeddings = config.Model.num_token_embeddings | |||
| self.num_pos_embeddings = config.Model.num_pos_embeddings | |||
| self.num_type_embeddings = config.Model.num_type_embeddings | |||
| self.num_turn_embeddings = config.Model.num_turn_embeddings | |||
| self.temperature = config.Model.temperature | |||
| self.hidden_dim = config.Model.hidden_dim | |||
| self.num_heads = config.Model.num_heads | |||
| self.num_layers = config.Model.num_layers | |||
| self.padding_idx = config.Model.padding_idx | |||
| self.dropout = config.Model.dropout | |||
| self.embed_dropout = config.Model.embed_dropout | |||
| self.attn_dropout = config.Model.attn_dropout | |||
| self.ff_dropout = config.Model.ff_dropout | |||
| self.mlm_ratio = config.Model.mlm_ratio | |||
| self.mmd_ratio = config.Model.mmd_ratio | |||
| self.pos_trainable = config.Model.pos_trainable | |||
| self.label_smooth = config.Model.label_smooth | |||
| self.initializer_range = config.Model.initializer_range | |||
| self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||
| self.token_loss = config.Trainer.token_loss | |||
| self.learning_method = config.Dataset.learning_method | |||
| self.with_contrastive = config.Dataset.with_contrastive | |||
| self.with_query_bow = config.BPETextField.with_query_bow | |||
| self.with_resp_bow = config.BPETextField.with_resp_bow | |||
| self.with_pool = config.Model.with_pool | |||
| self.with_mlm = config.Dataset.with_mlm | |||
| self._dtype = dtype | |||
| self.embedder = Embedder( | |||
| self.hidden_dim, | |||
| self.num_token_embeddings, | |||
| self.num_pos_embeddings, | |||
| self.num_type_embeddings, | |||
| self.num_turn_embeddings, | |||
| padding_idx=self.padding_idx, | |||
| dropout=self.embed_dropout, | |||
| pos_trainable=self.pos_trainable) | |||
| self.embed_layer_norm = nn.LayerNorm( | |||
| normalized_shape=self.hidden_dim, | |||
| eps=1e-12, | |||
| elementwise_affine=True) | |||
| self.layers = nn.ModuleList([ | |||
| TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, | |||
| self.attn_dropout, self.ff_dropout) | |||
| for _ in range(config.Model.num_layers) | |||
| ]) | |||
| if self.with_mlm: | |||
| self.mlm_transform = nn.Sequential( | |||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), | |||
| nn.LayerNorm( | |||
| normalized_shape=self.hidden_dim, | |||
| eps=1e-12, | |||
| elementwise_affine=True)) | |||
| self.mlm_bias = nn.Parameter( | |||
| torch.zeros(self.num_token_embeddings)) | |||
| self.pooler = nn.Sequential( | |||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) | |||
| if self.with_query_bow or self.with_resp_bow: | |||
| self.bow_predictor = nn.Linear( | |||
| self.hidden_dim, self.num_token_embeddings, bias=False) | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| self.bce_loss = nn.BCELoss(reduction='none') | |||
| self.nll_loss = nn.NLLLoss( | |||
| ignore_index=self.padding_idx, reduction='none') | |||
| self._create_parameters() | |||
| self.max_grad_norm = config.Model.max_grad_norm | |||
| if self.max_grad_norm is not None: | |||
| self.grad_clip = self.max_grad_norm | |||
| else: | |||
| self.grad_clip = None | |||
| self.weight_decay = config.Model.weight_decay | |||
| if self.use_gpu: | |||
| self.cuda() | |||
| return | |||
| def _create_parameters(self): | |||
| """ Create model's paramters. """ | |||
| sequence_mask = np.tri( | |||
| self.num_pos_embeddings, | |||
| self.num_pos_embeddings, | |||
| dtype=self._dtype) | |||
| self.sequence_mask = torch.tensor(sequence_mask) | |||
| return | |||
| def _create_mask(self, | |||
| input_mask, | |||
| append_head=False, | |||
| auto_regressive=False): | |||
| """ | |||
| Create attention mask. | |||
| 创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||
| mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) | |||
| 注: | |||
| 1. 一个句子中的非<pad>词看整个句子,该句中只有<pad>词才被mask | |||
| 2. 一个句子中的<pad>词看整个句子,该句的所有词都应该被mask | |||
| @param : input_mask | |||
| @type : Variable(shape: [batch_size, max_seq_len]) | |||
| @param : auto_regressive | |||
| @type : bool | |||
| """ | |||
| seq_len = input_mask.shape[1] | |||
| input_mask = input_mask.float() | |||
| mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) | |||
| mask2 = mask1.permute(0, 2, 1) | |||
| mask = mask1 * mask2 | |||
| if append_head: | |||
| # 拼接上句首位置([M]/z)的mask | |||
| mask = torch.cat([mask[:, :1, :], mask], dim=1) | |||
| mask = torch.cat([mask[:, :, :1], mask], dim=2) | |||
| seq_len += 1 | |||
| if auto_regressive: | |||
| # 将tgt端的<pad> mask和自回归attention mask融合 | |||
| seq_mask = self.sequence_mask[:seq_len, :seq_len] | |||
| seq_mask = seq_mask.to(mask.device) | |||
| mask = mask * seq_mask | |||
| mask = 1 - mask | |||
| return mask | |||
| def _join_mask(self, mask1, mask2): | |||
| """ | |||
| Merge source attention mask and target attention mask. | |||
| 合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb | |||
| @param : mask1 : source attention mask | |||
| @type : Variable(shape: [batch_size, max_src_len, max_src_len]) | |||
| @param : mask1 : target attention mask | |||
| @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) | |||
| """ | |||
| batch_size = mask1.shape[0] | |||
| seq_len1 = mask1.shape[1] | |||
| seq_len2 = mask2.shape[1] | |||
| seq_len = seq_len1 + seq_len2 | |||
| mask_lu = mask1 | |||
| mask_ru = torch.ones(batch_size, seq_len1, seq_len2) | |||
| if self.use_gpu: | |||
| mask_ru = mask_ru.cuda() | |||
| mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) | |||
| mask4 = mask1[:, :1].repeat(1, seq_len2, 1) | |||
| mask_lb = mask3 + mask4 - mask3 * mask4 | |||
| mask_rb = mask2 | |||
| mask_u = torch.cat([mask_lu, mask_ru], dim=2) | |||
| mask_b = torch.cat([mask_lb, mask_rb], dim=2) | |||
| mask = torch.cat([mask_u, mask_b], dim=1) | |||
| return mask | |||
| def _mlm_head(self, mlm_embed): | |||
| mlm_embed = self.mlm_transform(mlm_embed) | |||
| mlm_logits = torch.matmul( | |||
| mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias | |||
| mlm_probs = self.softmax(mlm_logits) | |||
| return mlm_probs | |||
| def _dec_head(self, dec_embed): | |||
| dec_logits = torch.matmul(dec_embed, | |||
| self.embedder.token_embedding.weight.T) | |||
| dec_probs = self.softmax(dec_logits) | |||
| return dec_probs | |||
| def _refactor_feature(self, features): | |||
| features = self.pooler(features) if self.with_pool else features | |||
| batch_size = features.size(0) // 2 | |||
| features = torch.cat([ | |||
| features[:batch_size].unsqueeze(1), | |||
| features[batch_size:].unsqueeze(1) | |||
| ], | |||
| dim=1) | |||
| features = F.normalize(features, dim=-1, p=2) | |||
| return features | |||
| def _encoder_network(self, | |||
| input_token, | |||
| input_mask, | |||
| input_pos=None, | |||
| input_type=None, | |||
| input_turn=None): | |||
| embed = self.embedder(input_token, input_pos, input_type, input_turn) | |||
| embed = self.embed_layer_norm(embed) | |||
| mask = self._create_mask(input_mask, auto_regressive=False) | |||
| for layer in self.layers: | |||
| embed = layer(embed, mask, None) | |||
| return embed | |||
| def _encoder_decoder_network(self, | |||
| src_token, | |||
| src_mask, | |||
| tgt_token, | |||
| tgt_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None, | |||
| tgt_pos=None, | |||
| tgt_type=None, | |||
| tgt_turn=None): | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||
| embed = torch.cat([src_embed, tgt_embed], dim=1) | |||
| embed = self.embed_layer_norm(embed) | |||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
| dec_mask = self._create_mask(tgt_mask, auto_regressive=True) | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| for layer in self.layers: | |||
| embed = layer(embed, mask, None) | |||
| tgt_len = tgt_token.shape[1] | |||
| enc_embed = embed[:, :-tgt_len] | |||
| dec_embed = embed[:, -tgt_len:] | |||
| return enc_embed, dec_embed | |||
| def _encoder_prompt_decoder_network(self, | |||
| src_token, | |||
| src_mask, | |||
| tgt_token, | |||
| tgt_mask, | |||
| prompt_token, | |||
| prompt_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None, | |||
| tgt_pos=None, | |||
| tgt_type=None, | |||
| tgt_turn=None, | |||
| prompt_pos=None, | |||
| prompt_type=None, | |||
| prompt_turn=None): | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||
| prompt_turn) | |||
| embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) | |||
| embed = self.embed_layer_norm(embed) | |||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
| dec_mask = self._create_mask( | |||
| torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| for layer in self.layers: | |||
| embed = layer(embed, mask, None) | |||
| src_len = src_token.shape[1] | |||
| tgt_len = tgt_token.shape[1] | |||
| enc_embed = embed[:, :src_len] | |||
| dec_embed = embed[:, -tgt_len:] | |||
| prompt_embed = embed[:, src_len:-tgt_len] | |||
| return enc_embed, dec_embed, prompt_embed | |||
| def _optimize(self, loss, optimizer=None, lr_scheduler=None): | |||
| """ Optimize loss function and update model. """ | |||
| assert optimizer is not None | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| if self.grad_clip is not None and self.grad_clip > 0: | |||
| torch.nn.utils.clip_grad_norm_( | |||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||
| optimizer.step() | |||
| if lr_scheduler is not None: | |||
| lr_scheduler.step() | |||
| return | |||
| def _infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ Real inference process of model. """ | |||
| results = {} | |||
| return results | |||
| UnifiedTransformer.register('UnifiedTransformer') | |||
| @@ -0,0 +1,67 @@ | |||
| """ | |||
| Embedder class. | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| class Embedder(nn.Module): | |||
| """ | |||
| Composite embedding layer. | |||
| """ | |||
| def __init__(self, | |||
| hidden_dim, | |||
| num_token_embeddings, | |||
| num_pos_embeddings, | |||
| num_type_embeddings, | |||
| num_turn_embeddings, | |||
| padding_idx=None, | |||
| dropout=0.1, | |||
| pos_trainable=False): | |||
| super(Embedder, self).__init__() | |||
| self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) | |||
| self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) | |||
| self.pos_embedding.weight.requires_grad = pos_trainable | |||
| self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) | |||
| self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| # follow the default xavier_uniform initializer in paddle version | |||
| # otherwise, there are bugs for dec_probs computation in weight typing setting | |||
| # default norm initializer in nn.Embedding in pytorch, which samples larger values | |||
| nn.init.xavier_uniform_(self.token_embedding.weight) | |||
| nn.init.xavier_uniform_(self.pos_embedding.weight) | |||
| nn.init.xavier_uniform_(self.type_embedding.weight) | |||
| nn.init.xavier_uniform_(self.turn_embedding.weight) | |||
| return | |||
| def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): | |||
| embed = self.token_embedding(token_inp) | |||
| if pos_inp is not None: | |||
| embed += self.pos_embedding(pos_inp) | |||
| if type_inp is not None: | |||
| embed += self.type_embedding(type_inp) | |||
| if turn_inp is not None: | |||
| embed += self.turn_embedding(turn_inp) | |||
| embed = self.dropout_layer(embed) | |||
| return embed | |||
| def main(): | |||
| import numpy as np | |||
| model = Embedder(10, 20, 20, 20, 20) | |||
| token_inp = torch.tensor( | |||
| np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| out = model(token_inp, pos_inp, type_inp, turn_inp) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,43 @@ | |||
| """ | |||
| FeedForward class. | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| class FeedForward(nn.Module): | |||
| """ | |||
| Positional feed forward layer. | |||
| """ | |||
| def __init__(self, hidden_dim, inner_dim, dropout): | |||
| super(FeedForward, self).__init__() | |||
| self.hidden_dim = hidden_dim | |||
| self.inner_dim = inner_dim | |||
| self.linear_hidden = nn.Sequential( | |||
| nn.Linear(hidden_dim, inner_dim), nn.GELU()) | |||
| self.linear_out = nn.Linear(inner_dim, hidden_dim) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| return | |||
| def forward(self, x): | |||
| out = self.linear_hidden(x) | |||
| out = self.dropout_layer(out) | |||
| out = self.linear_out(out) | |||
| return out | |||
| def main(): | |||
| import numpy as np | |||
| model = FeedForward(10, 20, 0.5) | |||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||
| inp = torch.tensor(inp) | |||
| out = model(inp) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,64 @@ | |||
| """ | |||
| Helpful functions. | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| def unsqueeze(input, dims): | |||
| """ Implement multi-dimension unsqueeze function. """ | |||
| if isinstance(dims, (list, tuple)): | |||
| dims = [ | |||
| dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims | |||
| ] | |||
| dims = sorted(dims, reverse=True) | |||
| shape = list(input.shape) | |||
| for dim in dims: | |||
| shape.insert(dim, 1) | |||
| return torch.reshape(input, shape) | |||
| elif isinstance(dims, int): | |||
| return input.unsqueeze(dims) | |||
| else: | |||
| raise ValueError('Warning: type(dims) must in (list, tuple, int)!') | |||
| def gumbel_softmax(input, tau=1, eps=1e-10): | |||
| """ Basic implement of gumbel_softmax. """ | |||
| U = torch.tensor(np.random.rand(*input.shape)) | |||
| gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) | |||
| y = input + gumbel | |||
| return F.softmax(y / tau) | |||
| def equal(x, y, dtype=None): | |||
| """ Implement equal in dygraph mode. (paddle) """ | |||
| if dtype is None: | |||
| dtype = 'float32' | |||
| if isinstance(x, torch.Tensor): | |||
| x = x.numpy() | |||
| if isinstance(y, torch.Tensor): | |||
| y = y.numpy() | |||
| out = np.equal(x, y).astype(dtype) | |||
| return torch.tensor(out) | |||
| def not_equal(x, y, dtype=None): | |||
| """ Implement not_equal in dygraph mode. (paddle) """ | |||
| return 1 - equal(x, y, dtype) | |||
| if __name__ == '__main__': | |||
| a = torch.tensor([[1, 1], [3, 4]]) | |||
| b = torch.tensor([[1, 1], [3, 4]]) | |||
| c = torch.equal(a, a) | |||
| c1 = equal(a, 3) | |||
| d = 1 - torch.not_equal(a, 3).float() | |||
| print(c) | |||
| print(c1) | |||
| print(d) | |||
| e = F.gumbel_softmax(a) | |||
| f = a.unsqueeze(a) | |||
| g = unsqueeze(a, dims=[0, 0, 1]) | |||
| print(g, g.shape) | |||
| @@ -0,0 +1,109 @@ | |||
| """ | |||
| MultiheadAttention class. | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| class MultiheadAttention(nn.Module): | |||
| """ | |||
| Multi head attention layer. | |||
| """ | |||
| def __init__(self, hidden_dim, num_heads, dropout): | |||
| assert hidden_dim % num_heads == 0 | |||
| super(MultiheadAttention, self).__init__() | |||
| self.hidden_dim = hidden_dim | |||
| self.num_heads = num_heads | |||
| self.head_dim = hidden_dim // num_heads | |||
| self.scale = self.head_dim**-0.5 | |||
| self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) | |||
| self.linear_out = nn.Linear(hidden_dim, hidden_dim) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| return | |||
| def _split_heads(self, x, is_key=False): | |||
| x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) | |||
| x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) | |||
| return x | |||
| def _merge_heads(self, x): | |||
| x = x.permute(0, 2, 1, 3) | |||
| x = x.reshape(x.size(0), x.size(1), self.hidden_dim) | |||
| return x | |||
| def _attn(self, query, key, value, mask): | |||
| # shape: [batch_size, num_head, seq_len, seq_len] | |||
| scores = torch.matmul(query, key) | |||
| scores = scores * self.scale | |||
| if mask is not None: | |||
| mask = mask.unsqueeze(1) | |||
| mask = mask.repeat(1, self.num_heads, 1, 1) | |||
| scores.masked_fill_( | |||
| mask.bool(), | |||
| float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) | |||
| attn = self.softmax(scores) | |||
| attn = self.dropout_layer(attn) | |||
| if mask is not None: | |||
| ''' | |||
| mask: [batch size, num_heads, seq_len, seq_len] | |||
| mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中<pad>位看的行 | |||
| 导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 | |||
| >>> F.softmax([-1e10, -100, -100]) | |||
| >>> [0.00, 0.50, 0.50] | |||
| >>> F.softmax([-1e10, -1e10, -1e10]) | |||
| >>> [0.33, 0.33, 0.33] | |||
| ==> [0.00, 0.00, 0.00] | |||
| ''' | |||
| attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn | |||
| out = torch.matmul(attn, value) | |||
| return out | |||
| def forward(self, inp, mask=None, cache=None): | |||
| """ Forward process of self attention. """ | |||
| # shape: [batch_size, seq_len, 3 * hidden_dim] | |||
| qkv = self.linear_qkv(inp) | |||
| query, key, value = torch.split(qkv, self.hidden_dim, dim=2) | |||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||
| query = self._split_heads(query) | |||
| # shape: [batch_size, num_head, head_dim, seq_len] | |||
| key = self._split_heads(key, is_key=True) | |||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||
| value = self._split_heads(value) | |||
| if cache is not None: | |||
| if 'key' in cache and 'value' in cache: | |||
| key = torch.cat([cache['key'], key], dim=3) | |||
| value = torch.cat([cache['value'], value], dim=2) | |||
| cache['key'] = key | |||
| cache['value'] = value | |||
| out = self._attn(query, key, value, mask) | |||
| out = self._merge_heads(out) | |||
| out = self.linear_out(out) | |||
| return out | |||
| def main(): | |||
| import numpy as np | |||
| model = MultiheadAttention(10, 2, 0.5) | |||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||
| inp = torch.tensor(inp) | |||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||
| mask = torch.tensor(mask) | |||
| out = model(inp, mask=mask, cache=None) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,73 @@ | |||
| """ | |||
| TransformerBlock class. | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| from maas_lib.models.nlp.space.modules.feedforward import FeedForward | |||
| from maas_lib.models.nlp.space.modules.multihead_attention import \ | |||
| MultiheadAttention | |||
| class TransformerBlock(nn.Module): | |||
| """ | |||
| Transformer block module. | |||
| """ | |||
| def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, | |||
| ff_dropout): | |||
| super(TransformerBlock, self).__init__() | |||
| self.attn = MultiheadAttention( | |||
| hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) | |||
| self.attn_norm = nn.LayerNorm( | |||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||
| self.ff = FeedForward( | |||
| hidden_dim=hidden_dim, | |||
| inner_dim=4 * hidden_dim, | |||
| dropout=ff_dropout) | |||
| self.ff_norm = nn.LayerNorm( | |||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| return | |||
| def forward(self, inp, mask=None, cache=None): | |||
| """ | |||
| Forward process on one transformer layer. | |||
| @param : x | |||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
| @param : memory | |||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
| @param : mask | |||
| @param : cache | |||
| """ | |||
| attn_out = self.attn(inp, mask, cache) | |||
| attn_out = self.dropout_layer(attn_out) | |||
| attn_out = self.attn_norm(attn_out + inp) | |||
| ff_out = self.ff(attn_out) | |||
| ff_out = self.dropout_layer(ff_out) | |||
| ff_out = self.ff_norm(ff_out + attn_out) | |||
| return ff_out | |||
| def main(): | |||
| import numpy as np | |||
| model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) | |||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||
| inp = torch.tensor(inp) | |||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||
| mask = torch.tensor(mask) | |||
| out = model(inp, mask=mask, cache=None) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -4,5 +4,5 @@ from .base import Preprocessor | |||
| from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| from .space.dialog_generation_preprcessor import * # noqa F403 | |||
| from .nlp.nlp import * # noqa F403 | |||
| from .nlp.space.dialog_generation_preprcessor import * # noqa F403 | |||
| @@ -7,8 +7,8 @@ from transformers import AutoTokenizer | |||
| from maas_lib.utils.constant import Fields, InputFields | |||
| from maas_lib.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| from ..base import Preprocessor | |||
| from ..builder import PREPROCESSORS | |||
| __all__ = [ | |||
| 'Tokenize', | |||
| @@ -5,10 +5,11 @@ import uuid | |||
| from typing import Any, Dict, Union | |||
| from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.constant import Fields, InputFields | |||
| from maas_lib.utils.type_assert import type_assert | |||
| from ..base import Preprocessor | |||
| from ..builder import PREPROCESSORS | |||
| from ...base import Preprocessor | |||
| from ...builder import PREPROCESSORS | |||
| __all__ = ['DialogGenerationPreprocessor'] | |||
| @@ -25,10 +26,10 @@ class DialogGenerationPreprocessor(Preprocessor): | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) | |||
| pass | |||
| self.config = Config.from_file( | |||
| os.path.join(self.model_dir, 'configuration.json')) | |||
| self.text_field = MultiWOZBPETextField( | |||
| self.model_dir, config=self.config) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @@ -4,37 +4,11 @@ import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| from maas_lib.fileio import File | |||
| from tests.case.nlp.dialog_generation_case import test_case | |||
| from maas_lib.models.nlp import DialogGenerationModel | |||
| from maas_lib.pipelines import DialogGenerationPipeline, pipeline | |||
| from maas_lib.preprocessors import DialogGenerationPreprocessor | |||
| from maas_lib.utils.constant import Tasks | |||
| dialog_case = [{ | |||
| 'user': | |||
| 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', | |||
| 'sys': | |||
| 'okay , do you have a specific area you want to stay in ?' | |||
| }, { | |||
| 'user': | |||
| 'no , i just need to make sure it is cheap . oh , and i need parking', | |||
| 'sys': | |||
| 'i found 1 cheap hotel for you that include -s parking . do you like me to book it ?' | |||
| }, { | |||
| 'user': | |||
| 'yes , please . 6 people 3 nights starting on tuesday .', | |||
| 'sys': | |||
| "i am sorry but i was n't able to book that for you for tuesday . is there another day you would like " | |||
| 'to stay or perhaps a shorter stay ? ' | |||
| }, { | |||
| 'user': | |||
| 'how about only 2 nights .', | |||
| 'sys': | |||
| 'booking was successful . reference number is : 7gawk763 . anything else i can do for you ?', | |||
| }, { | |||
| 'user': 'no , that will be all . goodbye .', | |||
| 'sys': 'thank you for using our services .' | |||
| }] | |||
| def merge(info, result): | |||
| @@ -47,21 +21,23 @@ class DialogGenerationTest(unittest.TestCase): | |||
| modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||
| preprocessor = DialogGenerationPreprocessor() | |||
| preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) | |||
| model = DialogGenerationModel( | |||
| model_dir=modeldir, preprocessor.tokenizer) | |||
| pipeline = DialogGenerationPipeline(model, preprocessor) | |||
| model_dir=modeldir, | |||
| text_field=preprocessor.text_field, | |||
| config=preprocessor.config) | |||
| # pipeline = DialogGenerationPipeline(model, preprocessor) | |||
| history_dialog = {} | |||
| for step in range(0, len(dialog_case)): | |||
| user_question = dialog_case[step]['user'] | |||
| for step, item in enumerate(test_case['sng0073']['log']): | |||
| user_question = item['user'] | |||
| print('user: {}'.format(user_question)) | |||
| history_dialog_info = merge(history_dialog_info, | |||
| result) if step > 0 else {} | |||
| result = pipeline(user_question, history=history_dialog_info) | |||
| print('sys : {}'.format(result['pred_answer'])) | |||
| # history_dialog_info = merge(history_dialog_info, | |||
| # result) if step > 0 else {} | |||
| # result = pipeline(user_question, history=history_dialog_info) | |||
| # | |||
| # print('sys : {}'.format(result['pred_answer'])) | |||
| if __name__ == '__main__': | |||