You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss.py 3.3 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import unittest
  2. import torch
  3. import torch.nn.functional as F
  4. import fastNLP as loss
  5. from fastNLP.core.losses import squash, unpad
  6. class TestLoss(unittest.TestCase):
  7. def test_CrossEntropyLoss(self):
  8. ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth")
  9. a = torch.randn(3, 5, requires_grad=False)
  10. b = torch.empty(3, dtype=torch.long).random_(5)
  11. ans = ce({"my_predict": a}, {"my_truth": b})
  12. self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))
  13. def test_BCELoss(self):
  14. bce = loss.BCELoss(pred="my_predict", target="my_truth")
  15. a = torch.sigmoid(torch.randn((3, 5), requires_grad=False))
  16. b = torch.randn((3, 5), requires_grad=False)
  17. ans = bce({"my_predict": a}, {"my_truth": b})
  18. self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b))
  19. def test_L1Loss(self):
  20. l1 = loss.L1Loss(pred="my_predict", target="my_truth")
  21. a = torch.randn(3, 5, requires_grad=False)
  22. b = torch.randn(3, 5)
  23. ans = l1({"my_predict": a}, {"my_truth": b})
  24. self.assertEqual(ans, torch.nn.functional.l1_loss(a, b))
  25. def test_NLLLoss(self):
  26. l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
  27. a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
  28. b = torch.tensor([1, 0, 4])
  29. ans = l1({"my_predict": a}, {"my_truth": b})
  30. self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))
  31. class TestLosserError(unittest.TestCase):
  32. def test_losser1(self):
  33. # (1) only input, targets passed
  34. pred_dict = {"pred": torch.zeros(4, 3)}
  35. target_dict = {'target': torch.zeros(4).long()}
  36. los = loss.CrossEntropyLoss()
  37. print(los(pred_dict=pred_dict, target_dict=target_dict))
  38. #
  39. def test_losser2(self):
  40. # (2) with corrupted size
  41. pred_dict = {"pred": torch.zeros(16, 3)}
  42. target_dict = {'target': torch.zeros(16, 3).long()}
  43. los = loss.CrossEntropyLoss()
  44. with self.assertRaises(RuntimeError):
  45. print(los(pred_dict=pred_dict, target_dict=target_dict))
  46. def test_losser3(self):
  47. # (2) with corrupted size
  48. pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0}
  49. target_dict = {'target': torch.zeros(16).long()}
  50. los = loss.CrossEntropyLoss()
  51. print(los(pred_dict=pred_dict, target_dict=target_dict))
  52. def test_check_error(self):
  53. l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
  54. a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
  55. b = torch.tensor([1, 0, 4])
  56. with self.assertRaises(Exception):
  57. ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b})
  58. with self.assertRaises(Exception):
  59. ans = l1({"my_predict": a}, {"truth": b, "my": a})
  60. class TestLossUtils(unittest.TestCase):
  61. def test_squash(self):
  62. a, b = squash(torch.randn(3, 5), torch.randn(3, 5))
  63. self.assertEqual(tuple(a.size()), (3, 5))
  64. self.assertEqual(tuple(b.size()), (15,))
  65. def test_unpad(self):
  66. a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8))
  67. self.assertEqual(tuple(a.size()), (5, 8, 3))
  68. self.assertEqual(tuple(b.size()), (5, 8))