# 使用Metric快速评测你的模型

和上一篇教程一样的实验准备代码

In [1]:
from fastNLP.io import SST2Pipe
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
from fastNLP.models import CNNText
from fastNLP import CrossEntropyLoss
import torch
from torch.optim import Adam
from fastNLP import AccuracyMetric

databundle = SST2Pipe().process_from_file()
vocab = databundle.get_vocab('words')
train_data = databundle.get_dataset('train')[:5000]
train_data, test_data = train_data.split(0.015)
dev_data = databundle.get_dataset('dev')

model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)
loss = CrossEntropyLoss()
metric = AccuracyMetric()
optimizer = Adam(model.parameters(), lr=0.001)
device = 0 if torch.cuda.is_available() else 'cpu'



进行训练时，fastNLP提供了各种各样的 metrics 。 如前面的教程中所介绍，AccuracyMetric 类的对象被直接传到 Trainer 中用于训练

In [2]:
trainer = Trainer(train_data=train_data, model=model, loss=loss,
                  optimizer=optimizer, batch_size=32, dev_data=dev_data,
                  metrics=metric, device=device)
trainer.train()

input fields after batch(if batch size is 2):
	words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) 
	seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 
target fields after batch(if batch size is 2):
	target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 

training epochs started 2020-02-28-00-11-51


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.16 seconds!
Evaluation on dev at Epoch 1/10. Step:154/1540: 
AccuracyMetric: acc=0.722477



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.36 seconds!
Evaluation on dev at Epoch 2/10. Step:308/1540: 
AccuracyMetric: acc=0.762615



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.16 seconds!
Evaluation on dev at Epoch 3/10. Step:462/1540: 
AccuracyMetric: acc=0.771789



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.44 seconds!
Evaluation on dev at Epoch 4/10. Step:616/1540: 
AccuracyMetric: acc=0.759174



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.29 seconds!
Evaluation on dev at Epoch 5/10. Step:770/1540: 
AccuracyMetric: acc=0.75344



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.33 seconds!
Evaluation on dev at Epoch 6/10. Step:924/1540: 
AccuracyMetric: acc=0.75



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.19 seconds!
Evaluation on dev at Epoch 7/10. Step:1078/1540: 
AccuracyMetric: acc=0.741972



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.49 seconds!
Evaluation on dev at Epoch 8/10. Step:1232/1540: 
AccuracyMetric: acc=0.740826



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.15 seconds!
Evaluation on dev at Epoch 9/10. Step:1386/1540: 
AccuracyMetric: acc=0.75



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.16 seconds!
Evaluation on dev at Epoch 10/10. Step:1540/1540: 
AccuracyMetric: acc=0.752294


In Epoch:3/Step:462, got best dev performance:
AccuracyMetric: acc=0.771789
Reloaded the best model.


{'best_eval': {'AccuracyMetric': {'acc': 0.771789}},
 'best_epoch': 3,
 'best_step': 462,
 'seconds': 30.04}

除了 AccuracyMetric 之外，SpanFPreRecMetric 也是一种非常见的评价指标， 例如在序列标注问题中，常以span的方式计算 F-measure, precision, recall。

另外，fastNLP 还实现了用于抽取式QA（如SQuAD）的metric ExtractiveQAMetric。 用户可以参考下面这个表格。

| 名称                 | 介绍                                              |
| -------------------- | ------------------------------------------------- |
| `MetricBase`         | 自定义metrics需继承的基类                         |
| `AccuracyMetric`     | 简单的正确率metric                                |
| `SpanFPreRecMetric`  | 同时计算 F-measure, precision, recall 值的 metric |
| `ExtractiveQAMetric` | 用于抽取式QA任务 的metric                         |



## 定义自己的metrics

在定义自己的metrics类时需继承 fastNLP 的 MetricBase, 并覆盖写入 evaluate 和 get_metric 方法。

- evaluate(xxx) 中传入一个批次的数据，将针对一个批次的预测结果做评价指标的累计

- get_metric(xxx) 当所有数据处理完毕时调用该方法，它将根据 evaluate函数累计的评价指标统计量来计算最终的评价结果

以分类问题中，Accuracy计算为例，假设model的forward返回dict中包含 pred 这个key, 并且该key需要用于Accuracy:

```python
class Model(nn.Module):
    def __init__(xxx):
        # do something
    def forward(self, xxx):
        # do something
        return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
```

### Version 1

假设dataset中 `target` 这个 field 是需要预测的值，并且该 field 被设置为了 target 对应的 `AccMetric` 可以按如下的定义

In [3]:
from fastNLP import MetricBase

class AccMetric(MetricBase):

    def __init__(self):
        super().__init__()
        # 根据你的情况自定义指标
        self.total = 0
        self.acc_count = 0

    # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致，不然找不到对应的value
    # pred, target 的参数是 fastNLP 的默认配置
    def evaluate(self, pred, target):
        # dev或test时，每个batch结束会调用一次该方法，需要实现如何根据每个batch累加metric
        self.total += target.size(0)
        self.acc_count += target.eq(pred).sum().item()

    def get_metric(self, reset=True): # 在这里定义如何计算metric
        acc = self.acc_count/self.total
        if reset: # 是否清零以便重新计算
            self.acc_count = 0
            self.total = 0
        return {'acc': acc}
        # 需要返回一个dict，key为该metric的名称，该名称会显示到Trainer的progress bar中

In [4]:
trainer = Trainer(train_data=train_data, model=model, loss=loss,
                  optimizer=optimizer, batch_size=32, dev_data=dev_data,
                  metrics=AccMetric(), device=device)
