You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cauchy.py 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Test nn.probability.distribution.cauchy.
  17. """
  18. import pytest
  19. import mindspore.nn as nn
  20. import mindspore.nn.probability.distribution as msd
  21. from mindspore import dtype
  22. from mindspore import Tensor
  23. def test_cauchy_shape_errpr():
  24. """
  25. Invalid shapes.
  26. """
  27. with pytest.raises(ValueError):
  28. msd.Cauchy([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
  29. def test_type():
  30. with pytest.raises(TypeError):
  31. msd.Cauchy(0., 1., dtype=dtype.int32)
  32. def test_name():
  33. with pytest.raises(TypeError):
  34. msd.Cauchy(0., 1., name=1.0)
  35. def test_seed():
  36. with pytest.raises(TypeError):
  37. msd.Cauchy(0., 1., seed='seed')
  38. def test_scale():
  39. with pytest.raises(ValueError):
  40. msd.Cauchy(0., 0.)
  41. with pytest.raises(ValueError):
  42. msd.Cauchy(0., -1.)
  43. def test_arguments():
  44. """
  45. args passing during initialization.
  46. """
  47. l = msd.Cauchy()
  48. assert isinstance(l, msd.Distribution)
  49. l = msd.Cauchy([3.0], [4.0], dtype=dtype.float32)
  50. assert isinstance(l, msd.Distribution)
  51. class CauchyProb(nn.Cell):
  52. """
  53. Cauchy distribution: initialize with loc/scale.
  54. """
  55. def __init__(self):
  56. super(CauchyProb, self).__init__()
  57. self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
  58. def construct(self, value):
  59. prob = self.cauchy.prob(value)
  60. log_prob = self.cauchy.log_prob(value)
  61. cdf = self.cauchy.cdf(value)
  62. log_cdf = self.cauchy.log_cdf(value)
  63. sf = self.cauchy.survival_function(value)
  64. log_sf = self.cauchy.log_survival(value)
  65. return prob + log_prob + cdf + log_cdf + sf + log_sf
  66. def test_cauchy_prob():
  67. """
  68. Test probability functions: passing value through construct.
  69. """
  70. net = CauchyProb()
  71. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  72. ans = net(value)
  73. assert isinstance(ans, Tensor)
  74. class CauchyProb1(nn.Cell):
  75. """
  76. Cauchy distribution: initialize without loc/scale.
  77. """
  78. def __init__(self):
  79. super(CauchyProb1, self).__init__()
  80. self.cauchy = msd.Cauchy()
  81. def construct(self, value, mu, s):
  82. prob = self.cauchy.prob(value, mu, s)
  83. log_prob = self.cauchy.log_prob(value, mu, s)
  84. cdf = self.cauchy.cdf(value, mu, s)
  85. log_cdf = self.cauchy.log_cdf(value, mu, s)
  86. sf = self.cauchy.survival_function(value, mu, s)
  87. log_sf = self.cauchy.log_survival(value, mu, s)
  88. return prob + log_prob + cdf + log_cdf + sf + log_sf
  89. def test_cauchy_prob1():
  90. """
  91. Test probability functions: passing loc/scale, value through construct.
  92. """
  93. net = CauchyProb1()
  94. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  95. mu = Tensor([0.0], dtype=dtype.float32)
  96. s = Tensor([1.0], dtype=dtype.float32)
  97. ans = net(value, mu, s)
  98. assert isinstance(ans, Tensor)
  99. class KL(nn.Cell):
  100. """
  101. Test kl_loss and cross entropy.
  102. """
  103. def __init__(self):
  104. super(KL, self).__init__()
  105. self.cauchy = msd.Cauchy(3.0, 4.0)
  106. self.cauchy1 = msd.Cauchy()
  107. def construct(self, mu, s, mu_a, s_a):
  108. kl = self.cauchy.kl_loss('Cauchy', mu, s)
  109. kl1 = self.cauchy1.kl_loss('Cauchy', mu, s, mu_a, s_a)
  110. cross_entropy = self.cauchy.cross_entropy('Cauchy', mu, s)
  111. cross_entropy1 = self.cauchy.cross_entropy('Cauchy', mu, s, mu_a, s_a)
  112. return kl + kl1 + cross_entropy + cross_entropy1
  113. def test_kl_cross_entropy():
  114. """
  115. Test kl_loss and cross_entropy.
  116. """
  117. net = KL()
  118. mu = Tensor([0.0], dtype=dtype.float32)
  119. s = Tensor([1.0], dtype=dtype.float32)
  120. mu_a = Tensor([0.0], dtype=dtype.float32)
  121. s_a = Tensor([1.0], dtype=dtype.float32)
  122. ans = net(mu, s, mu_a, s_a)
  123. assert isinstance(ans, Tensor)
  124. class CauchyBasics(nn.Cell):
  125. """
  126. Test class: basic loc/scale function.
  127. """
  128. def __init__(self):
  129. super(CauchyBasics, self).__init__()
  130. self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
  131. def construct(self):
  132. mode = self.cauchy.mode()
  133. entropy = self.cauchy.entropy()
  134. return mode + entropy
  135. class CauchyMean(nn.Cell):
  136. """
  137. Test class: basic loc/scale function.
  138. """
  139. def __init__(self):
  140. super(CauchyMean, self).__init__()
  141. self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
  142. def construct(self):
  143. return self.cauchy.mean()
  144. class CauchyVar(nn.Cell):
  145. """
  146. Test class: basic loc/scale function.
  147. """
  148. def __init__(self):
  149. super(CauchyVar, self).__init__()
  150. self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
  151. def construct(self):
  152. return self.cauchy.var()
  153. class CauchySd(nn.Cell):
  154. """
  155. Test class: basic loc/scale function.
  156. """
  157. def __init__(self):
  158. super(CauchySd, self).__init__()
  159. self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
  160. def construct(self):
  161. return self.cauchy.sd()
  162. def test_bascis():
  163. """
  164. Test mean/sd/var/mode/entropy functionality of Cauchy.
  165. """
  166. net = CauchyBasics()
  167. ans = net()
  168. assert isinstance(ans, Tensor)
  169. with pytest.raises(ValueError):
  170. net = CauchyMean()
  171. ans = net()
  172. with pytest.raises(ValueError):
  173. net = CauchyVar()
  174. ans = net()
  175. with pytest.raises(ValueError):
  176. net = CauchySd()
  177. ans = net()
  178. class CauchyConstruct(nn.Cell):
  179. """
  180. Cauchy distribution: going through construct.
  181. """
  182. def __init__(self):
  183. super(CauchyConstruct, self).__init__()
  184. self.cauchy = msd.Cauchy(3.0, 4.0)
  185. self.cauchy1 = msd.Cauchy()
  186. def construct(self, value, mu, s):
  187. prob = self.cauchy('prob', value)
  188. prob1 = self.cauchy('prob', value, mu, s)
  189. prob2 = self.cauchy1('prob', value, mu, s)
  190. return prob + prob1 + prob2
  191. def test_cauchy_construct():
  192. """
  193. Test probability function going through construct.
  194. """
  195. net = CauchyConstruct()
  196. value = Tensor([0.5, 1.0], dtype=dtype.float32)
  197. mu = Tensor([0.0], dtype=dtype.float32)
  198. s = Tensor([1.0], dtype=dtype.float32)
  199. ans = net(value, mu, s)
  200. assert isinstance(ans, Tensor)