You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

enas_trainer.py 15 kB

Dev0.4.0 (#149) * 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释 * BucketSampler增加一条错误检测 * 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释 * update MLP module * 增加metric注释;修改trainer save过程中的bug * Update README.md fix tutorial link * Add ENAS (Efficient Neural Architecture Search) * add ignore_type in DataSet.add_field * * AutoPadder will not pad when dtype is None * add ignore_type in DataSet.apply * 修复fieldarray中padder潜在bug * 修复crf中typo; 以及可能导致数值不稳定的地方 * 修复CRF中可能存在的bug * change two default init arguments of Trainer into None * Changes to Callbacks: * 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担 * * 将enas相关代码放到automl目录下 * 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容 * * 给vocabulary添加遍历方法 * 修复CRF为负数的bug * add SQuAD metric * add sigmoid activate function in MLP * - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance * - fix test * - fix callback & tests * - update README * 修改部分bug;调整callback * 准备发布0.4.0版本“ * update readme * support parallel loss * 防止多卡的情况导致无法正确计算loss“ * update advance_tutorial jupyter notebook * 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove. 2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性 * 1.DataSet.apply()报错时提供错误的index 2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序 3.embedloader在embed读取时遇到不规则的数据跳过这一行. * update attention * doc tools * fix some doc errors * 修改为中文注释,增加viterbi解码方法 * 样例版本 * - add pad sequence for lstm - add csv, conll, json filereader - update dataloader - remove useless dataloader - fix trainer loss print - fix tests * - fix test_tutorial * 注释增加 * 测试文档 * 本地暂存 * 本地暂存 * 修改文档的顺序 * - add document * 本地暂存 * update pooling * update bert * update documents in MLP * update documents in snli * combine self attention module to attention.py * update documents on losses.py * 对DataSet的文档进行更新 * update documents on metrics * 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档 * 增加对Trainer的注释 * 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏 * update char level encoder * update documents on embedding.py * - update doc * 补充注释,并修改部分代码 * - update doc - add get_embeddings * 修改了文档配置项 * 修改embedding为init_embed初始化 * 1.增加对Trainer和Tester的多卡支持; * - add test - fix jsonloader * 删除了注释教程 * 给 dataset 增加了get_field_names * 修复bug * - add Const - fix bugs * 修改部分注释 * - add model runner for easier test models - add model tests * 修改了 docs 的配置和架构 * 修改了核心部分的一大部分文档,TODO: 1. 完善 trainer 和 tester 部分的文档 2. 研究注释样例与测试 * core部分的注释基本检查完成 * 修改了 io 部分的注释 * 全部改为相对路径引用 * 全部改为相对路径引用 * small change * 1. 从安装文件中删除api/automl的安装 2. metric中存在seq_len的bug 3. sampler中存在命名错误,已修改 * 修复 bug :兼容 cpu 版本的 PyTorch TODO:其它地方可能也存在类似的 bug * 修改文档中的引用部分 * 把 tqdm.autonotebook 换成tqdm.auto * - fix batch & vocab * 上传了文档文件 *.rst * 上传了文档文件和若干 TODO * 讨论并整合了若干模块 * core部分的测试和一些小修改 * 删除了一些冗余文档 * update init files * update const files * update const files * 增加cnn的测试 * fix a little bug * - update attention - fix tests * 完善测试 * 完成快速入门教程 * 修改了sequence_modeling 命名为 sequence_labeling 的文档 * 重新 apidoc 解决改名的遗留问题 * 修改文档格式 * 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask * 增加了一行提示 * 在文档中展示 dataset_loader * 提示 Dataset.read_csv 会被 CSVLoader 替换 * 完成 Callback 和 Trainer 之间的文档 * index更新了部分 * 删除冗余的print * 删除用于分词的metric,因为有可能引起错误 * 修改文档中的中文名称 * 完成了详细介绍文档 * tutorial 的 ipynb 文件 * 修改了一些介绍文档 * 修改了 models 和 modules 的主页介绍 * 加上了 titlesonly 这个设置 * 修改了模块文档展示的标题 * 修改了 core 和 io 的开篇介绍 * 修改了 modules 和 models 开篇介绍 * 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释 * 修改了一些注释 * delete an old metric in test * 修改 tutorials 的测试文件 * 把暂不发布的功能移到 legacy 文件夹 * 删除了不能运行的测试 * 修改 callback 的测试文件 * 删除了过时的教程和测试文件 * cache_results 参数的修改 * 修改 io 的测试文件; 删除了一些过时的测试 * 修复bug * 修复无法通过test_utils.py的测试 * 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar * 1. 修复metric中的bug; 2.增加metric测试 * add model summary * 增加别名 * 删除encoder中的嵌套层 * 修改了 core 部分 import 的顺序,__all__ 暴露的内容 * 修改了 models 部分 import 的顺序,__all__ 暴露的内容 * 修改了文件名 * 修改了 modules 模块的__all__ 和 import * fix var runn * 增加vocab的clear方法 * 一些符合 PEP8 的微调 * 更新了cache_results的例子 * 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index * 修改了一个typo * 修改了 README.md * update documents on bert * update documents on encoder/bert * 增加一个fitlog callback,实现与fitlog实验记录 * typo * - update dataset_loader * 增加了到 fitlog 文档的链接。 * 增加了 DataSet Loader 的文档 * - add star-transformer reproduction
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Code Modified from https://github.com/carpedm20/ENAS-pytorch
  2. import math
  3. import time
  4. from datetime import datetime
  5. from datetime import timedelta
  6. import numpy as np
  7. import torch
  8. try:
  9. from tqdm.auto import tqdm
  10. except:
  11. from fastNLP.core.utils import _pseudo_tqdm as tqdm
  12. from fastNLP.core.batch import Batch
  13. from fastNLP.core.callback import CallbackException
  14. from fastNLP.core.dataset import DataSet
  15. from fastNLP.core.utils import _move_dict_value_to_device
  16. import fastNLP
  17. from . import enas_utils as utils
  18. from fastNLP.core.utils import _build_args
  19. from torch.optim import Adam
  20. def _get_no_grad_ctx_mgr():
  21. """Returns a the `torch.no_grad` context manager for PyTorch version >=
  22. 0.4, or a no-op context manager otherwise.
  23. """
  24. return torch.no_grad()
  25. class ENASTrainer(fastNLP.Trainer):
  26. """A class to wrap training code."""
  27. def __init__(self, train_data, model, controller, **kwargs):
  28. """Constructor for training algorithm.
  29. :param DataSet train_data: the training data
  30. :param torch.nn.modules.module model: a PyTorch model
  31. :param torch.nn.modules.module controller: a PyTorch model
  32. """
  33. self.final_epochs = kwargs['final_epochs']
  34. kwargs.pop('final_epochs')
  35. super(ENASTrainer, self).__init__(train_data, model, **kwargs)
  36. self.controller_step = 0
  37. self.shared_step = 0
  38. self.max_length = 35
  39. self.shared = model
  40. self.controller = controller
  41. self.shared_optim = Adam(
  42. self.shared.parameters(),
  43. lr=20.0,
  44. weight_decay=1e-7)
  45. self.controller_optim = Adam(
  46. self.controller.parameters(),
  47. lr=3.5e-4)
  48. def train(self, load_best_model=True):
  49. """
  50. :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
  51. 最好的模型参数。
  52. :return results: 返回一个字典类型的数据,
  53. 内含以下内容::
  54. seconds: float, 表示训练时长
  55. 以下三个内容只有在提供了dev_data的情况下会有。
  56. best_eval: Dict of Dict, 表示evaluation的结果
  57. best_epoch: int,在第几个epoch取得的最佳值
  58. best_step: int, 在第几个step(batch)更新取得的最佳值
  59. """
  60. results = {}
  61. if self.n_epochs <= 0:
  62. print(f"training epoch is {self.n_epochs}, nothing was done.")
  63. results['seconds'] = 0.
  64. return results
  65. try:
  66. if torch.cuda.is_available() and self.use_cuda:
  67. self.model = self.model.cuda()
  68. self._model_device = self.model.parameters().__next__().device
  69. self._mode(self.model, is_test=False)
  70. self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
  71. start_time = time.time()
  72. print("training epochs started " + self.start_time, flush=True)
  73. try:
  74. self.callback_manager.on_train_begin()
  75. self._train()
  76. self.callback_manager.on_train_end(self.model)
  77. except (CallbackException, KeyboardInterrupt) as e:
  78. self.callback_manager.on_exception(e, self.model)
  79. if self.dev_data is not None:
  80. print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
  81. self.tester._format_eval_results(self.best_dev_perf),)
  82. results['best_eval'] = self.best_dev_perf
  83. results['best_epoch'] = self.best_dev_epoch
  84. results['best_step'] = self.best_dev_step
  85. if load_best_model:
  86. model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
  87. load_succeed = self._load_model(self.model, model_name)
  88. if load_succeed:
  89. print("Reloaded the best model.")
  90. else:
  91. print("Fail to reload best model.")
  92. finally:
  93. pass
  94. results['seconds'] = round(time.time() - start_time, 2)
  95. return results
  96. def _train(self):
  97. if not self.use_tqdm:
  98. from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
  99. else:
  100. inner_tqdm = tqdm
  101. self.step = 0
  102. start = time.time()
  103. total_steps = (len(self.train_data) // self.batch_size + int(
  104. len(self.train_data) % self.batch_size != 0)) * self.n_epochs
  105. with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
  106. avg_loss = 0
  107. data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
  108. prefetch=self.prefetch)
  109. for epoch in range(1, self.n_epochs+1):
  110. pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
  111. last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
  112. if epoch == self.n_epochs + 1 - self.final_epochs:
  113. print('Entering the final stage. (Only train the selected structure)')
  114. # early stopping
  115. self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
  116. # 1. Training the shared parameters omega of the child models
  117. self.train_shared(pbar)
  118. # 2. Training the controller parameters theta
  119. if not last_stage:
  120. self.train_controller()
  121. if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
  122. (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
  123. and self.dev_data is not None:
  124. if not last_stage:
  125. self.derive()
  126. eval_res = self._do_validation(epoch=epoch, step=self.step)
  127. eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
  128. total_steps) + \
  129. self.tester._format_eval_results(eval_res)
  130. pbar.write(eval_str)
  131. # lr decay; early stopping
  132. self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
  133. # =============== epochs end =================== #
  134. pbar.close()
  135. # ============ tqdm end ============== #
  136. def get_loss(self, inputs, targets, hidden, dags):
  137. """Computes the loss for the same batch for M models.
  138. This amounts to an estimate of the loss, which is turned into an
  139. estimate for the gradients of the shared model.
  140. """
  141. if not isinstance(dags, list):
  142. dags = [dags]
  143. loss = 0
  144. for dag in dags:
  145. self.shared.setDAG(dag)
  146. inputs = _build_args(self.shared.forward, **inputs)
  147. inputs['hidden'] = hidden
  148. result = self.shared(**inputs)
  149. output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out']
  150. self.callback_manager.on_loss_begin(targets, result)
  151. sample_loss = self._compute_loss(result, targets)
  152. loss += sample_loss
  153. assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
  154. return loss, hidden, extra_out
  155. def train_shared(self, pbar=None, max_step=None, dag=None):
  156. """Train the language model for 400 steps of minibatches of 64
  157. examples.
  158. Args:
  159. max_step: Used to run extra training steps as a warm-up.
  160. dag: If not None, is used instead of calling sample().
  161. BPTT is truncated at 35 timesteps.
  162. For each weight update, gradients are estimated by sampling M models
  163. from the fixed controller policy, and averaging their gradients
  164. computed on a batch of training data.
  165. """
  166. model = self.shared
  167. model.train()
  168. self.controller.eval()
  169. hidden = self.shared.init_hidden(self.batch_size)
  170. abs_max_grad = 0
  171. abs_max_hidden_norm = 0
  172. step = 0
  173. raw_total_loss = 0
  174. total_loss = 0
  175. train_idx = 0
  176. avg_loss = 0
  177. data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
  178. prefetch=self.prefetch)
  179. for batch_x, batch_y in data_iterator:
  180. _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
  181. indices = data_iterator.get_batch_indices()
  182. # negative sampling; replace unknown; re-weight batch_y
  183. self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
  184. # prediction = self._data_forward(self.model, batch_x)
  185. dags = self.controller.sample(1)
  186. inputs, targets = batch_x, batch_y
  187. # self.callback_manager.on_loss_begin(batch_y, prediction)
  188. loss, hidden, extra_out = self.get_loss(inputs,
  189. targets,
  190. hidden,
  191. dags)
  192. hidden.detach_()
  193. avg_loss += loss.item()
  194. # Is loss NaN or inf? requires_grad = False
  195. self.callback_manager.on_backward_begin(loss, self.model)
  196. self._grad_backward(loss)
  197. self.callback_manager.on_backward_end(self.model)
  198. self._update()
  199. self.callback_manager.on_step_end(self.optimizer)
  200. if (self.step+1) % self.print_every == 0:
  201. if self.use_tqdm:
  202. print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
  203. pbar.update(self.print_every)
  204. else:
  205. end = time.time()
  206. diff = timedelta(seconds=round(end - start))
  207. print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
  208. epoch, self.step, avg_loss, diff)
  209. pbar.set_postfix_str(print_output)
  210. avg_loss = 0
  211. self.step += 1
  212. step += 1
  213. self.shared_step += 1
  214. self.callback_manager.on_batch_end()
  215. # ================= mini-batch end ==================== #
  216. def get_reward(self, dag, entropies, hidden, valid_idx=0):
  217. """Computes the perplexity of a single sampled model on a minibatch of
  218. validation data.
  219. """
  220. if not isinstance(entropies, np.ndarray):
  221. entropies = entropies.data.cpu().numpy()
  222. data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
  223. prefetch=self.prefetch)
  224. for inputs, targets in data_iterator:
  225. valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
  226. valid_loss = utils.to_item(valid_loss.data)
  227. valid_ppl = math.exp(valid_loss)
  228. R = 80 / valid_ppl
  229. rewards = R + 1e-4 * entropies
  230. return rewards, hidden
  231. def train_controller(self):
  232. """Fixes the shared parameters and updates the controller parameters.
  233. The controller is updated with a score function gradient estimator
  234. (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
  235. is computed on a minibatch of validation data.
  236. A moving average baseline is used.
  237. The controller is trained for 2000 steps per epoch (i.e.,
  238. first (Train Shared) phase -> second (Train Controller) phase).
  239. """
  240. model = self.controller
  241. model.train()
  242. # Why can't we call shared.eval() here? Leads to loss
  243. # being uniformly zero for the controller.
  244. # self.shared.eval()
  245. avg_reward_base = None
  246. baseline = None
  247. adv_history = []
  248. entropy_history = []
  249. reward_history = []
  250. hidden = self.shared.init_hidden(self.batch_size)
  251. total_loss = 0
  252. valid_idx = 0
  253. for step in range(20):
  254. # sample models
  255. dags, log_probs, entropies = self.controller.sample(
  256. with_details=True)
  257. # calculate reward
  258. np_entropies = entropies.data.cpu().numpy()
  259. # No gradients should be backpropagated to the
  260. # shared model during controller training, obviously.
  261. with _get_no_grad_ctx_mgr():
  262. rewards, hidden = self.get_reward(dags,
  263. np_entropies,
  264. hidden,
  265. valid_idx)
  266. reward_history.extend(rewards)
  267. entropy_history.extend(np_entropies)
  268. # moving average baseline
  269. if baseline is None:
  270. baseline = rewards
  271. else:
  272. decay = 0.95
  273. baseline = decay * baseline + (1 - decay) * rewards
  274. adv = rewards - baseline
  275. adv_history.extend(adv)
  276. # policy loss
  277. loss = -log_probs*utils.get_variable(adv,
  278. self.use_cuda,
  279. requires_grad=False)
  280. loss = loss.sum() # or loss.mean()
  281. # update
  282. self.controller_optim.zero_grad()
  283. loss.backward()
  284. self.controller_optim.step()
  285. total_loss += utils.to_item(loss.data)
  286. if ((step % 50) == 0) and (step > 0):
  287. reward_history, adv_history, entropy_history = [], [], []
  288. total_loss = 0
  289. self.controller_step += 1
  290. # prev_valid_idx = valid_idx
  291. # valid_idx = ((valid_idx + self.max_length) %
  292. # (self.valid_data.size(0) - 1))
  293. # # Whenever we wrap around to the beginning of the
  294. # # validation data, we reset the hidden states.
  295. # if prev_valid_idx > valid_idx:
  296. # hidden = self.shared.init_hidden(self.batch_size)
  297. def derive(self, sample_num=10, valid_idx=0):
  298. """We are always deriving based on the very first batch
  299. of validation data? This seems wrong...
  300. """
  301. hidden = self.shared.init_hidden(self.batch_size)
  302. dags, _, entropies = self.controller.sample(sample_num,
  303. with_details=True)
  304. max_R = 0
  305. best_dag = None
  306. for dag in dags:
  307. R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
  308. if R.max() > max_R:
  309. max_R = R.max()
  310. best_dag = dag
  311. self.model.setDAG(best_dag)