加入space模型在banking数据集上的微调代码
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10006792
master
| @@ -241,6 +241,7 @@ class Trainers(object): | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| dialog_intent_trainer = 'dialog-intent-trainer' | |||||
| nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .data_loader import DataLoader | |||||
| from .dialog_intent_prediction_preprocessor import \ | from .dialog_intent_prediction_preprocessor import \ | ||||
| DialogIntentPredictionPreprocessor | DialogIntentPredictionPreprocessor | ||||
| from .dialog_modeling_preprocessor import DialogModelingPreprocessor | from .dialog_modeling_preprocessor import DialogModelingPreprocessor | ||||
| @@ -13,6 +14,7 @@ if TYPE_CHECKING: | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'data_loader': ['DataLoader'], | |||||
| 'dialog_intent_prediction_preprocessor': | 'dialog_intent_prediction_preprocessor': | ||||
| ['DialogIntentPredictionPreprocessor'], | ['DialogIntentPredictionPreprocessor'], | ||||
| 'dialog_modeling_preprocessor': ['DialogModelingPreprocessor'], | 'dialog_modeling_preprocessor': ['DialogModelingPreprocessor'], | ||||
| @@ -0,0 +1,66 @@ | |||||
| """ | |||||
| Parse argument. | |||||
| """ | |||||
| import argparse | |||||
| import json | |||||
| def str2bool(v): | |||||
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |||||
| return True | |||||
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |||||
| return False | |||||
| else: | |||||
| raise argparse.ArgumentTypeError('Unsupported value encountered.') | |||||
| class HParams(dict): | |||||
| """ Hyper-parameters class | |||||
| Store hyper-parameters in training / infer / ... scripts. | |||||
| """ | |||||
| def __getattr__(self, name): | |||||
| if name in self.keys(): | |||||
| return self[name] | |||||
| for v in self.values(): | |||||
| if isinstance(v, HParams): | |||||
| if name in v: | |||||
| return v[name] | |||||
| raise AttributeError(f"'HParams' object has no attribute '{name}'") | |||||
| def __setattr__(self, name, value): | |||||
| self[name] = value | |||||
| def save(self, filename): | |||||
| with open(filename, 'w', encoding='utf-8') as fp: | |||||
| json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) | |||||
| def load(self, filename): | |||||
| with open(filename, 'r', encoding='utf-8') as fp: | |||||
| params_dict = json.load(fp) | |||||
| for k, v in params_dict.items(): | |||||
| if isinstance(v, dict): | |||||
| self[k].update(HParams(v)) | |||||
| else: | |||||
| self[k] = v | |||||
| def parse_args(parser): | |||||
| """ Parse hyper-parameters from cmdline. """ | |||||
| parsed = parser.parse_args() | |||||
| args = HParams() | |||||
| optional_args = parser._action_groups[1] | |||||
| for action in optional_args._group_actions[1:]: | |||||
| arg_name = action.dest | |||||
| args[arg_name] = getattr(parsed, arg_name) | |||||
| for group in parser._action_groups[2:]: | |||||
| group_args = HParams() | |||||
| for action in group._group_actions: | |||||
| arg_name = action.dest | |||||
| group_args[arg_name] = getattr(parsed, arg_name) | |||||
| if len(group_args) > 0: | |||||
| args[group.title] = group_args | |||||
| return args | |||||
| @@ -0,0 +1,55 @@ | |||||
| def batch(reader, batch_size, drop_last=False): | |||||
| """ | |||||
| This operator creates a batched reader which combines the data from the | |||||
| input reader to batched data. | |||||
| Args: | |||||
| reader(generator): the data reader to read from. | |||||
| batch_size(int): size of each mini-batch. | |||||
| drop_last(bool, optional): If set to True, the last batch is dropped when | |||||
| the size of last batch is not equal to batch_size, if set to False, | |||||
| it will not. Default: False. | |||||
| Returns: | |||||
| The batched reader. | |||||
| Return Type: | |||||
| generator | |||||
| Examples: | |||||
| .. code-block:: python | |||||
| import paddle.fluid as fluid | |||||
| def reader(): | |||||
| for i in range(10): | |||||
| yield i | |||||
| batch_reader = fluid.io.batch(reader, batch_size=2) | |||||
| for data in batch_reader(): | |||||
| print(data) | |||||
| # Output is | |||||
| # [0, 1] | |||||
| # [2, 3] | |||||
| # [4, 5] | |||||
| # [6, 7] | |||||
| # [8, 9] | |||||
| """ | |||||
| def batch_reader(): | |||||
| r = reader() | |||||
| b = [] | |||||
| for instance in r: | |||||
| b.append(instance) | |||||
| if len(b) == batch_size: | |||||
| yield b | |||||
| b = [] | |||||
| if drop_last is False and len(b) != 0: | |||||
| yield b | |||||
| # Batch size check | |||||
| batch_size = int(batch_size) | |||||
| if batch_size <= 0: | |||||
| raise ValueError('batch_size should be a positive integeral value, ' | |||||
| 'but got batch_size={}'.format(batch_size)) | |||||
| return batch_reader | |||||
| @@ -0,0 +1,112 @@ | |||||
| """ | |||||
| DataLoader class | |||||
| """ | |||||
| import math | |||||
| import os | |||||
| import numpy as np | |||||
| from modelscope.preprocessors.space.args import str2bool | |||||
| from modelscope.preprocessors.space.batch import batch | |||||
| from modelscope.preprocessors.space.lazy_dataset import LazyDataset | |||||
| from modelscope.preprocessors.space.sampler import (RandomSampler, | |||||
| SequentialSampler, | |||||
| SortedSampler) | |||||
| def get_data_loader(batch_size, reader, hparams, file, collate_fn, is_test): | |||||
| assert os.path.exists(file), f"{file} doesn't exist" | |||||
| dataset = LazyDataset(file, reader=reader) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size, | |||||
| hparams.Trainer, | |||||
| collate_fn=collate_fn, | |||||
| is_test=is_test) | |||||
| return data_loader | |||||
| def get_sequential_data_loader(batch_size, reader, hparams, data_paths, | |||||
| collate_fn, data_type): | |||||
| data_loaders = [] | |||||
| for data_path in data_paths: | |||||
| file = os.path.join( | |||||
| data_path, | |||||
| f'{data_type}.{hparams.BPETextField.tokenizer_type}.jsonl') | |||||
| data_loaders.append( | |||||
| get_data_loader( | |||||
| batch_size=batch_size, | |||||
| reader=reader, | |||||
| hparams=hparams, | |||||
| file=file, | |||||
| collate_fn=collate_fn, | |||||
| is_test=(data_type != 'train'))) | |||||
| data_loader = SequentialDataLoaderWrapper(data_loaders) | |||||
| return data_loader | |||||
| class DataLoader(object): | |||||
| """ Implement of DataLoader. """ | |||||
| @classmethod | |||||
| def add_cmdline_argument(cls, group): | |||||
| group.add_argument('--shuffle', type=str2bool, default=True) | |||||
| group.add_argument('--sort_pool_size', type=int, default=0) | |||||
| return group | |||||
| def __init__(self, | |||||
| dataset, | |||||
| batch_size, | |||||
| hparams, | |||||
| collate_fn=None, | |||||
| sampler=None, | |||||
| is_test=False): | |||||
| self.dataset = dataset | |||||
| self.collate_fn = collate_fn | |||||
| self.gpu = hparams.gpu | |||||
| self.sort_pool_size = hparams.sort_pool_size | |||||
| if sampler is None: | |||||
| if hparams.shuffle and not is_test: | |||||
| sampler = RandomSampler(dataset) | |||||
| else: | |||||
| sampler = SequentialSampler(dataset) | |||||
| if self.sort_pool_size > 0 and not is_test: | |||||
| sampler = SortedSampler(sampler, self.sort_pool_size) | |||||
| def reader(): | |||||
| for idx in sampler: | |||||
| yield idx | |||||
| drop_last = False if self.gpu <= 1 or is_test else True | |||||
| self.reader = batch(reader, batch_size=batch_size, drop_last=drop_last) | |||||
| self.num_batches = math.floor(len(dataset) / batch_size) if drop_last \ | |||||
| else math.ceil(len(dataset) / batch_size) | |||||
| def __len__(self): | |||||
| return self.num_batches | |||||
| def __iter__(self): | |||||
| for batch_indices in self.reader(): | |||||
| samples = [self.dataset[idx] for idx in batch_indices] | |||||
| yield self.collate_fn(samples) | |||||
| class SequentialDataLoaderWrapper: | |||||
| def __init__(self, data_loaders): | |||||
| self.data_loaders = data_loaders | |||||
| self.data_file_to_dataset = { | |||||
| data_loader.dataset.data_file: data_loader.dataset | |||||
| for data_loader in self.data_loaders | |||||
| } | |||||
| def __iter__(self): | |||||
| for data_loader in self.data_loaders: | |||||
| for tmp_batch in data_loader: | |||||
| yield data_loader.dataset.data_file, tmp_batch | |||||
| def __len__(self): | |||||
| return np.sum([len(data_loader) for data_loader in self.data_loaders]) | |||||
| @@ -791,7 +791,6 @@ class BPETextField(object): | |||||
| user_or_sys = [self.sos_r_id] | user_or_sys = [self.sos_r_id] | ||||
| tmp = [self.sos_u_id | tmp = [self.sos_u_id | ||||
| ] + self.numericalize(s) + user_or_sys | ] + self.numericalize(s) + user_or_sys | ||||
| tmp = tmp + self.numericalize(s) + [self.eos_r_id] | |||||
| new_src.append(tmp) | new_src.append(tmp) | ||||
| src_span_mask = [[0] + list(map(int, s)) + [0] | src_span_mask = [[0] + list(map(int, s)) + [0] | ||||
| @@ -0,0 +1,47 @@ | |||||
| """ | |||||
| Dataset class | |||||
| """ | |||||
| import json | |||||
| from modelscope.preprocessors.space.args import str2bool | |||||
| class LazyDataset(object): | |||||
| """ | |||||
| Lazy load dataset from disk. | |||||
| Each line of data file is a preprocessed example. | |||||
| """ | |||||
| def __init__(self, data_file, reader, transform=lambda s: json.loads(s)): | |||||
| """ | |||||
| Initialize lazy dataset. | |||||
| By default, loading .jsonl format. | |||||
| :param data_file | |||||
| :type str | |||||
| :param transform | |||||
| :type callable | |||||
| """ | |||||
| self.data_file = data_file | |||||
| self.transform = transform | |||||
| self.reader = reader | |||||
| self.offsets = [0] | |||||
| with open(data_file, 'r', encoding='utf-8') as fp: | |||||
| while fp.readline() != '': | |||||
| self.offsets.append(fp.tell()) | |||||
| self.offsets.pop() | |||||
| self.fp = open(data_file, 'r', encoding='utf-8') | |||||
| def __len__(self): | |||||
| return len(self.offsets) | |||||
| def __getitem__(self, idx): | |||||
| self.fp.seek(self.offsets[idx], 0) | |||||
| sample = self.transform(self.fp.readline().strip()) | |||||
| if self.reader.with_mlm: | |||||
| sample = self.reader.create_token_masked_lm_predictions(sample) | |||||
| return sample | |||||
| @@ -0,0 +1,48 @@ | |||||
| """ | |||||
| Preprocess script. | |||||
| """ | |||||
| import glob | |||||
| import os | |||||
| from modelscope.preprocessors.space.args import parse_args | |||||
| from modelscope.preprocessors.space.fields.intent_field import \ | |||||
| IntentBPETextField | |||||
| FILE_NAME = 'train.json' | |||||
| def intent_preprocess(path, cfg): | |||||
| bpe = IntentBPETextField(path, cfg) | |||||
| args = cfg.Dataset | |||||
| build_examples_fn = bpe.build_examples_multi_turn if args.trigger_role == 'system' \ | |||||
| else bpe.build_examples_single_turn | |||||
| build_score_matrix_fn = bpe.build_score_matrix | |||||
| build_score_matrix_multiprocessing_fn = bpe.build_score_matrix_multiprocessing | |||||
| data_paths = list( | |||||
| os.path.dirname(c) for c in sorted( | |||||
| glob.glob(args.data_dir + '/**/' + FILE_NAME, recursive=True))) | |||||
| data_paths = bpe.filter_data_path(data_paths=data_paths) | |||||
| for mode in ['train', 'valid', 'test']: | |||||
| for data_path in data_paths: | |||||
| input_file = os.path.join(data_path, f'{mode}.json') | |||||
| output_file = os.path.join(data_path, | |||||
| f'{mode}.{bpe.tokenizer_type}.jsonl') | |||||
| output_score_file = os.path.join(data_path, f'{mode}.Score.npy') | |||||
| if os.path.exists(input_file) and not os.path.exists(output_file): | |||||
| examples = build_examples_fn(input_file, data_type=mode) | |||||
| if examples: | |||||
| bpe.save_examples(examples, output_file) | |||||
| else: | |||||
| continue | |||||
| if os.path.exists(output_file) and not os.path.exists(output_score_file) and \ | |||||
| not args.dynamic_score and 'AnPreDial' in data_path: | |||||
| examples = bpe.load_examples(output_file) | |||||
| if args.num_process >= 2: | |||||
| score_matrix = build_score_matrix_multiprocessing_fn( | |||||
| examples) | |||||
| else: | |||||
| score_matrix = build_score_matrix_fn(examples) | |||||
| bpe.save_examples(score_matrix, output_score_file) | |||||
| @@ -0,0 +1,75 @@ | |||||
| """ | |||||
| Sampler class. | |||||
| """ | |||||
| import numpy as np | |||||
| class Sampler(object): | |||||
| def __init__(self): | |||||
| return | |||||
| def __len__(self): | |||||
| raise NotImplementedError | |||||
| def __iter__(self): | |||||
| raise NotImplementedError | |||||
| class SequentialSampler(Sampler): | |||||
| def __init__(self, dataset): | |||||
| self.dataset = dataset | |||||
| return | |||||
| def __len__(self): | |||||
| return len(self.dataset) | |||||
| def __iter__(self): | |||||
| return iter(range(len(self))) | |||||
| class RandomSampler(Sampler): | |||||
| def __init__(self, dataset): | |||||
| self.dataset = dataset | |||||
| self.epoch = 0 | |||||
| return | |||||
| def __len__(self): | |||||
| return len(self.dataset) | |||||
| def __iter__(self): | |||||
| np.random.seed(self.epoch) | |||||
| self.epoch += 1 | |||||
| return iter(np.random.permutation(len(self))) | |||||
| class SortedSampler(Sampler): | |||||
| """ Sorted Sampler. | |||||
| Sort each block of examples by key. | |||||
| """ | |||||
| def __init__(self, sampler, sort_pool_size, key='src'): | |||||
| self.sampler = sampler | |||||
| self.sort_pool_size = sort_pool_size | |||||
| self.key = lambda idx: len(self.sampler.dataset[idx][key]) | |||||
| return | |||||
| def __len__(self): | |||||
| return len(self.sampler) | |||||
| def __iter__(self): | |||||
| pool = [] | |||||
| for idx in self.sampler: | |||||
| pool.append(idx) | |||||
| if len(pool) == self.sort_pool_size: | |||||
| pool = sorted(pool, key=self.key) | |||||
| for i in pool: | |||||
| yield i | |||||
| pool = [] | |||||
| if len(pool) > 0: | |||||
| pool = sorted(pool, key=self.key) | |||||
| for i in pool: | |||||
| yield i | |||||
| @@ -0,0 +1,134 @@ | |||||
| import os | |||||
| import time | |||||
| from typing import Callable, Dict, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.nlp.space.model.generator import Generator | |||||
| from modelscope.models.nlp.space.model.model_base import SpaceModelBase | |||||
| from modelscope.preprocessors.space.data_loader import \ | |||||
| get_sequential_data_loader | |||||
| from modelscope.preprocessors.space.fields.intent_field import \ | |||||
| IntentBPETextField | |||||
| from modelscope.preprocessors.space.preprocess import intent_preprocess | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.logger import get_logger | |||||
| PATH = None | |||||
| logger = get_logger(PATH) | |||||
| @TRAINERS.register_module(module_name=Trainers.dialog_intent_trainer) | |||||
| class DialogIntentTrainer(BaseTrainer): | |||||
| def __init__(self, | |||||
| cfg_file: Optional[str] = None, | |||||
| cfg_modify_fn: Optional[Callable] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) | |||||
| def to_tensor(array): | |||||
| """ | |||||
| numpy array -> tensor | |||||
| """ | |||||
| import torch | |||||
| array = torch.tensor(array) | |||||
| return array.cuda() if self.cfg.use_gpu else array | |||||
| def setup_seed(seed): | |||||
| import random | |||||
| import torch | |||||
| torch.manual_seed(seed) | |||||
| torch.cuda.manual_seed_all(seed) | |||||
| np.random.seed(seed) | |||||
| random.seed(seed) | |||||
| torch.backends.cudnn.deterministic = True | |||||
| self.cfg_modify_fn = cfg_modify_fn | |||||
| self.cfg = self.rebuild_config(self.cfg) | |||||
| setup_seed(self.cfg.Trainer.seed) | |||||
| # preprocess data | |||||
| intent_preprocess(self.cfg.Model.init_checkpoint, self.cfg) | |||||
| # set reader and evaluator | |||||
| bpe = IntentBPETextField(self.cfg.Model.init_checkpoint, self.cfg) | |||||
| self.cfg.Model.num_token_embeddings = bpe.vocab_size | |||||
| self.cfg.Model.num_turn_embeddings = bpe.max_ctx_turn + 1 | |||||
| dataset_paths = [ | |||||
| os.path.join(self.cfg.Dataset.data_dir, | |||||
| self.cfg.Dataset.trigger_data) | |||||
| ] | |||||
| # set data and data status | |||||
| collate_fn = bpe.collate_fn_multi_turn | |||||
| self.train_label_loader = get_sequential_data_loader( | |||||
| batch_size=self.cfg.Trainer.batch_size_label, | |||||
| reader=bpe, | |||||
| hparams=self.cfg, | |||||
| data_paths=dataset_paths, | |||||
| collate_fn=collate_fn, | |||||
| data_type='train') | |||||
| self.valid_label_loader = get_sequential_data_loader( | |||||
| batch_size=self.cfg.Trainer.batch_size_label, | |||||
| reader=bpe, | |||||
| hparams=self.cfg, | |||||
| data_paths=dataset_paths, | |||||
| collate_fn=collate_fn, | |||||
| data_type='valid') | |||||
| self.test_label_loader = get_sequential_data_loader( | |||||
| batch_size=self.cfg.Trainer.batch_size_label, | |||||
| reader=bpe, | |||||
| hparams=self.cfg, | |||||
| data_paths=dataset_paths, | |||||
| collate_fn=collate_fn, | |||||
| data_type='test') | |||||
| # set generator | |||||
| generator = Generator.create(self.cfg, reader=bpe) | |||||
| # construct model | |||||
| self.model = SpaceModelBase.create( | |||||
| self.cfg.Model.init_checkpoint, | |||||
| self.cfg, | |||||
| reader=bpe, | |||||
| generator=generator) | |||||
| import torch | |||||
| # multi-gpu | |||||
| if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: | |||||
| self.model = torch.nn.DataParallel(self.model) | |||||
| # construct trainer | |||||
| self.trainer = IntentTrainer( | |||||
| self.model, to_tensor, self.cfg, reader=bpe) | |||||
| num_batches = len(self.train_label_loader) | |||||
| self.trainer.set_optimizers(num_training_steps_per_epoch=num_batches) | |||||
| # load model, optimizer and lr_scheduler | |||||
| self.trainer.load() | |||||
| def rebuild_config(self, cfg: Config): | |||||
| if self.cfg_modify_fn is not None: | |||||
| return self.cfg_modify_fn(cfg) | |||||
| return cfg | |||||
| def train(self, *args, **kwargs): | |||||
| logger.info('Train') | |||||
| self.trainer.train( | |||||
| train_label_iter=self.train_label_loader, | |||||
| valid_label_iter=self.valid_label_loader) | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| logger.info('Evaluate') | |||||
| self.trainer.infer( | |||||
| data_iter=self.test_label_loader, | |||||
| ex_data_iter=self.train_label_loader) | |||||
| @@ -0,0 +1,101 @@ | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import json | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import DownloadMode, ModelFile, Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestDialogIntentTrainer(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.save_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.save_dir): | |||||
| os.mkdir(self.save_dir) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.save_dir) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| model_id = 'damo/nlp_space_pretrained-dialog-model' | |||||
| data_banking = MsDataset.load('banking77') | |||||
| self.data_dir = data_banking._hf_ds.config_kwargs['split_config'][ | |||||
| 'train'] | |||||
| self.model_dir = snapshot_download(model_id) | |||||
| self.debugging = True | |||||
| kwargs = dict( | |||||
| model_dir=self.model_dir, | |||||
| cfg_name='intent_train_config.json', | |||||
| cfg_modify_fn=self.cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.dialog_intent_trainer, default_args=kwargs) | |||||
| trainer.train() | |||||
| def cfg_modify_fn(self, cfg): | |||||
| config = { | |||||
| 'num_intent': 77, | |||||
| 'BPETextField': { | |||||
| 'vocab_path': '', | |||||
| 'data_name': 'banking77', | |||||
| 'data_root': self.data_dir, | |||||
| 'understand': True, | |||||
| 'generation': False, | |||||
| 'max_len': 256 | |||||
| }, | |||||
| 'Dataset': { | |||||
| 'data_dir': self.data_dir, | |||||
| 'with_contrastive': False, | |||||
| 'trigger_role': 'user', | |||||
| 'trigger_data': 'banking' | |||||
| }, | |||||
| 'Trainer': { | |||||
| 'can_norm': True, | |||||
| 'seed': 11, | |||||
| 'gpu': 1, | |||||
| 'save_dir': self.save_dir, | |||||
| 'batch_size_label': 128, | |||||
| 'batch_size_nolabel': 0, | |||||
| 'log_steps': 20 | |||||
| }, | |||||
| 'Model': { | |||||
| 'init_checkpoint': self.model_dir, | |||||
| 'model': 'IntentUnifiedTransformer', | |||||
| 'example': False, | |||||
| 'num_intent': 77, | |||||
| 'with_rdrop': True, | |||||
| 'num_turn_embeddings': 21, | |||||
| 'dropout': 0.25, | |||||
| 'kl_ratio': 5.0, | |||||
| 'embed_dropout': 0.25, | |||||
| 'attn_dropout': 0.25, | |||||
| 'ff_dropout': 0.25, | |||||
| 'with_pool': False, | |||||
| 'warmup_steps': -1 | |||||
| } | |||||
| } | |||||
| cfg.BPETextField.vocab_path = os.path.join(self.model_dir, | |||||
| ModelFile.VOCAB_FILE) | |||||
| cfg.num_intent = 77 | |||||
| cfg.Trainer.update(config['Trainer']) | |||||
| cfg.BPETextField.update(config['BPETextField']) | |||||
| cfg.Dataset.update(config['Dataset']) | |||||
| cfg.Model.update(config['Model']) | |||||
| if self.debugging: | |||||
| cfg.Trainer.save_checkpoint = False | |||||
| cfg.Trainer.num_epochs = 5 | |||||
| cfg.Trainer.batch_size_label = 64 | |||||
| return cfg | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||