| @@ -26,7 +26,7 @@ from .utils import _build_args | |||
| from .utils import _check_arg_dict_list | |||
| from .utils import _check_function_or_method | |||
| from .utils import _get_func_signature | |||
| from .utils import seq_len_to_mask | |||
| class LossBase(object): | |||
| """ | |||
| @@ -223,7 +223,9 @@ class CrossEntropyLoss(LossBase): | |||
| :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
| :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
| :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 | |||
| :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | |||
| :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
| 传入seq_len. | |||
| Example:: | |||
| @@ -231,16 +233,18 @@ class CrossEntropyLoss(LossBase): | |||
| """ | |||
| def __init__(self, pred=None, target=None, padding_idx=-100): | |||
| def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): | |||
| super(CrossEntropyLoss, self).__init__() | |||
| self._init_param_map(pred=pred, target=target) | |||
| self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
| self.padding_idx = padding_idx | |||
| def get_loss(self, pred, target): | |||
| def get_loss(self, pred, target, seq_len=None): | |||
| if pred.dim()>2: | |||
| if pred.size()[:2]==target.size(): | |||
| # F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 | |||
| pred = pred.transpose(1, 2) | |||
| pred = pred.view(-1, pred.size(-1)) | |||
| target = target.view(-1) | |||
| if seq_len is not None: | |||
| mask = seq_len_to_mask(seq_len).view(-1).eq(0) | |||
| target = target.masked_fill(mask, self.padding_idx) | |||
| return F.cross_entropy(input=pred, target=target, | |||
| ignore_index=self.padding_idx) | |||
| @@ -452,17 +452,15 @@ class Trainer(object): | |||
| else: | |||
| raise TypeError("train_data type {} not support".format(type(train_data))) | |||
| self.model = _move_model_to_device(model, device=device) | |||
| if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
| _check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
| _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
| metric_key=metric_key, check_level=check_code_level, | |||
| batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
| # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
| self.model = _move_model_to_device(model, device=device) | |||
| self.train_data = train_data | |||
| self.dev_data = dev_data # If None, No validation. | |||
| self.model = model | |||
| self.losser = losser | |||
| self.metrics = metrics | |||
| self.n_epochs = int(n_epochs) | |||
| @@ -480,16 +478,16 @@ class Trainer(object): | |||
| if isinstance(optimizer, torch.optim.Optimizer): | |||
| self.optimizer = optimizer | |||
| elif isinstance(optimizer, Optimizer): | |||
| self.optimizer = optimizer.construct_from_pytorch(model.parameters()) | |||
| self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||
| elif optimizer is None: | |||
| self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) | |||
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
| else: | |||
| raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
| self.use_tqdm = use_tqdm | |||
| self.pbar = None | |||
| self.print_every = abs(self.print_every) | |||
| if self.dev_data is not None: | |||
| self.tester = Tester(model=self.model, | |||
| data=self.dev_data, | |||