| @@ -978,7 +978,7 @@ class SaveModelCallback(Callback): | |||||
| return save_pair, delete_pair | return save_pair, delete_pair | ||||
| def _save_this_model(self, metric_value): | def _save_this_model(self, metric_value): | ||||
| name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
| name = "epoch-{}_step-{}_{}-{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value) | |||||
| save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name)) | ||||
| if save_pair: | if save_pair: | ||||
| try: | try: | ||||
| @@ -995,7 +995,7 @@ class SaveModelCallback(Callback): | |||||
| def on_exception(self, exception): | def on_exception(self, exception): | ||||
| if self.save_on_exception: | if self.save_on_exception: | ||||
| name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
| name = "epoch-{}_step-{}_Exception-{}.pt".format(self.epoch, self.step, exception.__class__.__name__) | |||||
| _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | ||||
| @@ -148,7 +148,7 @@ class TestDistTrainer(unittest.TestCase): | |||||
| def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| ngpu = min(2, torch.cuda.device_count()) | ngpu = min(2, torch.cuda.device_count()) | ||||
| path = __file__ | |||||
| path = os.path.abspath(__file__) | |||||
| cmd = ['python', '-m', 'torch.distributed.launch', | cmd = ['python', '-m', 'torch.distributed.launch', | ||||
| '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | ||||
| print(' '.join(cmd)) | print(' '.join(cmd)) | ||||
| @@ -6,12 +6,24 @@ from fastNLP.models.cnn_text_classification import CNNText | |||||
| class TestCNNText(unittest.TestCase): | class TestCNNText(unittest.TestCase): | ||||
| def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): | |||||
| model = CNNText((VOCAB_SIZE, 30), | |||||
| NUM_CLS, | |||||
| kernel_nums=kernel_nums, | |||||
| kernel_sizes=kernel_sizes) | |||||
| return model | |||||
| def test_case1(self): | def test_case1(self): | ||||
| # 测试能否正常运行CNN | # 测试能否正常运行CNN | ||||
| init_emb = (VOCAB_SIZE, 30) | |||||
| model = CNNText(init_emb, | |||||
| NUM_CLS, | |||||
| kernel_nums=(1, 3, 5), | |||||
| kernel_sizes=(1, 3, 5), | |||||
| dropout=0.5) | |||||
| model = self.init_model((1,3,5)) | |||||
| RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
| def test_init_model(self): | |||||
| self.assertRaises(Exception, self.init_model, (2,4)) | |||||
| self.assertRaises(Exception, self.init_model, (2,)) | |||||
| def test_output(self): | |||||
| model = self.init_model((3,), (1,)) | |||||
| global MAX_LEN | |||||
| MAX_LEN = 2 | |||||
| RUNNER.run_model_with_task(TEXT_CLS, model) | RUNNER.run_model_with_task(TEXT_CLS, model) | ||||