|
|
|
@@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): |
|
|
|
self.padding_idx = args.tokenizer.pad_token_id |
|
|
|
self.args = args |
|
|
|
|
|
|
|
def forward(self, output, sample, update_num=0, reduce=True): |
|
|
|
def forward(self, model, sample, update_num=0, reduce=True): |
|
|
|
"""Compute the loss for the given sample. |
|
|
|
|
|
|
|
Returns a tuple with three elements: |
|
|
|
@@ -131,15 +131,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): |
|
|
|
2) the sample size, which is used as the denominator for the gradient |
|
|
|
3) logging outputs to display while training |
|
|
|
""" |
|
|
|
if 'labels' in sample: |
|
|
|
del sample['labels'] |
|
|
|
if 'samples' in sample: |
|
|
|
del sample['samples'] |
|
|
|
|
|
|
|
if self.use_rdrop: |
|
|
|
construct_rdrop_sample(sample) |
|
|
|
|
|
|
|
output = model.model(**sample['net_input']) |
|
|
|
loss, nll_loss, ntokens = self.compute_loss( |
|
|
|
output, sample, update_num, reduce=reduce) |
|
|
|
output.logits, sample, update_num, reduce=reduce) |
|
|
|
sample_size = ( |
|
|
|
sample['target'].size(0) if self.sentence_avg else ntokens) |
|
|
|
logging_output = { |
|
|
|
'loss': loss.data, |
|
|
|
'loss': loss.data / 100, |
|
|
|
'nll_loss': nll_loss.data, |
|
|
|
'ntokens': sample['ntokens'], |
|
|
|
'nsentences': sample['nsentences'], |
|
|
|
@@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): |
|
|
|
} |
|
|
|
return loss, sample_size, logging_output |
|
|
|
|
|
|
|
def get_lprobs_and_target(self, net_output, sample): |
|
|
|
def get_lprobs_and_target(self, logits, sample): |
|
|
|
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ |
|
|
|
'conf'] is not None else 1 |
|
|
|
constraint_masks = None |
|
|
|
if 'constraint_masks' in sample and sample[ |
|
|
|
'constraint_masks'] is not None: |
|
|
|
constraint_masks = sample['constraint_masks'] |
|
|
|
net_output[0].masked_fill_(~constraint_masks, -math.inf) |
|
|
|
logits.masked_fill_(~constraint_masks, -math.inf) |
|
|
|
if self.constraint_start is not None and self.constraint_end is not None: |
|
|
|
net_output[0][:, :, 4:self.constraint_start] = -math.inf |
|
|
|
net_output[0][:, :, self.constraint_end:] = -math.inf |
|
|
|
lprobs = F.log_softmax( |
|
|
|
net_output[0], dim=-1, dtype=torch.float32) * conf |
|
|
|
logits[:, :, 4:self.constraint_start] = -math.inf |
|
|
|
logits[:, :, self.constraint_end:] = -math.inf |
|
|
|
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf |
|
|
|
target = sample['target'] |
|
|
|
if self.ignore_prefix_size > 0: |
|
|
|
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() |
|
|
|
@@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): |
|
|
|
return lprobs.view(-1, |
|
|
|
lprobs.size(-1)), target.view(-1), constraint_masks |
|
|
|
|
|
|
|
def compute_loss(self, net_output, sample, update_num, reduce=True): |
|
|
|
def compute_loss(self, logits, sample, update_num, reduce=True): |
|
|
|
lprobs, target, constraint_masks = self.get_lprobs_and_target( |
|
|
|
net_output, sample) |
|
|
|
logits, sample) |
|
|
|
if constraint_masks is not None: |
|
|
|
constraint_masks = constraint_masks[target != self.padding_idx] |
|
|
|
lprobs = lprobs[target != self.padding_idx] |
|
|
|
|