import os from fastNLP import cache_results from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.core.vocabulary import VocabularyOption from fastNLP import Trainer from torch.optim import Adam from fastNLP import BucketSampler from fastNLP import GradientClipCallback from reproduction.seqence_labelling.cws.model.metric import RelayMetric # 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 @cache_results(None) def prepare_data(): data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, bigram_vocab_opt=bigram_vocab_opt, bigram_embed_opt=bigram_embed_opt, L=L) return data #########hyper L = 4 hidden_size = 200 num_layers = 1 drop_p = 0.2 lr = 0.02 #########hyper device = 0 # !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 # 你们的reproduction路径下,然后设置.gitignore file_dir = '/path/to/' char_embed_path = '/pretrain/vectors/1grams_t3_m50_corpus.txt' bigram_embed_path = '/pretrain/vectors/2grams_t3_m50_corpus.txt' bigram_vocab_opt = VocabularyOption(min_freq=3) char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) data_name = os.path.basename(file_dir) cache_fp = 'caches/{}.pkl'.format(data_name) data = prepare_data(_cache_fp=cache_fp, _refresh=True) model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], hidden_size=hidden_size, num_layers=num_layers, L=L, num_bigram_per_char=1, drop_p=drop_p) sampler = BucketSampler(batch_size=32) optimizer = Adam(model.parameters(), lr=lr) clipper = GradientClipCallback(clip_value=5, clip_type='value') callbacks = [clipper] # if pretrain: # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) # callbacks.append(fixer) trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, check_code_level=0) trainer.train()