trainer.train()

input fields after batch(if batch size is 2):
	words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) 
	seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 
target fields after batch(if batch size is 2):
	target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 

training epochs started 2020-02-28-00-12-21


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.33 seconds!
Evaluation on dev at Epoch 1/10. Step:154/1540: 
AccMetric: acc=0.7419724770642202



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.19 seconds!
Evaluation on dev at Epoch 2/10. Step:308/1540: 
AccMetric: acc=0.7660550458715596



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.27 seconds!
Evaluation on dev at Epoch 3/10. Step:462/1540: 
AccMetric: acc=0.75



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.24 seconds!
Evaluation on dev at Epoch 4/10. Step:616/1540: 
AccMetric: acc=0.7534403669724771



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.29 seconds!
Evaluation on dev at Epoch 5/10. Step:770/1540: 
AccMetric: acc=0.7488532110091743



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.14 seconds!
Evaluation on dev at Epoch 6/10. Step:924/1540: 
AccMetric: acc=0.7488532110091743



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.27 seconds!
Evaluation on dev at Epoch 7/10. Step:1078/1540: 
AccMetric: acc=0.7568807339449541



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.42 seconds!
Evaluation on dev at Epoch 8/10. Step:1232/1540: 
AccMetric: acc=0.7488532110091743



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.16 seconds!
Evaluation on dev at Epoch 9/10. Step:1386/1540: 
AccMetric: acc=0.7408256880733946



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.28 seconds!
Evaluation on dev at Epoch 10/10. Step:1540/1540: 
AccMetric: acc=0.7408256880733946


In Epoch:2/Step:308, got best dev performance:
AccMetric: acc=0.7660550458715596
Reloaded the best model.


{'best_eval': {'AccMetric': {'acc': 0.7660550458715596}},
 'best_epoch': 2,
 'best_step': 308,
 'seconds': 29.74}

### Version 2

如果需要复用 metric，比如下一次使用 `AccMetric` 时，dataset中目标field不叫 `target` 而叫 `y` ，或者model的输出不是 `pred`


In [5]:
class AccMetric(MetricBase):
    def __init__(self, pred=None, target=None):
        """
        假设在另一场景使用时，目标field叫y，model给出的key为pred_y。则只需要在初始化AccMetric时，
        acc_metric = AccMetric(pred='pred_y', target='y')即可。
        当初始化为acc_metric = AccMetric() 时，fastNLP会直接使用 'pred', 'target' 作为key去索取对应的的值
        """

        super().__init__()

        # 如果没有注册该则效果与 Version 1 就是一样的
        self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可

        # 根据你的情况自定义指标
        self.total = 0
        self.acc_count = 0

    # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致，不然找不到对应的value
    # pred, target 的参数是 fastNLP 的默认配置
    def evaluate(self, pred, target):
        # dev或test时，每个batch结束会调用一次该方法，需要实现如何根据每个batch累加metric
        self.total += target.size(0)
        self.acc_count += target.eq(pred).sum().item()

    def get_metric(self, reset=True): # 在这里定义如何计算metric
        acc = self.acc_count/self.total
        if reset: # 是否清零以便重新计算
            self.acc_count = 0
            self.total = 0
        return {'acc': acc}
        # 需要返回一个dict，key为该metric的名称，该名称会显示到Trainer的progress bar中

In [6]:
trainer = Trainer(train_data=train_data, model=model, loss=loss,
                  optimizer=optimizer, batch_size=32, dev_data=dev_data,
                  metrics=AccMetric(pred="pred", target="target"), device=device)
trainer.train()

input fields after batch(if batch size is 2):
	words: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 13]) 
	seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 
target fields after batch(if batch size is 2):
	target: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 

training epochs started 2020-02-28-00-12-51


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.24 seconds!
Evaluation on dev at Epoch 1/10. Step:154/1540: 
AccMetric: acc=0.7545871559633027



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.24 seconds!
Evaluation on dev at Epoch 2/10. Step:308/1540: 
AccMetric: acc=0.7534403669724771



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.18 seconds!
Evaluation on dev at Epoch 3/10. Step:462/1540: 
AccMetric: acc=0.7557339449541285



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.11 seconds!
Evaluation on dev at Epoch 4/10. Step:616/1540: 
AccMetric: acc=0.7511467889908257



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.19 seconds!
Evaluation on dev at Epoch 5/10. Step:770/1540: 
AccMetric: acc=0.7465596330275229



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.14 seconds!
Evaluation on dev at Epoch 6/10. Step:924/1540: 
AccMetric: acc=0.7454128440366973



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.43 seconds!
Evaluation on dev at Epoch 7/10. Step:1078/1540: 
AccMetric: acc=0.7488532110091743



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.21 seconds!
Evaluation on dev at Epoch 8/10. Step:1232/1540: 
AccMetric: acc=0.7431192660550459



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.1 seconds!
Evaluation on dev at Epoch 9/10. Step:1386/1540: 
AccMetric: acc=0.7477064220183486



HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…

Evaluate data in 0.29 seconds!
Evaluation on dev at Epoch 10/10. Step:1540/1540: 
AccMetric: acc=0.7465596330275229


In Epoch:3/Step:462, got best dev performance:
AccMetric: acc=0.7557339449541285
Reloaded the best model.


{'best_eval': {'AccMetric': {'acc': 0.7557339449541285}},
 'best_epoch': 3,
 'best_step': 462,
 'seconds': 28.68}