| @@ -52,28 +52,24 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| def test_AccuaryMetric4(self): | |||
| # (5) check reset | |||
| metric = AccuracyMetric() | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4, 3) + 1} | |||
| pred_dict = {"pred": torch.randn(4, 3, 2)} | |||
| target_dict = {'target': torch.ones(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 0}) | |||
| ans = torch.argmax(pred_dict["pred"], dim=2).to(target_dict["target"]) == target_dict["target"] | |||
| res = metric.get_metric() | |||
| self.assertTrue(isinstance(res, dict)) | |||
| self.assertTrue("acc" in res) | |||
| self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) | |||
| def test_AccuaryMetric5(self): | |||
| # (5) check reset | |||
| metric = AccuracyMetric() | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| pred_dict = {"pred": torch.randn(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| target_dict = {'target': torch.zeros(4, 3) + 1} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) | |||
| res = metric.get_metric(reset=False) | |||
| ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() | |||
| self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||
| def test_AccuaryMetric6(self): | |||
| # (6) check numpy array is not acceptable | |||
| @@ -90,10 +86,12 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| def test_AccuaryMetric7(self): | |||
| # (7) check map, match | |||
| metric = AccuracyMetric(pred='predictions', target='targets') | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
| pred_dict = {"predictions": torch.randn(4, 3, 2)} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| res = metric.get_metric() | |||
| ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() | |||
| self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||
| def test_AccuaryMetric8(self): | |||
| # (8) check map, does not match. use stop_fast_param to stop fast param map | |||