| @@ -243,6 +243,8 @@ class DataSet(object): | |||
| :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | |||
| :return: | |||
| """ | |||
| if field_name not in self.field_arrays: | |||
| raise KeyError("There is no field named {}.".format(field_name)) | |||
| self.field_arrays[field_name].set_padder(padder) | |||
| def set_pad_val(self, field_name, pad_val): | |||
| @@ -253,6 +255,8 @@ class DataSet(object): | |||
| :param pad_val: int,该field的padder会以pad_val作为padding index | |||
| :return: | |||
| """ | |||
| if field_name not in self.field_arrays: | |||
| raise KeyError("There is no field named {}.".format(field_name)) | |||
| self.field_arrays[field_name].set_pad_val(pad_val) | |||
| def get_input_name(self): | |||
| @@ -206,7 +206,7 @@ class FieldArray(object): | |||
| if list in type_set: | |||
| if len(type_set) > 1: | |||
| # list 跟 非list 混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
| # >1维list | |||
| inner_type_set = set() | |||
| for l in content: | |||
| @@ -229,7 +229,7 @@ class FieldArray(object): | |||
| return self._basic_type_detection(inner_inner_type_set) | |||
| else: | |||
| # list 跟 非list 混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) | |||
| else: | |||
| # 一维list | |||
| for content_type in type_set: | |||
| @@ -253,17 +253,17 @@ class FieldArray(object): | |||
| return float | |||
| else: | |||
| # str 跟 int 或者 float 混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
| else: | |||
| # str, int, float混在一起 | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
| def _1d_list_check(self, val): | |||
| """如果不是1D list就报错 | |||
| """ | |||
| type_set = set((type(obj) for obj in val)) | |||
| if any(obj not in self.BASIC_TYPES for obj in type_set): | |||
| raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
| raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
| self._basic_type_detection(type_set) | |||
| # otherwise: _basic_type_detection will raise error | |||
| return True | |||
| @@ -192,7 +192,7 @@ class ConditionalRandomField(nn.Module): | |||
| seq_len, batch_size, n_tags = logits.size() | |||
| alpha = logits[0] | |||
| if self.include_start_end_trans: | |||
| alpha += self.start_scores.view(1, -1) | |||
| alpha = alpha + self.start_scores.view(1, -1) | |||
| flip_mask = mask.eq(0) | |||
| @@ -204,7 +204,7 @@ class ConditionalRandomField(nn.Module): | |||
| alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | |||
| if self.include_start_end_trans: | |||
| alpha += self.end_scores.view(1, -1) | |||
| alpha = alpha + self.end_scores.view(1, -1) | |||
| return log_sum_exp(alpha, 1) | |||
| @@ -233,7 +233,7 @@ class ConditionalRandomField(nn.Module): | |||
| st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
| last_idx = mask.long().sum(0) - 1 | |||
| ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||
| score += st_scores + ed_scores | |||
| score = score + st_scores + ed_scores | |||
| # return [B,] | |||
| return score | |||
| @@ -216,6 +216,11 @@ class TestDataSetMethods(unittest.TestCase): | |||
| self.assertTrue(isinstance(ds, DataSet)) | |||
| self.assertTrue(len(ds) > 0) | |||
| def test_add_null(self): | |||
| ds = DataSet() | |||
| ds.add_field('test', []) | |||
| ds.set_target('test') | |||
| class TestDataSetIter(unittest.TestCase): | |||
| def test__repr__(self): | |||
| @@ -101,4 +101,28 @@ class TestCRF(unittest.TestCase): | |||
| # # seq equal | |||
| # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||
| def test_case3(self): | |||
| # 测试crf的loss不会出现负数 | |||
| import torch | |||
| from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
| from fastNLP.core.utils import seq_lens_to_masks | |||
| from torch import optim | |||
| from torch import nn | |||
| num_tags, include_start_end_trans = 4, True | |||
| num_samples = 4 | |||
| lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||
| max_len = lengths.max() | |||
| tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||
| masks = seq_lens_to_masks(lengths) | |||
| feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||
| crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||
| optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||
| for _ in range(10000): | |||
| loss = crf(feats, tags, masks).mean() | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| optimizer.step() | |||
| if _%1000==0: | |||
| print(loss) | |||
| assert loss.item()>0, "CRF loss cannot be less than 0." | |||