You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_7_metrics.rst 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. ===============================
  2. 使用Metric快速评测你的模型
  3. ===============================
  4. 在进行训练时,fastNLP提供了各种各样的 :mod:`~fastNLP.core.metrics` 。
  5. 如 :doc:`/user/quickstart` 中所介绍的,:class:`~fastNLP.AccuracyMetric` 类的对象被直接传到 :class:`~fastNLP.Trainer` 中用于训练
  6. .. code-block:: python
  7. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
  8. trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data,
  9. loss=CrossEntropyLoss(), metrics=AccuracyMetric())
  10. trainer.train()
  11. 除了 :class:`~fastNLP.AccuracyMetric` 之外,:class:`~fastNLP.SpanFPreRecMetric` 也是一种非常见的评价指标,
  12. 例如在序列标注问题中,常以span的方式计算 F-measure, precision, recall。
  13. 另外,fastNLP 还实现了用于抽取式QA(如SQuAD)的metric :class:`~fastNLP.ExtractiveQAMetric`。
  14. 用户可以参考下面这个表格,点击第一列查看各个 :mod:`~fastNLP.core.metrics` 的详细文档。
  15. .. csv-table::
  16. :header: 名称, 介绍
  17. :class:`~fastNLP.core.metrics.MetricBase` , 自定义metrics需继承的基类
  18. :class:`~fastNLP.core.metrics.AccuracyMetric` , 简单的正确率metric
  19. :class:`~fastNLP.core.metrics.SpanFPreRecMetric` , "同时计算 F-measure, precision, recall 值的 metric"
  20. :class:`~fastNLP.core.metrics.ExtractiveQAMetric` , 用于抽取式QA任务 的metric
  21. 更多的 :mod:`~fastNLP.core.metrics` 正在被添加到 fastNLP 当中,敬请期待。
  22. ------------------------------
  23. 定义自己的metrics
  24. ------------------------------
  25. 在定义自己的metrics类时需继承 fastNLP 的 :class:`~fastNLP.core.metrics.MetricBase`,
  26. 并覆盖写入 ``evaluate`` 和 ``get_metric`` 方法。
  27. evaluate(xxx) 中传入一个批次的数据,将针对一个批次的预测结果做评价指标的累计
  28. get_metric(xxx) 当所有数据处理完毕时调用该方法,它将根据 evaluate函数累计的评价指标统计量来计算最终的评价结果
  29. 以分类问题中,Accuracy计算为例,假设model的forward返回dict中包含 `pred` 这个key, 并且该key需要用于Accuracy::
  30. class Model(nn.Module):
  31. def __init__(xxx):
  32. # do something
  33. def forward(self, xxx):
  34. # do something
  35. return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
  36. 假设dataset中 `label` 这个field是需要预测的值,并且该field被设置为了target
  37. 对应的AccMetric可以按如下的定义, version1, 只使用这一次::
  38. class AccMetric(MetricBase):
  39. def __init__(self):
  40. super().__init__()
  41. # 根据你的情况自定义指标
  42. self.corr_num = 0
  43. self.total = 0
  44. def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value
  45. # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
  46. self.total += label.size(0)
  47. self.corr_num += label.eq(pred).sum().item()
  48. def get_metric(self, reset=True): # 在这里定义如何计算metric
  49. acc = self.corr_num/self.total
  50. if reset: # 是否清零以便重新计算
  51. self.corr_num = 0
  52. self.total = 0
  53. return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
  54. version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred::
  55. class AccMetric(MetricBase):
  56. def __init__(self, label=None, pred=None):
  57. # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,
  58. # acc_metric = AccMetric(label='y', pred='pred_y')即可。
  59. # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对
  60. # 应的的值
  61. super().__init__()
  62. self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
  63. # 如果没有注册该则效果与version1就是一样的
  64. # 根据你的情况自定义指标
  65. self.corr_num = 0
  66. self.total = 0
  67. def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。
  68. # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
  69. self.total += label.size(0)
  70. self.corr_num += label.eq(pred).sum().item()
  71. def get_metric(self, reset=True): # 在这里定义如何计算metric
  72. acc = self.corr_num/self.total
  73. if reset: # 是否清零以便重新计算
  74. self.corr_num = 0
  75. self.total = 0
  76. return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中
  77. ``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.
  78. ``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.
  79. ``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.
  80. ``MetricBase`` 会进行以下的类型检测:
  81. 1. self.evaluate当中是否有varargs, 这是不支持的.
  82. 2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .
  83. 3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .
  84. 除此以外,在参数被传入self.evaluate以前,这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数
  85. 如果kwargs是self.evaluate的参数,则不会检测
  86. self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值
  87. self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值