| @@ -3,6 +3,8 @@ from typing import Any, Dict, Optional | |||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| from .model.generator import Generator | |||||
| from .model.model_base import ModelBase | |||||
| __all__ = ['DialogGenerationModel'] | __all__ = ['DialogGenerationModel'] | ||||
| @@ -21,7 +23,14 @@ class DialogGenerationModel(Model): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| pass | |||||
| self.text_field = kwargs.pop('text_field') | |||||
| self.config = kwargs.pop('config') | |||||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||||
| self.model = ModelBase.create( | |||||
| model_dir=model_dir, | |||||
| config=self.config, | |||||
| reader=self.text_field, | |||||
| generator=self.generator) | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -0,0 +1,285 @@ | |||||
| """ | |||||
| IntentUnifiedTransformer | |||||
| """ | |||||
| import torch | |||||
| from maas_lib.models.nlp.space.model.unified_transformer import \ | |||||
| UnifiedTransformer | |||||
| class GenUnifiedTransformer(UnifiedTransformer): | |||||
| """ | |||||
| Implement generation unified transformer. | |||||
| """ | |||||
| def __init__(self, model_dir, config, reader, generator): | |||||
| super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, | |||||
| generator) | |||||
| self.understand = config.BPETextField.understand | |||||
| if self.use_gpu: | |||||
| self.cuda() | |||||
| return | |||||
| def _forward(self, inputs, is_training, with_label): | |||||
| """ Real forward process of model in different mode(train/test). """ | |||||
| def cat(x, y, dim=1): | |||||
| return torch.cat([x, y], dim=dim) | |||||
| outputs = {} | |||||
| if self.understand or self.policy: | |||||
| if self.understand: | |||||
| prompt_token = inputs['understand_token'] | |||||
| prompt_mask = inputs['understand_mask'] | |||||
| if self.policy: | |||||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||||
| else: | |||||
| prompt_token = inputs['policy_token'] | |||||
| prompt_mask = inputs['policy_mask'] | |||||
| enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| tgt_token=inputs['tgt_token'][:, :-1], | |||||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||||
| prompt_token=prompt_token, | |||||
| prompt_mask=prompt_mask, | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn'], | |||||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||||
| tgt_type=inputs['tgt_type'][:, :-1], | |||||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||||
| else: | |||||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| tgt_token=inputs['tgt_token'][:, :-1], | |||||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn'], | |||||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||||
| tgt_type=inputs['tgt_type'][:, :-1], | |||||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||||
| outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) | |||||
| return outputs | |||||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||||
| metrics = {} | |||||
| loss = 0. | |||||
| label = inputs['tgt_token'][:, 1:] | |||||
| token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) | |||||
| nll = self.nll_loss( | |||||
| torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) | |||||
| nll = torch.sum(nll, dim=1) | |||||
| token_nll = torch.sum(nll) / token_num | |||||
| nll = torch.mean(nll) | |||||
| metrics['nll'] = nll | |||||
| metrics['token_nll'] = token_nll | |||||
| metrics['token_num'] = token_num | |||||
| loss = loss + (token_nll if self.token_loss else nll) | |||||
| metrics['loss'] = loss | |||||
| if self.gpu > 1: | |||||
| return nll, token_nll, token_num | |||||
| else: | |||||
| return metrics | |||||
| def _optimize(self, loss, do_update=False, optimizer=None): | |||||
| """ Optimize loss function and update model. """ | |||||
| assert optimizer is not None | |||||
| if self.gradient_accumulation_steps > 1: | |||||
| loss = loss / self.gradient_accumulation_steps | |||||
| loss.backward() | |||||
| if self.grad_clip is not None and self.grad_clip > 0: | |||||
| torch.nn.utils.clip_grad_norm_( | |||||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||||
| if do_update: | |||||
| optimizer.step() | |||||
| optimizer.zero_grad() | |||||
| return | |||||
| def _init_state(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None): | |||||
| """ Initialize decode state. """ | |||||
| state = {} | |||||
| batch_size = src_token.shape[0] | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| src_embed = self.embed_layer_norm(src_embed) | |||||
| mask = self._create_mask(src_mask, append_head=False) | |||||
| enc_out = src_embed | |||||
| cache = {} | |||||
| for l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||||
| state['cache'] = cache | |||||
| state['mask'] = mask[:, :1] | |||||
| state['batch_size'] = batch_size | |||||
| shape = [batch_size, 1, 1] | |||||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||||
| if self.use_gpu: | |||||
| state['pred_mask'] = state['pred_mask'].cuda() | |||||
| state['pred_pos'] = state['pred_pos'].cuda() | |||||
| state['pred_type'] = state['pred_type'].cuda() | |||||
| state['pred_turn'] = state['pred_turn'].cuda() | |||||
| return state | |||||
| def _init_prompt_state(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| prompt_token, | |||||
| prompt_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None, | |||||
| prompt_pos=None, | |||||
| prompt_type=None, | |||||
| prompt_turn=None): | |||||
| """ Initialize decode state. """ | |||||
| state = {} | |||||
| batch_size = src_token.shape[0] | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||||
| prompt_turn) | |||||
| embed = torch.cat([src_embed, prompt_embed], dim=1) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| enc_out = embed | |||||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||||
| dec_mask = self._create_mask(prompt_mask, auto_regressive=True) | |||||
| mask = self._join_mask(enc_mask, dec_mask) | |||||
| cache = {} | |||||
| for l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||||
| state['cache'] = cache | |||||
| state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | |||||
| state['batch_size'] = batch_size | |||||
| shape = [batch_size, 1, 1] | |||||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||||
| if self.use_gpu: | |||||
| state['pred_mask'] = state['pred_mask'].cuda() | |||||
| state['pred_pos'] = state['pred_pos'].cuda() | |||||
| state['pred_type'] = state['pred_type'].cuda() | |||||
| state['pred_turn'] = state['pred_turn'].cuda() | |||||
| return state | |||||
| def _decode(self, state): | |||||
| """ Decoding one time stamp. """ | |||||
| # shape: [batch_size, 1, seq_len] | |||||
| mask = state['mask'] | |||||
| # shape: [batch_size, 1, 1] | |||||
| pred_token = state['pred_token'] | |||||
| pred_mask = state['pred_mask'] | |||||
| pred_pos = state['pred_pos'] | |||||
| pred_type = state['pred_type'] | |||||
| pred_turn = state['pred_turn'] | |||||
| # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] | |||||
| cache = state['cache'] | |||||
| pred_embed = self.embedder(pred_token, pred_pos, pred_type, | |||||
| pred_turn).squeeze(-2) | |||||
| pred_embed = self.embed_layer_norm(pred_embed) | |||||
| # shape: [batch_size, 1, seq_len + 1] | |||||
| mask = torch.cat([mask, 1 - pred_mask], dim=2) | |||||
| # shape: [batch_size, 1, hidden_dim] | |||||
| for l, layer in enumerate(self.layers): | |||||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{l}']) | |||||
| # shape: [batch_size, vocab_size] | |||||
| pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | |||||
| pred_logits = torch.log(pred_probs) | |||||
| state['mask'] = mask | |||||
| return pred_logits, state | |||||
| def _infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ Real inference process of model. """ | |||||
| def cat(x, y, dim=1): | |||||
| return torch.cat([x, y], dim=dim) | |||||
| # Initial decode state. | |||||
| if self.understand or self.policy: | |||||
| if self.understand: | |||||
| prompt_token = inputs['understand_token'] | |||||
| prompt_mask = inputs['understand_mask'] | |||||
| if self.policy: | |||||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||||
| else: | |||||
| prompt_token = inputs['policy_token'] | |||||
| prompt_mask = inputs['policy_mask'] | |||||
| state = self._init_prompt_state( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| prompt_token=prompt_token, | |||||
| prompt_mask=prompt_mask, | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn']) | |||||
| else: | |||||
| state = self._init_state( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn']) | |||||
| # Generation process. | |||||
| gen_results = self.generator( | |||||
| step_fn=self._decode, | |||||
| state=state, | |||||
| start_id=start_id, | |||||
| eos_id=eos_id, | |||||
| max_gen_len=max_gen_len, | |||||
| prev_input=prev_input) | |||||
| outputs = gen_results['preds'] | |||||
| return outputs | |||||
| GenUnifiedTransformer.register('GenUnifiedTransformer') | |||||
| @@ -0,0 +1,296 @@ | |||||
| """ | |||||
| Generator class. | |||||
| """ | |||||
| import math | |||||
| import numpy as np | |||||
| import torch | |||||
| from .gen_unified_transformer import GenUnifiedTransformer | |||||
| from .unified_transformer import UnifiedTransformer | |||||
| def repeat(var, times): | |||||
| if isinstance(var, list): | |||||
| return [repeat(x, times) for x in var] | |||||
| elif isinstance(var, dict): | |||||
| return {k: repeat(v, times) for k, v in var.items()} | |||||
| elif isinstance(var, torch.Tensor): | |||||
| var = var.unsqueeze(1) | |||||
| expand_times = [1] * len(var.shape) | |||||
| expand_times[1] = times | |||||
| dtype = var.dtype | |||||
| var = var.float() | |||||
| var = var.repeat(*expand_times) | |||||
| shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) | |||||
| var = var.reshape(*shape) | |||||
| var = torch.tensor(var, dtype=dtype) | |||||
| return var | |||||
| else: | |||||
| return var | |||||
| def gather(var, idx): | |||||
| if isinstance(var, list): | |||||
| return [gather(x, idx) for x in var] | |||||
| elif isinstance(var, dict): | |||||
| return {k: gather(v, idx) for k, v in var.items()} | |||||
| elif isinstance(var, torch.Tensor): | |||||
| out = var.index_select(dim=0, index=idx) | |||||
| return out | |||||
| else: | |||||
| return var | |||||
| class Generator(object): | |||||
| """ Genrator class. """ | |||||
| _registry = dict() | |||||
| @classmethod | |||||
| def register(cls, name): | |||||
| Generator._registry[name] = cls | |||||
| return | |||||
| @staticmethod | |||||
| def by_name(name): | |||||
| return Generator._registry[name] | |||||
| @staticmethod | |||||
| def create(config, *args, **kwargs): | |||||
| """ Create generator. """ | |||||
| generator_cls = Generator.by_name(config.Generator.generator) | |||||
| return generator_cls(config, *args, **kwargs) | |||||
| def __init__(self, config, reader): | |||||
| self.vocab_size = reader.vocab_size | |||||
| self.bos_id = reader.bos_id | |||||
| self.eos_id = reader.eos_id | |||||
| self.unk_id = reader.unk_id | |||||
| self.pad_id = reader.pad_id | |||||
| self.min_gen_len = config.Generator.min_gen_len | |||||
| self.max_gen_len = config.Generator.max_gen_len | |||||
| self.use_gpu = config.use_gpu | |||||
| assert 1 <= self.min_gen_len <= self.max_gen_len | |||||
| return | |||||
| def __call__(self, step_fn, state): | |||||
| """ | |||||
| Running generation. | |||||
| @param : step_fn : decoding one step | |||||
| @type : function | |||||
| @param : state : initial state | |||||
| @type : dict | |||||
| """ | |||||
| raise NotImplementedError | |||||
| class BeamSearch(Generator): | |||||
| """ BeamSearch generator. """ | |||||
| def __init__(self, config, reader): | |||||
| super().__init__(config, reader) | |||||
| self.beam_size = config.Generator.beam_size | |||||
| self.length_average = config.Generator.length_average | |||||
| self.length_penalty = config.Generator.length_penalty | |||||
| self.ignore_unk = config.Generator.ignore_unk | |||||
| return | |||||
| def __call__(self, | |||||
| step_fn, | |||||
| state, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ | |||||
| Running beam search. | |||||
| @param : step_fn : decoding one step | |||||
| @type : function | |||||
| @param : state : initial state | |||||
| @type : dict | |||||
| """ | |||||
| if prev_input is not None: | |||||
| if isinstance(prev_input, list): | |||||
| length = max(list(map(lambda x: len(x), prev_input))) | |||||
| prev_input_numpy = np.full((len(prev_input), length), | |||||
| self.pad_id) | |||||
| for i, x in enumerate(prev_input): | |||||
| prev_input_numpy[i, :len(x)] = x | |||||
| prev_input_tensor = torch.from_numpy(prev_input_numpy) | |||||
| if self.use_gpu: | |||||
| prev_input_tensor = prev_input_tensor.cuda() | |||||
| for i in range(length): | |||||
| state['pred_token'] = prev_input_tensor[:, i].unsqueeze( | |||||
| -1).unsqueeze(-1) | |||||
| if i != 0: | |||||
| state['pred_mask'] = torch.not_equal( | |||||
| state['pred_token'], self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + state[ | |||||
| 'pred_mask'].int() | |||||
| _, state = step_fn(state) | |||||
| else: | |||||
| assert isinstance(prev_input, torch.Tensor) | |||||
| for i, input in enumerate(prev_input): | |||||
| state['pred_token'] = input.expand(1, 1, 1) | |||||
| if i != 0: | |||||
| state['pred_mask'] = torch.not_equal( | |||||
| state['pred_token'], self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + 1 | |||||
| _, state = step_fn(state) | |||||
| batch_size = state['batch_size'] | |||||
| beam_size = self.beam_size | |||||
| # shape: [batch_size, 1] | |||||
| pos_index = torch.arange( | |||||
| 0, batch_size, 1, dtype=torch.int64) * beam_size | |||||
| pos_index = pos_index.unsqueeze(1) | |||||
| # shape: [batch_size, beam_size, 1] | |||||
| if start_id is None: | |||||
| start_id = self.bos_id | |||||
| if eos_id is None: | |||||
| eos_id = self.eos_id | |||||
| predictions = torch.ones([batch_size, beam_size, 1], | |||||
| dtype=torch.int64) * start_id | |||||
| if self.use_gpu: | |||||
| pos_index = pos_index.cuda() | |||||
| predictions = predictions.cuda() | |||||
| # initial input (start_id) | |||||
| state['pred_token'] = predictions[:, :1] | |||||
| if prev_input is not None: | |||||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||||
| self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + 1 | |||||
| # shape: [batch_size, vocab_size] | |||||
| scores, state = step_fn(state) | |||||
| unk_penalty = np.zeros(self.vocab_size, dtype='float32') | |||||
| unk_penalty[self.unk_id] = -1e10 | |||||
| unk_penalty = torch.from_numpy(unk_penalty) | |||||
| eos_penalty = np.zeros(self.vocab_size, dtype='float32') | |||||
| eos_penalty[eos_id] = -1e10 | |||||
| eos_penalty = torch.from_numpy(eos_penalty) | |||||
| scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | |||||
| scores_after_end[ | |||||
| self.pad_id] = 0 # 希望<eos>之后只生成<pad>,故使词表中log(p(<pad>))最高(0) | |||||
| scores_after_end = torch.from_numpy(scores_after_end) | |||||
| if self.use_gpu: | |||||
| unk_penalty = unk_penalty.cuda() | |||||
| eos_penalty = eos_penalty.cuda() | |||||
| scores_after_end = scores_after_end.cuda() | |||||
| if self.ignore_unk: | |||||
| scores = scores + unk_penalty | |||||
| scores = scores + eos_penalty | |||||
| # shape: [batch_size, beam_size] | |||||
| sequence_scores, preds = torch.topk(scores, self.beam_size) | |||||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||||
| state = repeat(state, beam_size) | |||||
| parent_idx_list = [] | |||||
| pred_list = [] | |||||
| if max_gen_len is None: | |||||
| max_gen_len = self.max_gen_len | |||||
| for step in range(2, max_gen_len + 1): | |||||
| pre_ids = predictions[:, :, -1:] | |||||
| state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) | |||||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||||
| self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + 1 | |||||
| scores, state = step_fn(state) | |||||
| # Generate next | |||||
| # scores shape: [batch_size * beam_size, vocab_size] | |||||
| if self.ignore_unk: | |||||
| scores = scores + unk_penalty | |||||
| if step <= self.min_gen_len: | |||||
| scores = scores + eos_penalty | |||||
| # scores shape: [batch_size, beam_size, vocab_size] | |||||
| scores = scores.reshape(batch_size, beam_size, self.vocab_size) | |||||
| # previous token is [PAD] or [EOS] | |||||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||||
| scores = scores * (1 - pre_eos_mask) + \ | |||||
| pre_eos_mask.repeat(1, 1, self.vocab_size) * scores_after_end | |||||
| if self.length_average: | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - | |||||
| 1 / step) | |||||
| sequence_scores = sequence_scores.unsqueeze(2) * scaled_value | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) | |||||
| scores = scores * scaled_value | |||||
| elif self.length_penalty >= 0.0: | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||||
| (math.pow((4 + step) / (5 + step), self.length_penalty)) | |||||
| sequence_scores = scaled_value * sequence_scores | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||||
| (math.pow(1 / (5 + step), self.length_penalty)) | |||||
| scores = scores * scaled_value | |||||
| scores = scores + sequence_scores.unsqueeze(-1) | |||||
| scores = scores.reshape(batch_size, beam_size * self.vocab_size) | |||||
| topk_scores, topk_indices = torch.topk(scores, beam_size) | |||||
| # topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) | |||||
| # 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 | |||||
| parent_idx = topk_indices.floor_divide(self.vocab_size) | |||||
| # 对vocab_size取余 | |||||
| preds = topk_indices % self.vocab_size | |||||
| # Gather state / sequence_scores | |||||
| parent_idx = parent_idx + pos_index | |||||
| parent_idx = parent_idx.reshape(batch_size * beam_size) | |||||
| state = gather(state, parent_idx) | |||||
| sequence_scores = topk_scores | |||||
| predictions = predictions.reshape(batch_size * beam_size, step) | |||||
| predictions = gather(predictions, parent_idx) | |||||
| predictions = predictions.reshape(batch_size, beam_size, step) | |||||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||||
| # 希望生成的整个句子已完结,所以要求最后一个token为<eos>或者<pad>(跟在<eos>之后),否则惩罚 | |||||
| pre_ids = predictions[:, :, -1] | |||||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||||
| sequence_scores = sequence_scores * pre_eos_mask + ( | |||||
| 1 - pre_eos_mask) * (-1e10) | |||||
| # 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) | |||||
| indices = torch.argsort(sequence_scores, dim=1) | |||||
| indices = indices + pos_index | |||||
| indices = indices.reshape(-1) | |||||
| sequence_scores = sequence_scores.reshape(batch_size * beam_size) | |||||
| predictions = predictions.reshape(batch_size * beam_size, -1) | |||||
| sequence_scores = gather(sequence_scores, indices) | |||||
| predictions = gather(predictions, indices) | |||||
| sequence_scores = sequence_scores.reshape(batch_size, beam_size) | |||||
| predictions = predictions.reshape(batch_size, beam_size, -1) | |||||
| results = { | |||||
| 'preds': predictions[:, -1], | |||||
| 'scores': sequence_scores[:, -1] | |||||
| } | |||||
| return results | |||||
| BeamSearch.register('BeamSearch') | |||||
| @@ -0,0 +1,99 @@ | |||||
| """ | |||||
| Model base | |||||
| """ | |||||
| import os | |||||
| import torch.nn as nn | |||||
| class ModelBase(nn.Module): | |||||
| """ | |||||
| Basic model wrapper for static graph and dygrpah. | |||||
| """ | |||||
| _registry = dict() | |||||
| @classmethod | |||||
| def register(cls, name): | |||||
| ModelBase._registry[name] = cls | |||||
| return | |||||
| @staticmethod | |||||
| def by_name(name): | |||||
| return ModelBase._registry[name] | |||||
| @staticmethod | |||||
| def create(model_dir, config, *args, **kwargs): | |||||
| model_cls = ModelBase.by_name(config.Model.model) | |||||
| return model_cls(model_dir, config, *args, **kwargs) | |||||
| def __init__(self, model_dir, config): | |||||
| super(ModelBase, self).__init__() | |||||
| self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin') | |||||
| self.abandon_label = config.Dataset.abandon_label | |||||
| self.use_gpu = config.use_gpu | |||||
| self.gpu = config.Trainer.gpu | |||||
| return | |||||
| def _create_parameters(self): | |||||
| """ Create model's paramters. """ | |||||
| raise NotImplementedError | |||||
| def _forward(self, inputs, is_training, with_label): | |||||
| """ NO LABEL: Real forward process of model in different mode(train/test). """ | |||||
| raise NotImplementedError | |||||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||||
| """ NO LABEL: Calculate loss function by using inputs and outputs. """ | |||||
| raise NotImplementedError | |||||
| def _optimize(self, loss, optimizer, lr_scheduler): | |||||
| """ Optimize loss function and update model. """ | |||||
| raise NotImplementedError | |||||
| def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): | |||||
| """ Real inference process of model. """ | |||||
| raise NotImplementedError | |||||
| def forward(self, | |||||
| inputs, | |||||
| is_training=False, | |||||
| with_label=False, | |||||
| data_file=None): | |||||
| """ | |||||
| Forward process, include real forward, collect metrices and optimize(optional) | |||||
| @params : inputs : input data | |||||
| @type : dict of numpy.ndarray/int/float/... | |||||
| """ | |||||
| if is_training: | |||||
| self.train() | |||||
| else: | |||||
| self.eval() | |||||
| with_label = False if self.abandon_label else with_label | |||||
| outputs = self._forward(inputs, is_training, with_label=with_label) | |||||
| metrics = self._collect_metrics( | |||||
| inputs, outputs, with_label=with_label, data_file=data_file) | |||||
| return metrics | |||||
| def infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ | |||||
| Inference process. | |||||
| @params : inputs : input data | |||||
| @type : dict of numpy.ndarray/int/float/... | |||||
| """ | |||||
| self.eval() | |||||
| results = self._infer( | |||||
| inputs, | |||||
| start_id=start_id, | |||||
| eos_id=eos_id, | |||||
| max_gen_len=max_gen_len, | |||||
| prev_input=prev_input) | |||||
| return results | |||||
| @@ -0,0 +1,322 @@ | |||||
| """ | |||||
| UnifiedTransformer | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from maas_lib.models.nlp.space.model.model_base import ModelBase | |||||
| from maas_lib.models.nlp.space.modules.embedder import Embedder | |||||
| from maas_lib.models.nlp.space.modules.transformer_block import \ | |||||
| TransformerBlock | |||||
| class UnifiedTransformer(ModelBase): | |||||
| """ | |||||
| Implement unified transformer. | |||||
| """ | |||||
| def __init__(self, model_dir, config, reader, generator, dtype='float32'): | |||||
| super(UnifiedTransformer, self).__init__(model_dir, config) | |||||
| self.reader = reader | |||||
| self.generator = generator | |||||
| self.policy = config.BPETextField.policy | |||||
| self.generation = config.BPETextField.generation | |||||
| self.num_token_embeddings = config.Model.num_token_embeddings | |||||
| self.num_pos_embeddings = config.Model.num_pos_embeddings | |||||
| self.num_type_embeddings = config.Model.num_type_embeddings | |||||
| self.num_turn_embeddings = config.Model.num_turn_embeddings | |||||
| self.temperature = config.Model.temperature | |||||
| self.hidden_dim = config.Model.hidden_dim | |||||
| self.num_heads = config.Model.num_heads | |||||
| self.num_layers = config.Model.num_layers | |||||
| self.padding_idx = config.Model.padding_idx | |||||
| self.dropout = config.Model.dropout | |||||
| self.embed_dropout = config.Model.embed_dropout | |||||
| self.attn_dropout = config.Model.attn_dropout | |||||
| self.ff_dropout = config.Model.ff_dropout | |||||
| self.mlm_ratio = config.Model.mlm_ratio | |||||
| self.mmd_ratio = config.Model.mmd_ratio | |||||
| self.pos_trainable = config.Model.pos_trainable | |||||
| self.label_smooth = config.Model.label_smooth | |||||
| self.initializer_range = config.Model.initializer_range | |||||
| self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||||
| self.token_loss = config.Trainer.token_loss | |||||
| self.learning_method = config.Dataset.learning_method | |||||
| self.with_contrastive = config.Dataset.with_contrastive | |||||
| self.with_query_bow = config.BPETextField.with_query_bow | |||||
| self.with_resp_bow = config.BPETextField.with_resp_bow | |||||
| self.with_pool = config.Model.with_pool | |||||
| self.with_mlm = config.Dataset.with_mlm | |||||
| self._dtype = dtype | |||||
| self.embedder = Embedder( | |||||
| self.hidden_dim, | |||||
| self.num_token_embeddings, | |||||
| self.num_pos_embeddings, | |||||
| self.num_type_embeddings, | |||||
| self.num_turn_embeddings, | |||||
| padding_idx=self.padding_idx, | |||||
| dropout=self.embed_dropout, | |||||
| pos_trainable=self.pos_trainable) | |||||
| self.embed_layer_norm = nn.LayerNorm( | |||||
| normalized_shape=self.hidden_dim, | |||||
| eps=1e-12, | |||||
| elementwise_affine=True) | |||||
| self.layers = nn.ModuleList([ | |||||
| TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, | |||||
| self.attn_dropout, self.ff_dropout) | |||||
| for _ in range(config.Model.num_layers) | |||||
| ]) | |||||
| if self.with_mlm: | |||||
| self.mlm_transform = nn.Sequential( | |||||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), | |||||
| nn.LayerNorm( | |||||
| normalized_shape=self.hidden_dim, | |||||
| eps=1e-12, | |||||
| elementwise_affine=True)) | |||||
| self.mlm_bias = nn.Parameter( | |||||
| torch.zeros(self.num_token_embeddings)) | |||||
| self.pooler = nn.Sequential( | |||||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) | |||||
| if self.with_query_bow or self.with_resp_bow: | |||||
| self.bow_predictor = nn.Linear( | |||||
| self.hidden_dim, self.num_token_embeddings, bias=False) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| self.bce_loss = nn.BCELoss(reduction='none') | |||||
| self.nll_loss = nn.NLLLoss( | |||||
| ignore_index=self.padding_idx, reduction='none') | |||||
| self._create_parameters() | |||||
| self.max_grad_norm = config.Model.max_grad_norm | |||||
| if self.max_grad_norm is not None: | |||||
| self.grad_clip = self.max_grad_norm | |||||
| else: | |||||
| self.grad_clip = None | |||||
| self.weight_decay = config.Model.weight_decay | |||||
| if self.use_gpu: | |||||
| self.cuda() | |||||
| return | |||||
| def _create_parameters(self): | |||||
| """ Create model's paramters. """ | |||||
| sequence_mask = np.tri( | |||||
| self.num_pos_embeddings, | |||||
| self.num_pos_embeddings, | |||||
| dtype=self._dtype) | |||||
| self.sequence_mask = torch.tensor(sequence_mask) | |||||
| return | |||||
| def _create_mask(self, | |||||
| input_mask, | |||||
| append_head=False, | |||||
| auto_regressive=False): | |||||
| """ | |||||
| Create attention mask. | |||||
| 创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||||
| mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) | |||||
| 注: | |||||
| 1. 一个句子中的非<pad>词看整个句子,该句中只有<pad>词才被mask | |||||
| 2. 一个句子中的<pad>词看整个句子,该句的所有词都应该被mask | |||||
| @param : input_mask | |||||
| @type : Variable(shape: [batch_size, max_seq_len]) | |||||
| @param : auto_regressive | |||||
| @type : bool | |||||
| """ | |||||
| seq_len = input_mask.shape[1] | |||||
| input_mask = input_mask.float() | |||||
| mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) | |||||
| mask2 = mask1.permute(0, 2, 1) | |||||
| mask = mask1 * mask2 | |||||
| if append_head: | |||||
| # 拼接上句首位置([M]/z)的mask | |||||
| mask = torch.cat([mask[:, :1, :], mask], dim=1) | |||||
| mask = torch.cat([mask[:, :, :1], mask], dim=2) | |||||
| seq_len += 1 | |||||
| if auto_regressive: | |||||
| # 将tgt端的<pad> mask和自回归attention mask融合 | |||||
| seq_mask = self.sequence_mask[:seq_len, :seq_len] | |||||
| seq_mask = seq_mask.to(mask.device) | |||||
| mask = mask * seq_mask | |||||
| mask = 1 - mask | |||||
| return mask | |||||
| def _join_mask(self, mask1, mask2): | |||||
| """ | |||||
| Merge source attention mask and target attention mask. | |||||
| 合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb | |||||
| @param : mask1 : source attention mask | |||||
| @type : Variable(shape: [batch_size, max_src_len, max_src_len]) | |||||
| @param : mask1 : target attention mask | |||||
| @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) | |||||
| """ | |||||
| batch_size = mask1.shape[0] | |||||
| seq_len1 = mask1.shape[1] | |||||
| seq_len2 = mask2.shape[1] | |||||
| seq_len = seq_len1 + seq_len2 | |||||
| mask_lu = mask1 | |||||
| mask_ru = torch.ones(batch_size, seq_len1, seq_len2) | |||||
| if self.use_gpu: | |||||
| mask_ru = mask_ru.cuda() | |||||
| mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) | |||||
| mask4 = mask1[:, :1].repeat(1, seq_len2, 1) | |||||
| mask_lb = mask3 + mask4 - mask3 * mask4 | |||||
| mask_rb = mask2 | |||||
| mask_u = torch.cat([mask_lu, mask_ru], dim=2) | |||||
| mask_b = torch.cat([mask_lb, mask_rb], dim=2) | |||||
| mask = torch.cat([mask_u, mask_b], dim=1) | |||||
| return mask | |||||
| def _mlm_head(self, mlm_embed): | |||||
| mlm_embed = self.mlm_transform(mlm_embed) | |||||
| mlm_logits = torch.matmul( | |||||
| mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias | |||||
| mlm_probs = self.softmax(mlm_logits) | |||||
| return mlm_probs | |||||
| def _dec_head(self, dec_embed): | |||||
| dec_logits = torch.matmul(dec_embed, | |||||
| self.embedder.token_embedding.weight.T) | |||||
| dec_probs = self.softmax(dec_logits) | |||||
| return dec_probs | |||||
| def _refactor_feature(self, features): | |||||
| features = self.pooler(features) if self.with_pool else features | |||||
| batch_size = features.size(0) // 2 | |||||
| features = torch.cat([ | |||||
| features[:batch_size].unsqueeze(1), | |||||
| features[batch_size:].unsqueeze(1) | |||||
| ], | |||||
| dim=1) | |||||
| features = F.normalize(features, dim=-1, p=2) | |||||
| return features | |||||
| def _encoder_network(self, | |||||
| input_token, | |||||
| input_mask, | |||||
| input_pos=None, | |||||
| input_type=None, | |||||
| input_turn=None): | |||||
| embed = self.embedder(input_token, input_pos, input_type, input_turn) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| mask = self._create_mask(input_mask, auto_regressive=False) | |||||
| for layer in self.layers: | |||||
| embed = layer(embed, mask, None) | |||||
| return embed | |||||
| def _encoder_decoder_network(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| tgt_token, | |||||
| tgt_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None, | |||||
| tgt_pos=None, | |||||
| tgt_type=None, | |||||
| tgt_turn=None): | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||||
| embed = torch.cat([src_embed, tgt_embed], dim=1) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||||
| dec_mask = self._create_mask(tgt_mask, auto_regressive=True) | |||||
| mask = self._join_mask(enc_mask, dec_mask) | |||||
| for layer in self.layers: | |||||
| embed = layer(embed, mask, None) | |||||
| tgt_len = tgt_token.shape[1] | |||||
| enc_embed = embed[:, :-tgt_len] | |||||
| dec_embed = embed[:, -tgt_len:] | |||||
| return enc_embed, dec_embed | |||||
| def _encoder_prompt_decoder_network(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| tgt_token, | |||||
| tgt_mask, | |||||
| prompt_token, | |||||
| prompt_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None, | |||||
| tgt_pos=None, | |||||
| tgt_type=None, | |||||
| tgt_turn=None, | |||||
| prompt_pos=None, | |||||
| prompt_type=None, | |||||
| prompt_turn=None): | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||||
| prompt_turn) | |||||
| embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||||
| dec_mask = self._create_mask( | |||||
| torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) | |||||
| mask = self._join_mask(enc_mask, dec_mask) | |||||
| for layer in self.layers: | |||||
| embed = layer(embed, mask, None) | |||||
| src_len = src_token.shape[1] | |||||
| tgt_len = tgt_token.shape[1] | |||||
| enc_embed = embed[:, :src_len] | |||||
| dec_embed = embed[:, -tgt_len:] | |||||
| prompt_embed = embed[:, src_len:-tgt_len] | |||||
| return enc_embed, dec_embed, prompt_embed | |||||
| def _optimize(self, loss, optimizer=None, lr_scheduler=None): | |||||
| """ Optimize loss function and update model. """ | |||||
| assert optimizer is not None | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| if self.grad_clip is not None and self.grad_clip > 0: | |||||
| torch.nn.utils.clip_grad_norm_( | |||||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||||
| optimizer.step() | |||||
| if lr_scheduler is not None: | |||||
| lr_scheduler.step() | |||||
| return | |||||
| def _infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ Real inference process of model. """ | |||||
| results = {} | |||||
| return results | |||||
| UnifiedTransformer.register('UnifiedTransformer') | |||||
| @@ -0,0 +1,67 @@ | |||||
| """ | |||||
| Embedder class. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class Embedder(nn.Module): | |||||
| """ | |||||
| Composite embedding layer. | |||||
| """ | |||||
| def __init__(self, | |||||
| hidden_dim, | |||||
| num_token_embeddings, | |||||
| num_pos_embeddings, | |||||
| num_type_embeddings, | |||||
| num_turn_embeddings, | |||||
| padding_idx=None, | |||||
| dropout=0.1, | |||||
| pos_trainable=False): | |||||
| super(Embedder, self).__init__() | |||||
| self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) | |||||
| self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) | |||||
| self.pos_embedding.weight.requires_grad = pos_trainable | |||||
| self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) | |||||
| self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| # follow the default xavier_uniform initializer in paddle version | |||||
| # otherwise, there are bugs for dec_probs computation in weight typing setting | |||||
| # default norm initializer in nn.Embedding in pytorch, which samples larger values | |||||
| nn.init.xavier_uniform_(self.token_embedding.weight) | |||||
| nn.init.xavier_uniform_(self.pos_embedding.weight) | |||||
| nn.init.xavier_uniform_(self.type_embedding.weight) | |||||
| nn.init.xavier_uniform_(self.turn_embedding.weight) | |||||
| return | |||||
| def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): | |||||
| embed = self.token_embedding(token_inp) | |||||
| if pos_inp is not None: | |||||
| embed += self.pos_embedding(pos_inp) | |||||
| if type_inp is not None: | |||||
| embed += self.type_embedding(type_inp) | |||||
| if turn_inp is not None: | |||||
| embed += self.turn_embedding(turn_inp) | |||||
| embed = self.dropout_layer(embed) | |||||
| return embed | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = Embedder(10, 20, 20, 20, 20) | |||||
| token_inp = torch.tensor( | |||||
| np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| out = model(token_inp, pos_inp, type_inp, turn_inp) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,43 @@ | |||||
| """ | |||||
| FeedForward class. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class FeedForward(nn.Module): | |||||
| """ | |||||
| Positional feed forward layer. | |||||
| """ | |||||
| def __init__(self, hidden_dim, inner_dim, dropout): | |||||
| super(FeedForward, self).__init__() | |||||
| self.hidden_dim = hidden_dim | |||||
| self.inner_dim = inner_dim | |||||
| self.linear_hidden = nn.Sequential( | |||||
| nn.Linear(hidden_dim, inner_dim), nn.GELU()) | |||||
| self.linear_out = nn.Linear(inner_dim, hidden_dim) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| return | |||||
| def forward(self, x): | |||||
| out = self.linear_hidden(x) | |||||
| out = self.dropout_layer(out) | |||||
| out = self.linear_out(out) | |||||
| return out | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = FeedForward(10, 20, 0.5) | |||||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||||
| inp = torch.tensor(inp) | |||||
| out = model(inp) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,64 @@ | |||||
| """ | |||||
| Helpful functions. | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| def unsqueeze(input, dims): | |||||
| """ Implement multi-dimension unsqueeze function. """ | |||||
| if isinstance(dims, (list, tuple)): | |||||
| dims = [ | |||||
| dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims | |||||
| ] | |||||
| dims = sorted(dims, reverse=True) | |||||
| shape = list(input.shape) | |||||
| for dim in dims: | |||||
| shape.insert(dim, 1) | |||||
| return torch.reshape(input, shape) | |||||
| elif isinstance(dims, int): | |||||
| return input.unsqueeze(dims) | |||||
| else: | |||||
| raise ValueError('Warning: type(dims) must in (list, tuple, int)!') | |||||
| def gumbel_softmax(input, tau=1, eps=1e-10): | |||||
| """ Basic implement of gumbel_softmax. """ | |||||
| U = torch.tensor(np.random.rand(*input.shape)) | |||||
| gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) | |||||
| y = input + gumbel | |||||
| return F.softmax(y / tau) | |||||
| def equal(x, y, dtype=None): | |||||
| """ Implement equal in dygraph mode. (paddle) """ | |||||
| if dtype is None: | |||||
| dtype = 'float32' | |||||
| if isinstance(x, torch.Tensor): | |||||
| x = x.numpy() | |||||
| if isinstance(y, torch.Tensor): | |||||
| y = y.numpy() | |||||
| out = np.equal(x, y).astype(dtype) | |||||
| return torch.tensor(out) | |||||
| def not_equal(x, y, dtype=None): | |||||
| """ Implement not_equal in dygraph mode. (paddle) """ | |||||
| return 1 - equal(x, y, dtype) | |||||
| if __name__ == '__main__': | |||||
| a = torch.tensor([[1, 1], [3, 4]]) | |||||
| b = torch.tensor([[1, 1], [3, 4]]) | |||||
| c = torch.equal(a, a) | |||||
| c1 = equal(a, 3) | |||||
| d = 1 - torch.not_equal(a, 3).float() | |||||
| print(c) | |||||
| print(c1) | |||||
| print(d) | |||||
| e = F.gumbel_softmax(a) | |||||
| f = a.unsqueeze(a) | |||||
| g = unsqueeze(a, dims=[0, 0, 1]) | |||||
| print(g, g.shape) | |||||
| @@ -0,0 +1,109 @@ | |||||
| """ | |||||
| MultiheadAttention class. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class MultiheadAttention(nn.Module): | |||||
| """ | |||||
| Multi head attention layer. | |||||
| """ | |||||
| def __init__(self, hidden_dim, num_heads, dropout): | |||||
| assert hidden_dim % num_heads == 0 | |||||
| super(MultiheadAttention, self).__init__() | |||||
| self.hidden_dim = hidden_dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = hidden_dim // num_heads | |||||
| self.scale = self.head_dim**-0.5 | |||||
| self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) | |||||
| self.linear_out = nn.Linear(hidden_dim, hidden_dim) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| return | |||||
| def _split_heads(self, x, is_key=False): | |||||
| x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) | |||||
| x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) | |||||
| return x | |||||
| def _merge_heads(self, x): | |||||
| x = x.permute(0, 2, 1, 3) | |||||
| x = x.reshape(x.size(0), x.size(1), self.hidden_dim) | |||||
| return x | |||||
| def _attn(self, query, key, value, mask): | |||||
| # shape: [batch_size, num_head, seq_len, seq_len] | |||||
| scores = torch.matmul(query, key) | |||||
| scores = scores * self.scale | |||||
| if mask is not None: | |||||
| mask = mask.unsqueeze(1) | |||||
| mask = mask.repeat(1, self.num_heads, 1, 1) | |||||
| scores.masked_fill_( | |||||
| mask.bool(), | |||||
| float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) | |||||
| attn = self.softmax(scores) | |||||
| attn = self.dropout_layer(attn) | |||||
| if mask is not None: | |||||
| ''' | |||||
| mask: [batch size, num_heads, seq_len, seq_len] | |||||
| mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中<pad>位看的行 | |||||
| 导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 | |||||
| >>> F.softmax([-1e10, -100, -100]) | |||||
| >>> [0.00, 0.50, 0.50] | |||||
| >>> F.softmax([-1e10, -1e10, -1e10]) | |||||
| >>> [0.33, 0.33, 0.33] | |||||
| ==> [0.00, 0.00, 0.00] | |||||
| ''' | |||||
| attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn | |||||
| out = torch.matmul(attn, value) | |||||
| return out | |||||
| def forward(self, inp, mask=None, cache=None): | |||||
| """ Forward process of self attention. """ | |||||
| # shape: [batch_size, seq_len, 3 * hidden_dim] | |||||
| qkv = self.linear_qkv(inp) | |||||
| query, key, value = torch.split(qkv, self.hidden_dim, dim=2) | |||||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||||
| query = self._split_heads(query) | |||||
| # shape: [batch_size, num_head, head_dim, seq_len] | |||||
| key = self._split_heads(key, is_key=True) | |||||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||||
| value = self._split_heads(value) | |||||
| if cache is not None: | |||||
| if 'key' in cache and 'value' in cache: | |||||
| key = torch.cat([cache['key'], key], dim=3) | |||||
| value = torch.cat([cache['value'], value], dim=2) | |||||
| cache['key'] = key | |||||
| cache['value'] = value | |||||
| out = self._attn(query, key, value, mask) | |||||
| out = self._merge_heads(out) | |||||
| out = self.linear_out(out) | |||||
| return out | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = MultiheadAttention(10, 2, 0.5) | |||||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||||
| inp = torch.tensor(inp) | |||||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||||
| mask = torch.tensor(mask) | |||||
| out = model(inp, mask=mask, cache=None) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,73 @@ | |||||
| """ | |||||
| TransformerBlock class. | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from maas_lib.models.nlp.space.modules.feedforward import FeedForward | |||||
| from maas_lib.models.nlp.space.modules.multihead_attention import \ | |||||
| MultiheadAttention | |||||
| class TransformerBlock(nn.Module): | |||||
| """ | |||||
| Transformer block module. | |||||
| """ | |||||
| def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, | |||||
| ff_dropout): | |||||
| super(TransformerBlock, self).__init__() | |||||
| self.attn = MultiheadAttention( | |||||
| hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) | |||||
| self.attn_norm = nn.LayerNorm( | |||||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||||
| self.ff = FeedForward( | |||||
| hidden_dim=hidden_dim, | |||||
| inner_dim=4 * hidden_dim, | |||||
| dropout=ff_dropout) | |||||
| self.ff_norm = nn.LayerNorm( | |||||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| return | |||||
| def forward(self, inp, mask=None, cache=None): | |||||
| """ | |||||
| Forward process on one transformer layer. | |||||
| @param : x | |||||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||||
| @param : memory | |||||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||||
| @param : mask | |||||
| @param : cache | |||||
| """ | |||||
| attn_out = self.attn(inp, mask, cache) | |||||
| attn_out = self.dropout_layer(attn_out) | |||||
| attn_out = self.attn_norm(attn_out + inp) | |||||
| ff_out = self.ff(attn_out) | |||||
| ff_out = self.dropout_layer(ff_out) | |||||
| ff_out = self.ff_norm(ff_out + attn_out) | |||||
| return ff_out | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) | |||||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||||
| inp = torch.tensor(inp) | |||||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||||
| mask = torch.tensor(mask) | |||||
| out = model(inp, mask=mask, cache=None) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -4,5 +4,5 @@ from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS, build_preprocessor | from .builder import PREPROCESSORS, build_preprocessor | ||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .nlp import * # noqa F403 | |||||
| from .space.dialog_generation_preprcessor import * # noqa F403 | |||||
| from .nlp.nlp import * # noqa F403 | |||||
| from .nlp.space.dialog_generation_preprcessor import * # noqa F403 | |||||
| @@ -7,8 +7,8 @@ from transformers import AutoTokenizer | |||||
| from maas_lib.utils.constant import Fields, InputFields | from maas_lib.utils.constant import Fields, InputFields | ||||
| from maas_lib.utils.type_assert import type_assert | from maas_lib.utils.type_assert import type_assert | ||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | |||||
| from ..base import Preprocessor | |||||
| from ..builder import PREPROCESSORS | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Tokenize', | 'Tokenize', | ||||
| @@ -5,10 +5,11 @@ import uuid | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField | from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField | ||||
| from maas_lib.utils.config import Config | |||||
| from maas_lib.utils.constant import Fields, InputFields | from maas_lib.utils.constant import Fields, InputFields | ||||
| from maas_lib.utils.type_assert import type_assert | from maas_lib.utils.type_assert import type_assert | ||||
| from ..base import Preprocessor | |||||
| from ..builder import PREPROCESSORS | |||||
| from ...base import Preprocessor | |||||
| from ...builder import PREPROCESSORS | |||||
| __all__ = ['DialogGenerationPreprocessor'] | __all__ = ['DialogGenerationPreprocessor'] | ||||
| @@ -25,10 +26,10 @@ class DialogGenerationPreprocessor(Preprocessor): | |||||
| super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
| self.model_dir: str = model_dir | self.model_dir: str = model_dir | ||||
| self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) | |||||
| pass | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, 'configuration.json')) | |||||
| self.text_field = MultiWOZBPETextField( | |||||
| self.model_dir, config=self.config) | |||||
| @type_assert(object, str) | @type_assert(object, str) | ||||
| def __call__(self, data: str) -> Dict[str, Any]: | def __call__(self, data: str) -> Dict[str, Any]: | ||||
| @@ -4,37 +4,11 @@ import os.path as osp | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from maas_lib.fileio import File | |||||
| from tests.case.nlp.dialog_generation_case import test_case | |||||
| from maas_lib.models.nlp import DialogGenerationModel | from maas_lib.models.nlp import DialogGenerationModel | ||||
| from maas_lib.pipelines import DialogGenerationPipeline, pipeline | from maas_lib.pipelines import DialogGenerationPipeline, pipeline | ||||
| from maas_lib.preprocessors import DialogGenerationPreprocessor | from maas_lib.preprocessors import DialogGenerationPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | |||||
| dialog_case = [{ | |||||
| 'user': | |||||
| 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', | |||||
| 'sys': | |||||
| 'okay , do you have a specific area you want to stay in ?' | |||||
| }, { | |||||
| 'user': | |||||
| 'no , i just need to make sure it is cheap . oh , and i need parking', | |||||
| 'sys': | |||||
| 'i found 1 cheap hotel for you that include -s parking . do you like me to book it ?' | |||||
| }, { | |||||
| 'user': | |||||
| 'yes , please . 6 people 3 nights starting on tuesday .', | |||||
| 'sys': | |||||
| "i am sorry but i was n't able to book that for you for tuesday . is there another day you would like " | |||||
| 'to stay or perhaps a shorter stay ? ' | |||||
| }, { | |||||
| 'user': | |||||
| 'how about only 2 nights .', | |||||
| 'sys': | |||||
| 'booking was successful . reference number is : 7gawk763 . anything else i can do for you ?', | |||||
| }, { | |||||
| 'user': 'no , that will be all . goodbye .', | |||||
| 'sys': 'thank you for using our services .' | |||||
| }] | |||||
| def merge(info, result): | def merge(info, result): | ||||
| @@ -47,21 +21,23 @@ class DialogGenerationTest(unittest.TestCase): | |||||
| modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | ||||
| preprocessor = DialogGenerationPreprocessor() | |||||
| preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) | |||||
| model = DialogGenerationModel( | model = DialogGenerationModel( | ||||
| model_dir=modeldir, preprocessor.tokenizer) | |||||
| pipeline = DialogGenerationPipeline(model, preprocessor) | |||||
| model_dir=modeldir, | |||||
| text_field=preprocessor.text_field, | |||||
| config=preprocessor.config) | |||||
| # pipeline = DialogGenerationPipeline(model, preprocessor) | |||||
| history_dialog = {} | history_dialog = {} | ||||
| for step in range(0, len(dialog_case)): | |||||
| user_question = dialog_case[step]['user'] | |||||
| for step, item in enumerate(test_case['sng0073']['log']): | |||||
| user_question = item['user'] | |||||
| print('user: {}'.format(user_question)) | print('user: {}'.format(user_question)) | ||||
| history_dialog_info = merge(history_dialog_info, | |||||
| result) if step > 0 else {} | |||||
| result = pipeline(user_question, history=history_dialog_info) | |||||
| print('sys : {}'.format(result['pred_answer'])) | |||||
| # history_dialog_info = merge(history_dialog_info, | |||||
| # result) if step > 0 else {} | |||||
| # result = pipeline(user_question, history=history_dialog_info) | |||||
| # | |||||
| # print('sys : {}'.format(result['pred_answer'])) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||