| @@ -51,7 +51,7 @@ from fastNLP.transformers.torch import BertTokenizer | |||||
| # 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。 | # 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。 | ||||
| @cache_results('caches/cache.pkl') | @cache_results('caches/cache.pkl') | ||||
| def prepare_data(): | def prepare_data(): | ||||
| # 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 | |||||
| # 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 | |||||
| data_bundle = ChnSentiCorpLoader().load() | data_bundle = ChnSentiCorpLoader().load() | ||||
| # 使用tokenizer对数据进行tokenize | # 使用tokenizer对数据进行tokenize | ||||
| tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm') | tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm') | ||||
| @@ -130,7 +130,7 @@ evaluator.run() | |||||
| from fastNLP.io import ChnSentiCorpLoader | from fastNLP.io import ChnSentiCorpLoader | ||||
| from functools import partial | from functools import partial | ||||
| # 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 | |||||
| # 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的 | |||||
| data_bundle = ChnSentiCorpLoader().load() | data_bundle = ChnSentiCorpLoader().load() | ||||
| # 使用tokenizer对数据进行tokenize | # 使用tokenizer对数据进行tokenize | ||||
| @@ -50,6 +50,8 @@ class Saver: | |||||
| self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model' | self.save_fn_name = 'save_checkpoint' if save_object == 'trainer' else 'save_model' | ||||
| self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | ||||
| # 打印这次运行时 checkpoint 所保存在的文件夹,因为这个文件夹是根据时间实时生成的,因此需要打印出来防止用户混淆; | |||||
| logger.info(f"The checkpoint will be saved in this folder for this time: {self.timestamp_path}.") | |||||
| def save(self, trainer, folder_name): | def save(self, trainer, folder_name): | ||||
| """ | """ | ||||
| @@ -199,7 +199,8 @@ class TorchDriver(Driver): | |||||
| f"`only_state_dict=False`") | f"`only_state_dict=False`") | ||||
| if not isinstance(res, dict): | if not isinstance(res, dict): | ||||
| res = res.state_dict() | res = res.state_dict() | ||||
| model.load_state_dict(res) | |||||
| _strict = kwargs.get("strict", True) | |||||
| model.load_state_dict(res, _strict) | |||||
| @rank_zero_call | @rank_zero_call | ||||
| def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | ||||