From a1600a65a0e1428a1b79138dc4b205185cd54688 Mon Sep 17 00:00:00 2001 From: ly119399 Date: Wed, 8 Jun 2022 15:17:21 +0800 Subject: [PATCH] add model --- .../nlp/space/dialog_generation_model.py | 11 +- .../nlp/space/model}/__init__.py | 0 .../space/model/gen_unified_transformer.py | 285 ++++++++++++++++ maas_lib/models/nlp/space/model/generator.py | 296 ++++++++++++++++ maas_lib/models/nlp/space/model/model_base.py | 99 ++++++ .../nlp/space/model/unified_transformer.py | 322 ++++++++++++++++++ maas_lib/models/nlp/space/modules/__init__.py | 0 maas_lib/models/nlp/space/modules/embedder.py | 67 ++++ .../models/nlp/space/modules/feedforward.py | 43 +++ .../models/nlp/space/modules/functions.py | 64 ++++ .../nlp/space/modules/multihead_attention.py | 109 ++++++ .../nlp/space/modules/transformer_block.py | 73 ++++ maas_lib/preprocessors/__init__.py | 4 +- maas_lib/preprocessors/nlp/__init__.py | 0 maas_lib/preprocessors/{ => nlp}/nlp.py | 4 +- maas_lib/preprocessors/nlp/space/__init__.py | 0 .../space/dialog_generation_preprcessor.py | 13 +- tests/pipelines/nlp/test_dialog_generation.py | 52 +-- 18 files changed, 1393 insertions(+), 49 deletions(-) rename maas_lib/{preprocessors/space => models/nlp/space/model}/__init__.py (100%) create mode 100644 maas_lib/models/nlp/space/model/gen_unified_transformer.py create mode 100644 maas_lib/models/nlp/space/model/generator.py create mode 100644 maas_lib/models/nlp/space/model/model_base.py create mode 100644 maas_lib/models/nlp/space/model/unified_transformer.py create mode 100644 maas_lib/models/nlp/space/modules/__init__.py create mode 100644 maas_lib/models/nlp/space/modules/embedder.py create mode 100644 maas_lib/models/nlp/space/modules/feedforward.py create mode 100644 maas_lib/models/nlp/space/modules/functions.py create mode 100644 maas_lib/models/nlp/space/modules/multihead_attention.py create mode 100644 maas_lib/models/nlp/space/modules/transformer_block.py create mode 100644 maas_lib/preprocessors/nlp/__init__.py rename maas_lib/preprocessors/{ => nlp}/nlp.py (97%) create mode 100644 maas_lib/preprocessors/nlp/space/__init__.py rename maas_lib/preprocessors/{ => nlp}/space/dialog_generation_preprcessor.py (78%) diff --git a/maas_lib/models/nlp/space/dialog_generation_model.py b/maas_lib/models/nlp/space/dialog_generation_model.py index 72a99705..a5d286a4 100644 --- a/maas_lib/models/nlp/space/dialog_generation_model.py +++ b/maas_lib/models/nlp/space/dialog_generation_model.py @@ -3,6 +3,8 @@ from typing import Any, Dict, Optional from maas_lib.utils.constant import Tasks from ...base import Model, Tensor from ...builder import MODELS +from .model.generator import Generator +from .model.model_base import ModelBase __all__ = ['DialogGenerationModel'] @@ -21,7 +23,14 @@ class DialogGenerationModel(Model): super().__init__(model_dir, *args, **kwargs) self.model_dir = model_dir - pass + self.text_field = kwargs.pop('text_field') + self.config = kwargs.pop('config') + self.generator = Generator.create(self.config, reader=self.text_field) + self.model = ModelBase.create( + model_dir=model_dir, + config=self.config, + reader=self.text_field, + generator=self.generator) def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model diff --git a/maas_lib/preprocessors/space/__init__.py b/maas_lib/models/nlp/space/model/__init__.py similarity index 100% rename from maas_lib/preprocessors/space/__init__.py rename to maas_lib/models/nlp/space/model/__init__.py diff --git a/maas_lib/models/nlp/space/model/gen_unified_transformer.py b/maas_lib/models/nlp/space/model/gen_unified_transformer.py new file mode 100644 index 00000000..2ea68bd1 --- /dev/null +++ b/maas_lib/models/nlp/space/model/gen_unified_transformer.py @@ -0,0 +1,285 @@ +""" +IntentUnifiedTransformer +""" +import torch + +from maas_lib.models.nlp.space.model.unified_transformer import \ + UnifiedTransformer + + +class GenUnifiedTransformer(UnifiedTransformer): + """ + Implement generation unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator): + super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, + generator) + self.understand = config.BPETextField.understand + + if self.use_gpu: + self.cuda() + return + + def _forward(self, inputs, is_training, with_label): + """ Real forward process of model in different mode(train/test). """ + + def cat(x, y, dim=1): + return torch.cat([x, y], dim=dim) + + outputs = {} + + if self.understand or self.policy: + if self.understand: + prompt_token = inputs['understand_token'] + prompt_mask = inputs['understand_mask'] + if self.policy: + prompt_token = cat(prompt_token, inputs['policy_token']) + prompt_mask = cat(prompt_mask, inputs['policy_mask']) + else: + prompt_token = inputs['policy_token'] + prompt_mask = inputs['policy_mask'] + + enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'][:, :-1], + tgt_mask=inputs['tgt_mask'][:, :-1], + prompt_token=prompt_token, + prompt_mask=prompt_mask, + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn'], + tgt_pos=inputs['tgt_pos'][:, :-1], + tgt_type=inputs['tgt_type'][:, :-1], + tgt_turn=inputs['tgt_turn'][:, :-1]) + else: + enc_embed, dec_embed = self._encoder_decoder_network( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + tgt_token=inputs['tgt_token'][:, :-1], + tgt_mask=inputs['tgt_mask'][:, :-1], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn'], + tgt_pos=inputs['tgt_pos'][:, :-1], + tgt_type=inputs['tgt_type'][:, :-1], + tgt_turn=inputs['tgt_turn'][:, :-1]) + + outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) + return outputs + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + + metrics = {} + loss = 0. + + label = inputs['tgt_token'][:, 1:] + token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) + nll = self.nll_loss( + torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) + nll = torch.sum(nll, dim=1) + token_nll = torch.sum(nll) / token_num + nll = torch.mean(nll) + metrics['nll'] = nll + metrics['token_nll'] = token_nll + metrics['token_num'] = token_num + loss = loss + (token_nll if self.token_loss else nll) + + metrics['loss'] = loss + if self.gpu > 1: + return nll, token_nll, token_num + else: + return metrics + + def _optimize(self, loss, do_update=False, optimizer=None): + """ Optimize loss function and update model. """ + assert optimizer is not None + + if self.gradient_accumulation_steps > 1: + loss = loss / self.gradient_accumulation_steps + + loss.backward() + + if self.grad_clip is not None and self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + parameters=self.parameters(), max_norm=self.grad_clip) + + if do_update: + optimizer.step() + optimizer.zero_grad() + + return + + def _init_state(self, + src_token, + src_mask, + src_pos=None, + src_type=None, + src_turn=None): + """ Initialize decode state. """ + state = {} + batch_size = src_token.shape[0] + + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + src_embed = self.embed_layer_norm(src_embed) + + mask = self._create_mask(src_mask, append_head=False) + + enc_out = src_embed + + cache = {} + for l, layer in enumerate(self.layers): + cache[f'layer_{l}'] = {} + enc_out = layer(enc_out, mask, cache[f'layer_{l}']) + + state['cache'] = cache + state['mask'] = mask[:, :1] + state['batch_size'] = batch_size + shape = [batch_size, 1, 1] + state['pred_mask'] = torch.ones(shape, dtype=torch.float32) + state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) + state['pred_type'] = torch.zeros(shape, dtype=torch.int64) + state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) + if self.use_gpu: + state['pred_mask'] = state['pred_mask'].cuda() + state['pred_pos'] = state['pred_pos'].cuda() + state['pred_type'] = state['pred_type'].cuda() + state['pred_turn'] = state['pred_turn'].cuda() + + return state + + def _init_prompt_state(self, + src_token, + src_mask, + prompt_token, + prompt_mask, + src_pos=None, + src_type=None, + src_turn=None, + prompt_pos=None, + prompt_type=None, + prompt_turn=None): + """ Initialize decode state. """ + state = {} + batch_size = src_token.shape[0] + + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, + prompt_turn) + embed = torch.cat([src_embed, prompt_embed], dim=1) + embed = self.embed_layer_norm(embed) + enc_out = embed + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask(prompt_mask, auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + cache = {} + for l, layer in enumerate(self.layers): + cache[f'layer_{l}'] = {} + enc_out = layer(enc_out, mask, cache[f'layer_{l}']) + + state['cache'] = cache + state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] + state['batch_size'] = batch_size + shape = [batch_size, 1, 1] + state['pred_mask'] = torch.ones(shape, dtype=torch.float32) + state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) + state['pred_type'] = torch.zeros(shape, dtype=torch.int64) + state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) + if self.use_gpu: + state['pred_mask'] = state['pred_mask'].cuda() + state['pred_pos'] = state['pred_pos'].cuda() + state['pred_type'] = state['pred_type'].cuda() + state['pred_turn'] = state['pred_turn'].cuda() + + return state + + def _decode(self, state): + """ Decoding one time stamp. """ + + # shape: [batch_size, 1, seq_len] + mask = state['mask'] + + # shape: [batch_size, 1, 1] + pred_token = state['pred_token'] + pred_mask = state['pred_mask'] + pred_pos = state['pred_pos'] + pred_type = state['pred_type'] + pred_turn = state['pred_turn'] + + # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] + cache = state['cache'] + + pred_embed = self.embedder(pred_token, pred_pos, pred_type, + pred_turn).squeeze(-2) + pred_embed = self.embed_layer_norm(pred_embed) + + # shape: [batch_size, 1, seq_len + 1] + mask = torch.cat([mask, 1 - pred_mask], dim=2) + + # shape: [batch_size, 1, hidden_dim] + for l, layer in enumerate(self.layers): + pred_embed = layer(pred_embed, mask, cache[f'layer_{l}']) + + # shape: [batch_size, vocab_size] + pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) + pred_logits = torch.log(pred_probs) + + state['mask'] = mask + return pred_logits, state + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + + def cat(x, y, dim=1): + return torch.cat([x, y], dim=dim) + + # Initial decode state. + if self.understand or self.policy: + if self.understand: + prompt_token = inputs['understand_token'] + prompt_mask = inputs['understand_mask'] + if self.policy: + prompt_token = cat(prompt_token, inputs['policy_token']) + prompt_mask = cat(prompt_mask, inputs['policy_mask']) + else: + prompt_token = inputs['policy_token'] + prompt_mask = inputs['policy_mask'] + + state = self._init_prompt_state( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + prompt_token=prompt_token, + prompt_mask=prompt_mask, + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + else: + state = self._init_state( + src_token=inputs['src_token'], + src_mask=inputs['src_mask'], + src_pos=inputs['src_pos'], + src_type=inputs['src_type'], + src_turn=inputs['src_turn']) + + # Generation process. + gen_results = self.generator( + step_fn=self._decode, + state=state, + start_id=start_id, + eos_id=eos_id, + max_gen_len=max_gen_len, + prev_input=prev_input) + + outputs = gen_results['preds'] + return outputs + + +GenUnifiedTransformer.register('GenUnifiedTransformer') diff --git a/maas_lib/models/nlp/space/model/generator.py b/maas_lib/models/nlp/space/model/generator.py new file mode 100644 index 00000000..2567102f --- /dev/null +++ b/maas_lib/models/nlp/space/model/generator.py @@ -0,0 +1,296 @@ +""" +Generator class. +""" + +import math + +import numpy as np +import torch + +from .gen_unified_transformer import GenUnifiedTransformer +from .unified_transformer import UnifiedTransformer + + +def repeat(var, times): + if isinstance(var, list): + return [repeat(x, times) for x in var] + elif isinstance(var, dict): + return {k: repeat(v, times) for k, v in var.items()} + elif isinstance(var, torch.Tensor): + var = var.unsqueeze(1) + expand_times = [1] * len(var.shape) + expand_times[1] = times + dtype = var.dtype + var = var.float() + var = var.repeat(*expand_times) + shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) + var = var.reshape(*shape) + var = torch.tensor(var, dtype=dtype) + return var + else: + return var + + +def gather(var, idx): + if isinstance(var, list): + return [gather(x, idx) for x in var] + elif isinstance(var, dict): + return {k: gather(v, idx) for k, v in var.items()} + elif isinstance(var, torch.Tensor): + out = var.index_select(dim=0, index=idx) + return out + else: + return var + + +class Generator(object): + """ Genrator class. """ + + _registry = dict() + + @classmethod + def register(cls, name): + Generator._registry[name] = cls + return + + @staticmethod + def by_name(name): + return Generator._registry[name] + + @staticmethod + def create(config, *args, **kwargs): + """ Create generator. """ + generator_cls = Generator.by_name(config.Generator.generator) + return generator_cls(config, *args, **kwargs) + + def __init__(self, config, reader): + self.vocab_size = reader.vocab_size + self.bos_id = reader.bos_id + self.eos_id = reader.eos_id + self.unk_id = reader.unk_id + self.pad_id = reader.pad_id + self.min_gen_len = config.Generator.min_gen_len + self.max_gen_len = config.Generator.max_gen_len + self.use_gpu = config.use_gpu + assert 1 <= self.min_gen_len <= self.max_gen_len + return + + def __call__(self, step_fn, state): + """ + Running generation. + + @param : step_fn : decoding one step + @type : function + + @param : state : initial state + @type : dict + """ + raise NotImplementedError + + +class BeamSearch(Generator): + """ BeamSearch generator. """ + + def __init__(self, config, reader): + super().__init__(config, reader) + self.beam_size = config.Generator.beam_size + self.length_average = config.Generator.length_average + self.length_penalty = config.Generator.length_penalty + self.ignore_unk = config.Generator.ignore_unk + return + + def __call__(self, + step_fn, + state, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ + Running beam search. + + @param : step_fn : decoding one step + @type : function + + @param : state : initial state + @type : dict + """ + if prev_input is not None: + + if isinstance(prev_input, list): + length = max(list(map(lambda x: len(x), prev_input))) + prev_input_numpy = np.full((len(prev_input), length), + self.pad_id) + for i, x in enumerate(prev_input): + prev_input_numpy[i, :len(x)] = x + prev_input_tensor = torch.from_numpy(prev_input_numpy) + if self.use_gpu: + prev_input_tensor = prev_input_tensor.cuda() + + for i in range(length): + state['pred_token'] = prev_input_tensor[:, i].unsqueeze( + -1).unsqueeze(-1) + if i != 0: + state['pred_mask'] = torch.not_equal( + state['pred_token'], self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + state[ + 'pred_mask'].int() + _, state = step_fn(state) + else: + assert isinstance(prev_input, torch.Tensor) + for i, input in enumerate(prev_input): + state['pred_token'] = input.expand(1, 1, 1) + if i != 0: + state['pred_mask'] = torch.not_equal( + state['pred_token'], self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + _, state = step_fn(state) + + batch_size = state['batch_size'] + beam_size = self.beam_size + + # shape: [batch_size, 1] + pos_index = torch.arange( + 0, batch_size, 1, dtype=torch.int64) * beam_size + pos_index = pos_index.unsqueeze(1) + + # shape: [batch_size, beam_size, 1] + if start_id is None: + start_id = self.bos_id + if eos_id is None: + eos_id = self.eos_id + predictions = torch.ones([batch_size, beam_size, 1], + dtype=torch.int64) * start_id + + if self.use_gpu: + pos_index = pos_index.cuda() + predictions = predictions.cuda() + + # initial input (start_id) + state['pred_token'] = predictions[:, :1] + if prev_input is not None: + state['pred_mask'] = torch.not_equal(state['pred_token'], + self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + + # shape: [batch_size, vocab_size] + scores, state = step_fn(state) + + unk_penalty = np.zeros(self.vocab_size, dtype='float32') + unk_penalty[self.unk_id] = -1e10 + unk_penalty = torch.from_numpy(unk_penalty) + + eos_penalty = np.zeros(self.vocab_size, dtype='float32') + eos_penalty[eos_id] = -1e10 + eos_penalty = torch.from_numpy(eos_penalty) + + scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') + scores_after_end[ + self.pad_id] = 0 # 希望之后只生成,故使词表中log(p())最高(0) + scores_after_end = torch.from_numpy(scores_after_end) + + if self.use_gpu: + unk_penalty = unk_penalty.cuda() + eos_penalty = eos_penalty.cuda() + scores_after_end = scores_after_end.cuda() + + if self.ignore_unk: + scores = scores + unk_penalty + scores = scores + eos_penalty + + # shape: [batch_size, beam_size] + sequence_scores, preds = torch.topk(scores, self.beam_size) + + predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) + state = repeat(state, beam_size) + + parent_idx_list = [] + pred_list = [] + + if max_gen_len is None: + max_gen_len = self.max_gen_len + for step in range(2, max_gen_len + 1): + pre_ids = predictions[:, :, -1:] + state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) + state['pred_mask'] = torch.not_equal(state['pred_token'], + self.pad_id).float() + state['pred_pos'] = state['pred_pos'] + 1 + scores, state = step_fn(state) + + # Generate next + # scores shape: [batch_size * beam_size, vocab_size] + if self.ignore_unk: + scores = scores + unk_penalty + + if step <= self.min_gen_len: + scores = scores + eos_penalty + + # scores shape: [batch_size, beam_size, vocab_size] + scores = scores.reshape(batch_size, beam_size, self.vocab_size) + + # previous token is [PAD] or [EOS] + pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ + (1 - torch.not_equal(pre_ids, self.pad_id).float()) + + scores = scores * (1 - pre_eos_mask) + \ + pre_eos_mask.repeat(1, 1, self.vocab_size) * scores_after_end + if self.length_average: + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - + 1 / step) + sequence_scores = sequence_scores.unsqueeze(2) * scaled_value + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) + scores = scores * scaled_value + elif self.length_penalty >= 0.0: + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ + (math.pow((4 + step) / (5 + step), self.length_penalty)) + sequence_scores = scaled_value * sequence_scores + scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ + (math.pow(1 / (5 + step), self.length_penalty)) + scores = scores * scaled_value + scores = scores + sequence_scores.unsqueeze(-1) + scores = scores.reshape(batch_size, beam_size * self.vocab_size) + + topk_scores, topk_indices = torch.topk(scores, beam_size) + # topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) + # 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 + parent_idx = topk_indices.floor_divide(self.vocab_size) + # 对vocab_size取余 + preds = topk_indices % self.vocab_size + + # Gather state / sequence_scores + parent_idx = parent_idx + pos_index + parent_idx = parent_idx.reshape(batch_size * beam_size) + state = gather(state, parent_idx) + sequence_scores = topk_scores + + predictions = predictions.reshape(batch_size * beam_size, step) + predictions = gather(predictions, parent_idx) + predictions = predictions.reshape(batch_size, beam_size, step) + predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) + + # 希望生成的整个句子已完结,所以要求最后一个token为或者(跟在之后),否则惩罚 + pre_ids = predictions[:, :, -1] + pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ + (1 - torch.not_equal(pre_ids, self.pad_id).float()) + sequence_scores = sequence_scores * pre_eos_mask + ( + 1 - pre_eos_mask) * (-1e10) + + # 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) + indices = torch.argsort(sequence_scores, dim=1) + indices = indices + pos_index + indices = indices.reshape(-1) + sequence_scores = sequence_scores.reshape(batch_size * beam_size) + predictions = predictions.reshape(batch_size * beam_size, -1) + sequence_scores = gather(sequence_scores, indices) + predictions = gather(predictions, indices) + sequence_scores = sequence_scores.reshape(batch_size, beam_size) + predictions = predictions.reshape(batch_size, beam_size, -1) + + results = { + 'preds': predictions[:, -1], + 'scores': sequence_scores[:, -1] + } + return results + + +BeamSearch.register('BeamSearch') diff --git a/maas_lib/models/nlp/space/model/model_base.py b/maas_lib/models/nlp/space/model/model_base.py new file mode 100644 index 00000000..cdd355a5 --- /dev/null +++ b/maas_lib/models/nlp/space/model/model_base.py @@ -0,0 +1,99 @@ +""" +Model base +""" +import os + +import torch.nn as nn + + +class ModelBase(nn.Module): + """ + Basic model wrapper for static graph and dygrpah. + """ + _registry = dict() + + @classmethod + def register(cls, name): + ModelBase._registry[name] = cls + return + + @staticmethod + def by_name(name): + return ModelBase._registry[name] + + @staticmethod + def create(model_dir, config, *args, **kwargs): + model_cls = ModelBase.by_name(config.Model.model) + return model_cls(model_dir, config, *args, **kwargs) + + def __init__(self, model_dir, config): + super(ModelBase, self).__init__() + self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin') + self.abandon_label = config.Dataset.abandon_label + self.use_gpu = config.use_gpu + self.gpu = config.Trainer.gpu + return + + def _create_parameters(self): + """ Create model's paramters. """ + raise NotImplementedError + + def _forward(self, inputs, is_training, with_label): + """ NO LABEL: Real forward process of model in different mode(train/test). """ + raise NotImplementedError + + def _collect_metrics(self, inputs, outputs, with_label, data_file): + """ NO LABEL: Calculate loss function by using inputs and outputs. """ + raise NotImplementedError + + def _optimize(self, loss, optimizer, lr_scheduler): + """ Optimize loss function and update model. """ + raise NotImplementedError + + def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): + """ Real inference process of model. """ + raise NotImplementedError + + def forward(self, + inputs, + is_training=False, + with_label=False, + data_file=None): + """ + Forward process, include real forward, collect metrices and optimize(optional) + + @params : inputs : input data + @type : dict of numpy.ndarray/int/float/... + """ + if is_training: + self.train() + else: + self.eval() + + with_label = False if self.abandon_label else with_label + outputs = self._forward(inputs, is_training, with_label=with_label) + metrics = self._collect_metrics( + inputs, outputs, with_label=with_label, data_file=data_file) + + return metrics + + def infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ + Inference process. + + @params : inputs : input data + @type : dict of numpy.ndarray/int/float/... + """ + self.eval() + results = self._infer( + inputs, + start_id=start_id, + eos_id=eos_id, + max_gen_len=max_gen_len, + prev_input=prev_input) + return results diff --git a/maas_lib/models/nlp/space/model/unified_transformer.py b/maas_lib/models/nlp/space/model/unified_transformer.py new file mode 100644 index 00000000..53e03c69 --- /dev/null +++ b/maas_lib/models/nlp/space/model/unified_transformer.py @@ -0,0 +1,322 @@ +""" +UnifiedTransformer +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from maas_lib.models.nlp.space.model.model_base import ModelBase +from maas_lib.models.nlp.space.modules.embedder import Embedder +from maas_lib.models.nlp.space.modules.transformer_block import \ + TransformerBlock + + +class UnifiedTransformer(ModelBase): + """ + Implement unified transformer. + """ + + def __init__(self, model_dir, config, reader, generator, dtype='float32'): + super(UnifiedTransformer, self).__init__(model_dir, config) + self.reader = reader + self.generator = generator + self.policy = config.BPETextField.policy + self.generation = config.BPETextField.generation + self.num_token_embeddings = config.Model.num_token_embeddings + self.num_pos_embeddings = config.Model.num_pos_embeddings + self.num_type_embeddings = config.Model.num_type_embeddings + self.num_turn_embeddings = config.Model.num_turn_embeddings + self.temperature = config.Model.temperature + self.hidden_dim = config.Model.hidden_dim + self.num_heads = config.Model.num_heads + self.num_layers = config.Model.num_layers + self.padding_idx = config.Model.padding_idx + self.dropout = config.Model.dropout + self.embed_dropout = config.Model.embed_dropout + self.attn_dropout = config.Model.attn_dropout + self.ff_dropout = config.Model.ff_dropout + self.mlm_ratio = config.Model.mlm_ratio + self.mmd_ratio = config.Model.mmd_ratio + self.pos_trainable = config.Model.pos_trainable + self.label_smooth = config.Model.label_smooth + self.initializer_range = config.Model.initializer_range + self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps + self.token_loss = config.Trainer.token_loss + self.learning_method = config.Dataset.learning_method + self.with_contrastive = config.Dataset.with_contrastive + self.with_query_bow = config.BPETextField.with_query_bow + self.with_resp_bow = config.BPETextField.with_resp_bow + self.with_pool = config.Model.with_pool + self.with_mlm = config.Dataset.with_mlm + self._dtype = dtype + + self.embedder = Embedder( + self.hidden_dim, + self.num_token_embeddings, + self.num_pos_embeddings, + self.num_type_embeddings, + self.num_turn_embeddings, + padding_idx=self.padding_idx, + dropout=self.embed_dropout, + pos_trainable=self.pos_trainable) + self.embed_layer_norm = nn.LayerNorm( + normalized_shape=self.hidden_dim, + eps=1e-12, + elementwise_affine=True) + + self.layers = nn.ModuleList([ + TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, + self.attn_dropout, self.ff_dropout) + for _ in range(config.Model.num_layers) + ]) + + if self.with_mlm: + self.mlm_transform = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), + nn.LayerNorm( + normalized_shape=self.hidden_dim, + eps=1e-12, + elementwise_affine=True)) + self.mlm_bias = nn.Parameter( + torch.zeros(self.num_token_embeddings)) + + self.pooler = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) + + if self.with_query_bow or self.with_resp_bow: + self.bow_predictor = nn.Linear( + self.hidden_dim, self.num_token_embeddings, bias=False) + + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(dim=-1) + self.bce_loss = nn.BCELoss(reduction='none') + self.nll_loss = nn.NLLLoss( + ignore_index=self.padding_idx, reduction='none') + self._create_parameters() + + self.max_grad_norm = config.Model.max_grad_norm + if self.max_grad_norm is not None: + self.grad_clip = self.max_grad_norm + else: + self.grad_clip = None + self.weight_decay = config.Model.weight_decay + + if self.use_gpu: + self.cuda() + + return + + def _create_parameters(self): + """ Create model's paramters. """ + sequence_mask = np.tri( + self.num_pos_embeddings, + self.num_pos_embeddings, + dtype=self._dtype) + self.sequence_mask = torch.tensor(sequence_mask) + return + + def _create_mask(self, + input_mask, + append_head=False, + auto_regressive=False): + """ + Create attention mask. + 创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] + mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) + 注: + 1. 一个句子中的非词看整个句子,该句中只有词才被mask + 2. 一个句子中的词看整个句子,该句的所有词都应该被mask + + @param : input_mask + @type : Variable(shape: [batch_size, max_seq_len]) + + @param : auto_regressive + @type : bool + """ + seq_len = input_mask.shape[1] + + input_mask = input_mask.float() + mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) + mask2 = mask1.permute(0, 2, 1) + mask = mask1 * mask2 + + if append_head: + # 拼接上句首位置([M]/z)的mask + mask = torch.cat([mask[:, :1, :], mask], dim=1) + mask = torch.cat([mask[:, :, :1], mask], dim=2) + seq_len += 1 + + if auto_regressive: + # 将tgt端的 mask和自回归attention mask融合 + seq_mask = self.sequence_mask[:seq_len, :seq_len] + seq_mask = seq_mask.to(mask.device) + mask = mask * seq_mask + + mask = 1 - mask + return mask + + def _join_mask(self, mask1, mask2): + """ + Merge source attention mask and target attention mask. + 合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb + + @param : mask1 : source attention mask + @type : Variable(shape: [batch_size, max_src_len, max_src_len]) + + @param : mask1 : target attention mask + @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) + """ + batch_size = mask1.shape[0] + seq_len1 = mask1.shape[1] + seq_len2 = mask2.shape[1] + seq_len = seq_len1 + seq_len2 + + mask_lu = mask1 + mask_ru = torch.ones(batch_size, seq_len1, seq_len2) + if self.use_gpu: + mask_ru = mask_ru.cuda() + mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) + mask4 = mask1[:, :1].repeat(1, seq_len2, 1) + mask_lb = mask3 + mask4 - mask3 * mask4 + mask_rb = mask2 + mask_u = torch.cat([mask_lu, mask_ru], dim=2) + mask_b = torch.cat([mask_lb, mask_rb], dim=2) + mask = torch.cat([mask_u, mask_b], dim=1) + return mask + + def _mlm_head(self, mlm_embed): + mlm_embed = self.mlm_transform(mlm_embed) + mlm_logits = torch.matmul( + mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias + mlm_probs = self.softmax(mlm_logits) + return mlm_probs + + def _dec_head(self, dec_embed): + dec_logits = torch.matmul(dec_embed, + self.embedder.token_embedding.weight.T) + dec_probs = self.softmax(dec_logits) + return dec_probs + + def _refactor_feature(self, features): + features = self.pooler(features) if self.with_pool else features + batch_size = features.size(0) // 2 + features = torch.cat([ + features[:batch_size].unsqueeze(1), + features[batch_size:].unsqueeze(1) + ], + dim=1) + features = F.normalize(features, dim=-1, p=2) + return features + + def _encoder_network(self, + input_token, + input_mask, + input_pos=None, + input_type=None, + input_turn=None): + embed = self.embedder(input_token, input_pos, input_type, input_turn) + embed = self.embed_layer_norm(embed) + mask = self._create_mask(input_mask, auto_regressive=False) + + for layer in self.layers: + embed = layer(embed, mask, None) + + return embed + + def _encoder_decoder_network(self, + src_token, + src_mask, + tgt_token, + tgt_mask, + src_pos=None, + src_type=None, + src_turn=None, + tgt_pos=None, + tgt_type=None, + tgt_turn=None): + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) + embed = torch.cat([src_embed, tgt_embed], dim=1) + embed = self.embed_layer_norm(embed) + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask(tgt_mask, auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + for layer in self.layers: + embed = layer(embed, mask, None) + + tgt_len = tgt_token.shape[1] + enc_embed = embed[:, :-tgt_len] + dec_embed = embed[:, -tgt_len:] + + return enc_embed, dec_embed + + def _encoder_prompt_decoder_network(self, + src_token, + src_mask, + tgt_token, + tgt_mask, + prompt_token, + prompt_mask, + src_pos=None, + src_type=None, + src_turn=None, + tgt_pos=None, + tgt_type=None, + tgt_turn=None, + prompt_pos=None, + prompt_type=None, + prompt_turn=None): + src_embed = self.embedder(src_token, src_pos, src_type, src_turn) + tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) + prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, + prompt_turn) + + embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) + embed = self.embed_layer_norm(embed) + + enc_mask = self._create_mask(src_mask, auto_regressive=False) + dec_mask = self._create_mask( + torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) + mask = self._join_mask(enc_mask, dec_mask) + + for layer in self.layers: + embed = layer(embed, mask, None) + + src_len = src_token.shape[1] + tgt_len = tgt_token.shape[1] + enc_embed = embed[:, :src_len] + dec_embed = embed[:, -tgt_len:] + prompt_embed = embed[:, src_len:-tgt_len] + + return enc_embed, dec_embed, prompt_embed + + def _optimize(self, loss, optimizer=None, lr_scheduler=None): + """ Optimize loss function and update model. """ + assert optimizer is not None + optimizer.zero_grad() + loss.backward() + + if self.grad_clip is not None and self.grad_clip > 0: + torch.nn.utils.clip_grad_norm_( + parameters=self.parameters(), max_norm=self.grad_clip) + optimizer.step() + if lr_scheduler is not None: + lr_scheduler.step() + return + + def _infer(self, + inputs, + start_id=None, + eos_id=None, + max_gen_len=None, + prev_input=None): + """ Real inference process of model. """ + results = {} + return results + + +UnifiedTransformer.register('UnifiedTransformer') diff --git a/maas_lib/models/nlp/space/modules/__init__.py b/maas_lib/models/nlp/space/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/models/nlp/space/modules/embedder.py b/maas_lib/models/nlp/space/modules/embedder.py new file mode 100644 index 00000000..4fb592ef --- /dev/null +++ b/maas_lib/models/nlp/space/modules/embedder.py @@ -0,0 +1,67 @@ +""" +Embedder class. +""" + +import torch +import torch.nn as nn + + +class Embedder(nn.Module): + """ + Composite embedding layer. + """ + + def __init__(self, + hidden_dim, + num_token_embeddings, + num_pos_embeddings, + num_type_embeddings, + num_turn_embeddings, + padding_idx=None, + dropout=0.1, + pos_trainable=False): + super(Embedder, self).__init__() + + self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) + self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) + self.pos_embedding.weight.requires_grad = pos_trainable + self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) + self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + + # follow the default xavier_uniform initializer in paddle version + # otherwise, there are bugs for dec_probs computation in weight typing setting + # default norm initializer in nn.Embedding in pytorch, which samples larger values + nn.init.xavier_uniform_(self.token_embedding.weight) + nn.init.xavier_uniform_(self.pos_embedding.weight) + nn.init.xavier_uniform_(self.type_embedding.weight) + nn.init.xavier_uniform_(self.turn_embedding.weight) + return + + def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): + embed = self.token_embedding(token_inp) + if pos_inp is not None: + embed += self.pos_embedding(pos_inp) + if type_inp is not None: + embed += self.type_embedding(type_inp) + if turn_inp is not None: + embed += self.turn_embedding(turn_inp) + embed = self.dropout_layer(embed) + return embed + + +def main(): + import numpy as np + + model = Embedder(10, 20, 20, 20, 20) + token_inp = torch.tensor( + np.random.randint(0, 19, [10, 10]).astype('int64')) + pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) + out = model(token_inp, pos_inp, type_inp, turn_inp) + print(out) + + +if __name__ == '__main__': + main() diff --git a/maas_lib/models/nlp/space/modules/feedforward.py b/maas_lib/models/nlp/space/modules/feedforward.py new file mode 100644 index 00000000..e9a5f4c7 --- /dev/null +++ b/maas_lib/models/nlp/space/modules/feedforward.py @@ -0,0 +1,43 @@ +""" +FeedForward class. +""" + +import torch +import torch.nn as nn + + +class FeedForward(nn.Module): + """ + Positional feed forward layer. + """ + + def __init__(self, hidden_dim, inner_dim, dropout): + super(FeedForward, self).__init__() + + self.hidden_dim = hidden_dim + self.inner_dim = inner_dim + self.linear_hidden = nn.Sequential( + nn.Linear(hidden_dim, inner_dim), nn.GELU()) + self.linear_out = nn.Linear(inner_dim, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + return + + def forward(self, x): + out = self.linear_hidden(x) + out = self.dropout_layer(out) + out = self.linear_out(out) + return out + + +def main(): + import numpy as np + + model = FeedForward(10, 20, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + out = model(inp) + print(out) + + +if __name__ == '__main__': + main() diff --git a/maas_lib/models/nlp/space/modules/functions.py b/maas_lib/models/nlp/space/modules/functions.py new file mode 100644 index 00000000..45c02e21 --- /dev/null +++ b/maas_lib/models/nlp/space/modules/functions.py @@ -0,0 +1,64 @@ +""" +Helpful functions. +""" + +import numpy as np +import torch +import torch.nn.functional as F + + +def unsqueeze(input, dims): + """ Implement multi-dimension unsqueeze function. """ + if isinstance(dims, (list, tuple)): + dims = [ + dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims + ] + dims = sorted(dims, reverse=True) + shape = list(input.shape) + for dim in dims: + shape.insert(dim, 1) + return torch.reshape(input, shape) + elif isinstance(dims, int): + return input.unsqueeze(dims) + else: + raise ValueError('Warning: type(dims) must in (list, tuple, int)!') + + +def gumbel_softmax(input, tau=1, eps=1e-10): + """ Basic implement of gumbel_softmax. """ + U = torch.tensor(np.random.rand(*input.shape)) + gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) + y = input + gumbel + return F.softmax(y / tau) + + +def equal(x, y, dtype=None): + """ Implement equal in dygraph mode. (paddle) """ + if dtype is None: + dtype = 'float32' + if isinstance(x, torch.Tensor): + x = x.numpy() + if isinstance(y, torch.Tensor): + y = y.numpy() + out = np.equal(x, y).astype(dtype) + return torch.tensor(out) + + +def not_equal(x, y, dtype=None): + """ Implement not_equal in dygraph mode. (paddle) """ + return 1 - equal(x, y, dtype) + + +if __name__ == '__main__': + a = torch.tensor([[1, 1], [3, 4]]) + b = torch.tensor([[1, 1], [3, 4]]) + c = torch.equal(a, a) + c1 = equal(a, 3) + d = 1 - torch.not_equal(a, 3).float() + print(c) + print(c1) + print(d) + e = F.gumbel_softmax(a) + f = a.unsqueeze(a) + g = unsqueeze(a, dims=[0, 0, 1]) + print(g, g.shape) diff --git a/maas_lib/models/nlp/space/modules/multihead_attention.py b/maas_lib/models/nlp/space/modules/multihead_attention.py new file mode 100644 index 00000000..209eab5e --- /dev/null +++ b/maas_lib/models/nlp/space/modules/multihead_attention.py @@ -0,0 +1,109 @@ +""" +MultiheadAttention class. +""" + +import torch +import torch.nn as nn + + +class MultiheadAttention(nn.Module): + """ + Multi head attention layer. + """ + + def __init__(self, hidden_dim, num_heads, dropout): + assert hidden_dim % num_heads == 0 + super(MultiheadAttention, self).__init__() + + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.scale = self.head_dim**-0.5 + self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) + self.linear_out = nn.Linear(hidden_dim, hidden_dim) + self.dropout_layer = nn.Dropout(p=dropout) + self.softmax = nn.Softmax(dim=-1) + return + + def _split_heads(self, x, is_key=False): + x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) + x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) + return x + + def _merge_heads(self, x): + x = x.permute(0, 2, 1, 3) + x = x.reshape(x.size(0), x.size(1), self.hidden_dim) + return x + + def _attn(self, query, key, value, mask): + # shape: [batch_size, num_head, seq_len, seq_len] + scores = torch.matmul(query, key) + scores = scores * self.scale + + if mask is not None: + mask = mask.unsqueeze(1) + mask = mask.repeat(1, self.num_heads, 1, 1) + scores.masked_fill_( + mask.bool(), + float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) + + attn = self.softmax(scores) + attn = self.dropout_layer(attn) + + if mask is not None: + ''' + mask: [batch size, num_heads, seq_len, seq_len] + mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中位看的行 + 导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 + + >>> F.softmax([-1e10, -100, -100]) + >>> [0.00, 0.50, 0.50] + >>> F.softmax([-1e10, -1e10, -1e10]) + >>> [0.33, 0.33, 0.33] + ==> [0.00, 0.00, 0.00] + ''' + attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn + + out = torch.matmul(attn, value) + return out + + def forward(self, inp, mask=None, cache=None): + """ Forward process of self attention. """ + # shape: [batch_size, seq_len, 3 * hidden_dim] + qkv = self.linear_qkv(inp) + query, key, value = torch.split(qkv, self.hidden_dim, dim=2) + + # shape: [batch_size, num_head, seq_len, head_dim] + query = self._split_heads(query) + # shape: [batch_size, num_head, head_dim, seq_len] + key = self._split_heads(key, is_key=True) + # shape: [batch_size, num_head, seq_len, head_dim] + value = self._split_heads(value) + + if cache is not None: + if 'key' in cache and 'value' in cache: + key = torch.cat([cache['key'], key], dim=3) + value = torch.cat([cache['value'], value], dim=2) + cache['key'] = key + cache['value'] = value + + out = self._attn(query, key, value, mask) + out = self._merge_heads(out) + out = self.linear_out(out) + return out + + +def main(): + import numpy as np + + model = MultiheadAttention(10, 2, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') + mask = torch.tensor(mask) + out = model(inp, mask=mask, cache=None) + print(out) + + +if __name__ == '__main__': + main() diff --git a/maas_lib/models/nlp/space/modules/transformer_block.py b/maas_lib/models/nlp/space/modules/transformer_block.py new file mode 100644 index 00000000..daa7d723 --- /dev/null +++ b/maas_lib/models/nlp/space/modules/transformer_block.py @@ -0,0 +1,73 @@ +""" +TransformerBlock class. +""" + +import torch +import torch.nn as nn + +from maas_lib.models.nlp.space.modules.feedforward import FeedForward +from maas_lib.models.nlp.space.modules.multihead_attention import \ + MultiheadAttention + + +class TransformerBlock(nn.Module): + """ + Transformer block module. + """ + + def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, + ff_dropout): + super(TransformerBlock, self).__init__() + + self.attn = MultiheadAttention( + hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) + self.attn_norm = nn.LayerNorm( + normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) + self.ff = FeedForward( + hidden_dim=hidden_dim, + inner_dim=4 * hidden_dim, + dropout=ff_dropout) + self.ff_norm = nn.LayerNorm( + normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) + self.dropout_layer = nn.Dropout(p=dropout) + return + + def forward(self, inp, mask=None, cache=None): + """ + Forward process on one transformer layer. + + @param : x + @type : Variable(shape: [batch_size, seq_len, hidden_size]) + + @param : memory + @type : Variable(shape: [batch_size, seq_len, hidden_size]) + + @param : mask + + @param : cache + """ + attn_out = self.attn(inp, mask, cache) + attn_out = self.dropout_layer(attn_out) + attn_out = self.attn_norm(attn_out + inp) + + ff_out = self.ff(attn_out) + ff_out = self.dropout_layer(ff_out) + ff_out = self.ff_norm(ff_out + attn_out) + + return ff_out + + +def main(): + import numpy as np + + model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) + inp = np.random.rand(2, 3, 10).astype('float32') + inp = torch.tensor(inp) + mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') + mask = torch.tensor(mask) + out = model(inp, mask=mask, cache=None) + print(out) + + +if __name__ == '__main__': + main() diff --git a/maas_lib/preprocessors/__init__.py b/maas_lib/preprocessors/__init__.py index b1dc0fa2..9ed0d181 100644 --- a/maas_lib/preprocessors/__init__.py +++ b/maas_lib/preprocessors/__init__.py @@ -4,5 +4,5 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image -from .nlp import * # noqa F403 -from .space.dialog_generation_preprcessor import * # noqa F403 +from .nlp.nlp import * # noqa F403 +from .nlp.space.dialog_generation_preprcessor import * # noqa F403 diff --git a/maas_lib/preprocessors/nlp/__init__.py b/maas_lib/preprocessors/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/preprocessors/nlp.py b/maas_lib/preprocessors/nlp/nlp.py similarity index 97% rename from maas_lib/preprocessors/nlp.py rename to maas_lib/preprocessors/nlp/nlp.py index 0a03328a..ea496883 100644 --- a/maas_lib/preprocessors/nlp.py +++ b/maas_lib/preprocessors/nlp/nlp.py @@ -7,8 +7,8 @@ from transformers import AutoTokenizer from maas_lib.utils.constant import Fields, InputFields from maas_lib.utils.type_assert import type_assert -from .base import Preprocessor -from .builder import PREPROCESSORS +from ..base import Preprocessor +from ..builder import PREPROCESSORS __all__ = [ 'Tokenize', diff --git a/maas_lib/preprocessors/nlp/space/__init__.py b/maas_lib/preprocessors/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/preprocessors/space/dialog_generation_preprcessor.py b/maas_lib/preprocessors/nlp/space/dialog_generation_preprcessor.py similarity index 78% rename from maas_lib/preprocessors/space/dialog_generation_preprcessor.py rename to maas_lib/preprocessors/nlp/space/dialog_generation_preprcessor.py index 846c5872..f47eed7e 100644 --- a/maas_lib/preprocessors/space/dialog_generation_preprcessor.py +++ b/maas_lib/preprocessors/nlp/space/dialog_generation_preprcessor.py @@ -5,10 +5,11 @@ import uuid from typing import Any, Dict, Union from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField +from maas_lib.utils.config import Config from maas_lib.utils.constant import Fields, InputFields from maas_lib.utils.type_assert import type_assert -from ..base import Preprocessor -from ..builder import PREPROCESSORS +from ...base import Preprocessor +from ...builder import PREPROCESSORS __all__ = ['DialogGenerationPreprocessor'] @@ -25,10 +26,10 @@ class DialogGenerationPreprocessor(Preprocessor): super().__init__(*args, **kwargs) self.model_dir: str = model_dir - - self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) - - pass + self.config = Config.from_file( + os.path.join(self.model_dir, 'configuration.json')) + self.text_field = MultiWOZBPETextField( + self.model_dir, config=self.config) @type_assert(object, str) def __call__(self, data: str) -> Dict[str, Any]: diff --git a/tests/pipelines/nlp/test_dialog_generation.py b/tests/pipelines/nlp/test_dialog_generation.py index b3186de6..7b42059a 100644 --- a/tests/pipelines/nlp/test_dialog_generation.py +++ b/tests/pipelines/nlp/test_dialog_generation.py @@ -4,37 +4,11 @@ import os.path as osp import tempfile import unittest -from maas_lib.fileio import File +from tests.case.nlp.dialog_generation_case import test_case + from maas_lib.models.nlp import DialogGenerationModel from maas_lib.pipelines import DialogGenerationPipeline, pipeline from maas_lib.preprocessors import DialogGenerationPreprocessor -from maas_lib.utils.constant import Tasks - -dialog_case = [{ - 'user': - 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', - 'sys': - 'okay , do you have a specific area you want to stay in ?' -}, { - 'user': - 'no , i just need to make sure it is cheap . oh , and i need parking', - 'sys': - 'i found 1 cheap hotel for you that include -s parking . do you like me to book it ?' -}, { - 'user': - 'yes , please . 6 people 3 nights starting on tuesday .', - 'sys': - "i am sorry but i was n't able to book that for you for tuesday . is there another day you would like " - 'to stay or perhaps a shorter stay ? ' -}, { - 'user': - 'how about only 2 nights .', - 'sys': - 'booking was successful . reference number is : 7gawk763 . anything else i can do for you ?', -}, { - 'user': 'no , that will be all . goodbye .', - 'sys': 'thank you for using our services .' -}] def merge(info, result): @@ -47,21 +21,23 @@ class DialogGenerationTest(unittest.TestCase): modeldir = '/Users/yangliu/Desktop/space-dialog-generation' - preprocessor = DialogGenerationPreprocessor() + preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) model = DialogGenerationModel( - model_dir=modeldir, preprocessor.tokenizer) - pipeline = DialogGenerationPipeline(model, preprocessor) + model_dir=modeldir, + text_field=preprocessor.text_field, + config=preprocessor.config) + # pipeline = DialogGenerationPipeline(model, preprocessor) history_dialog = {} - for step in range(0, len(dialog_case)): - user_question = dialog_case[step]['user'] + for step, item in enumerate(test_case['sng0073']['log']): + user_question = item['user'] print('user: {}'.format(user_question)) - history_dialog_info = merge(history_dialog_info, - result) if step > 0 else {} - result = pipeline(user_question, history=history_dialog_info) - - print('sys : {}'.format(result['pred_answer'])) + # history_dialog_info = merge(history_dialog_info, + # result) if step > 0 else {} + # result = pipeline(user_question, history=history_dialog_info) + # + # print('sys : {}'.format(result['pred_answer'])) if __name__ == '__main__':