You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_logistic.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Test nn.probability.distribution.logistic.
  17. """
  18. import pytest
  19. import mindspore.nn as nn
  20. import mindspore.nn.probability.distribution as msd
  21. from mindspore import dtype
  22. from mindspore import Tensor
  23. def test_logistic_shape_errpr():
  24. """
  25. Invalid shapes.
  26. """
  27. with pytest.raises(ValueError):
  28. msd.Logistic([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
  29. def test_type():
  30. with pytest.raises(TypeError):
  31. msd.Logistic(0., 1., dtype=dtype.int32)
  32. def test_name():
  33. with pytest.raises(TypeError):
  34. msd.Logistic(0., 1., name=1.0)
  35. def test_seed():
  36. with pytest.raises(TypeError):
  37. msd.Logistic(0., 1., seed='seed')
  38. def test_scale():
  39. with pytest.raises(ValueError):
  40. msd.Logistic(0., 0.)
  41. with pytest.raises(ValueError):
  42. msd.Logistic(0., -1.)
  43. def test_arguments():
  44. """
  45. args passing during initialization.
  46. """
  47. l = msd.Logistic()
  48. assert isinstance(l, msd.Distribution)
  49. l = msd.Logistic([3.0], [4.0], dtype=dtype.float32)
  50. assert isinstance(l, msd.Distribution)
  51. class LogisticProb(nn.Cell):
  52. """
  53. logistic distribution: initialize with loc/scale.
  54. """
  55. def __init__(self):
  56. super(LogisticProb, self).__init__()
  57. self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32)
  58. def construct(self, value):
  59. prob = self.logistic.prob(value)
  60. log_prob = self.logistic.log_prob(value)
  61. cdf = self.logistic.cdf(value)
  62. log_cdf = self.logistic.log_cdf(value)
  63. sf = self.logistic.survival_function(value)
  64. log_sf = self.logistic.log_survival(value)
  65. return prob + log_prob + cdf + log_cdf + sf + log_sf
  66. def test_logistic_prob():
  67. """
  68. Test probability functions: passing value through construct.
  69. """
  70. net = LogisticProb()
  71. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  72. ans = net(value)
  73. assert isinstance(ans, Tensor)
  74. class LogisticProb1(nn.Cell):
  75. """
  76. logistic distribution: initialize without loc/scale.
  77. """
  78. def __init__(self):
  79. super(LogisticProb1, self).__init__()
  80. self.logistic = msd.Logistic()
  81. def construct(self, value, mu, s):
  82. prob = self.logistic.prob(value, mu, s)
  83. log_prob = self.logistic.log_prob(value, mu, s)
  84. cdf = self.logistic.cdf(value, mu, s)
  85. log_cdf = self.logistic.log_cdf(value, mu, s)
  86. sf = self.logistic.survival_function(value, mu, s)
  87. log_sf = self.logistic.log_survival(value, mu, s)
  88. return prob + log_prob + cdf + log_cdf + sf + log_sf
  89. def test_logistic_prob1():
  90. """
  91. Test probability functions: passing loc/scale, value through construct.
  92. """
  93. net = LogisticProb1()
  94. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  95. mu = Tensor([0.0], dtype=dtype.float32)
  96. s = Tensor([1.0], dtype=dtype.float32)
  97. ans = net(value, mu, s)
  98. assert isinstance(ans, Tensor)
  99. class KL(nn.Cell):
  100. """
  101. Test kl_loss. Should raise NotImplementedError.
  102. """
  103. def __init__(self):
  104. super(KL, self).__init__()
  105. self.logistic = msd.Logistic(3.0, 4.0)
  106. def construct(self, mu, s):
  107. kl = self.logistic.kl_loss('Logistic', mu, s)
  108. return kl
  109. class Crossentropy(nn.Cell):
  110. """
  111. Test cross entropy. Should raise NotImplementedError.
  112. """
  113. def __init__(self):
  114. super(Crossentropy, self).__init__()
  115. self.logistic = msd.Logistic(3.0, 4.0)
  116. def construct(self, mu, s):
  117. cross_entropy = self.logistic.cross_entropy('Logistic', mu, s)
  118. return cross_entropy
  119. class LogisticBasics(nn.Cell):
  120. """
  121. Test class: basic loc/scale function.
  122. """
  123. def __init__(self):
  124. super(LogisticBasics, self).__init__()
  125. self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32)
  126. def construct(self):
  127. mean = self.logistic.mean()
  128. sd = self.logistic.sd()
  129. mode = self.logistic.mode()
  130. entropy = self.logistic.entropy()
  131. return mean + sd + mode + entropy
  132. def test_bascis():
  133. """
  134. Test mean/sd/mode/entropy functionality of logistic.
  135. """
  136. net = LogisticBasics()
  137. ans = net()
  138. assert isinstance(ans, Tensor)
  139. mu = Tensor(1.0, dtype=dtype.float32)
  140. s = Tensor(1.0, dtype=dtype.float32)
  141. with pytest.raises(NotImplementedError):
  142. kl = KL()
  143. ans = kl(mu, s)
  144. with pytest.raises(NotImplementedError):
  145. crossentropy = Crossentropy()
  146. ans = crossentropy(mu, s)
  147. class LogisticConstruct(nn.Cell):
  148. """
  149. logistic distribution: going through construct.
  150. """
  151. def __init__(self):
  152. super(LogisticConstruct, self).__init__()
  153. self.logistic = msd.Logistic(3.0, 4.0)
  154. self.logistic1 = msd.Logistic()
  155. def construct(self, value, mu, s):
  156. prob = self.logistic('prob', value)
  157. prob1 = self.logistic('prob', value, mu, s)
  158. prob2 = self.logistic1('prob', value, mu, s)
  159. return prob + prob1 + prob2
  160. def test_logistic_construct():
  161. """
  162. Test probability function going through construct.
  163. """
  164. net = LogisticConstruct()
  165. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  166. mu = Tensor([0.0], dtype=dtype.float32)
  167. s = Tensor([1.0], dtype=dtype.float32)
  168. ans = net(value, mu, s)
  169. assert isinstance(ans, Tensor)