|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- import unittest
-
- import fastNLP.core.loss as loss
- import math
- import torch as tc
- import pdb
-
- class TestLoss(unittest.TestCase):
-
- def test_case_1(self):
- #验证nllloss的原理
-
- print (".----------------------------------")
-
- loss_func = loss.Loss("nll")
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [.3,.4,.3],
- [.5,.3,.2],
- [.3,.6,.1],
- ]
- )
-
- gy = tc.LongTensor(
- [
- 0,
- 1,
- 2,
- ]
- )
-
-
- y = tc.log(y)
- los = loss_func(y , gy)
-
- r = -math.log(.3) - math.log(.3) - math.log(.1)
- r /= 3
- print ("loss = %f" % (los))
- print ("r = %f" % (r))
-
- self.assertEqual(int(los * 1000), int(r * 1000))
-
- def test_case_2(self):
- #验证squash()的正确性
- print ("----------------------------------")
-
- log = math.log
-
- loss_func = loss.Loss("nll")
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [[.3,.4,.3],[.3,.4,.3],],
- [[.5,.3,.2],[.1,.2,.7],],
- [[.3,.6,.1],[.2,.1,.7],],
- ]
- )
-
- gy = tc.LongTensor(
- [
- [0,2],
- [1,2],
- [2,1],
- ]
- )
-
-
- #pdb.set_trace()
-
- y = tc.log(y)
- los = loss_func(y , gy)
- print ("loss = %f" % (los))
-
- r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
- r /= 6
- print ("r = %f" % (r))
-
- self.assertEqual(int(los * 1000), int(r * 1000))
-
- def test_case_3(self):
- #验证pack_padded_sequence()的正确性
- print ("----------------------------------")
-
- log = math.log
-
- loss_func = loss.Loss("nll")
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],],
- [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],],
- [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],],
- ]
- )
-
- gy = tc.LongTensor(
- [
- [0,2,1,],
- [1,2,0,],
- [2,0,0,],
- ]
- )
-
- lens = [3,2,1]
-
- #pdb.set_trace()
-
- y = tc.log(y)
-
- yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data
- gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data
- los = loss_func(yy , gyy)
- print ("loss = %f" % (los))
-
-
- r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
- r /= 6
- print ("r = %f" % (r))
-
- self.assertEqual(int(los * 1000), int(r * 1000))
-
- def test_case_4(self):
- #验证unpad()的正确性
- print ("----------------------------------")
-
- log = math.log
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
- [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
- [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
- ]
- )
-
- gy = tc.LongTensor(
- [
- [0,2,1,2,],
- [1,2,0,0,],
- [2,0,0,0,],
- ]
- )
-
- lens = [4,2,1]
-
- #pdb.set_trace()
-
- y = tc.log(y)
-
- loss_func = loss.Loss("nll" , pre_pro = ["unpad"])
- los = loss_func(y , gy , lens = lens)
- print ("loss = %f" % (los))
-
-
- r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
- r /= 7
- print ("r = %f" % (r))
-
-
- self.assertEqual(int(los * 1000), int(r * 1000))
-
- def test_case_5(self):
- #验证mask()和make_mask()的正确性
- print ("----------------------------------")
-
- log = math.log
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
- [[.5,.4,.1],[.3,.2,.5],[.4,.5,.1,],[.6,.1,.3,],],
- [[.3,.6,.1],[.3,.2,.5],[.0,.0,.0,],[.0,.0,.0,],],
- ]
- )
-
- gy = tc.LongTensor(
- [
- [1,2,0,0,],
- [0,2,1,2,],
- [2,1,0,0,],
- ]
- )
-
- mask = tc.ByteTensor(
- [
- [1,1,0,0,],
- [1,1,1,1,],
- [1,1,0,0,],
- ]
- )
-
- y = tc.log(y)
-
- lens = [2,4,2]
-
- loss_func = loss.Loss("nll" , pre_pro = ["mask"])
- los = loss_func(y , gy , mask = mask)
- print ("loss = %f" % (los))
-
- los2 = loss_func(y , gy , mask = loss.make_mask(lens,gy.size()[-1]))
- print ("loss2 = %f" % (los2))
-
-
- r = -log(.3) -log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2)
- r /= 8
- print ("r = %f" % (r))
-
-
- self.assertEqual(int(los * 1000), int(r * 1000))
- self.assertEqual(int(los2 * 1000), int(r * 1000))
-
- def test_case_6(self):
- #验证unpad_mask()的正确性
- print ("----------------------------------")
-
- log = math.log
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
- [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
- [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
- ]
- )
-
- gy = tc.LongTensor(
- [
- [0,2,1,2,],
- [1,2,0,0,],
- [2,0,0,0,],
- ]
- )
-
- lens = [4,2,1]
-
- #pdb.set_trace()
-
- y = tc.log(y)
-
- loss_func = loss.Loss("nll" , pre_pro = ["unpad_mask"])
- los = loss_func(y , gy , lens = lens)
- print ("loss = %f" % (los))
-
-
- r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
- r /= 7
- print ("r = %f" % (r))
-
- self.assertEqual(int(los * 1000), int(r * 1000))
-
- def test_case_7(self):
- #验证一些其他东西
- print ("----------------------------------")
-
- log = math.log
-
- #pdb.set_trace()
-
- y = tc.Tensor(
- [
- [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
- [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
- [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
- ]
- )
-
- gy = tc.LongTensor(
- [
- [0,2,1,2,],
- [1,2,0,0,],
- [2,0,0,0,],
- ]
- )
-
- lens = [4,2,1]
-
- #pdb.set_trace()
-
- y = tc.log(y)
-
- loss_func = loss.Loss("nll" , pre_pro = [] , weight = tc.Tensor([1,1,0]))
- loss_func.add_pre_pro("unpad_mask")
- los = loss_func(y , gy , lens = lens)
- print ("loss = %f" % (los))
-
-
- r = - log(.3) - log(.5) - log(.3)
- r /= 3
- print ("r = %f" % (r))
- self.assertEqual(int(los * 1000), int(r * 1000))
-
- if __name__ == "__main__":
- unittest.main()
|