You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.probability.distribution.Distribution.rst 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. mindspore.nn.probability.distribution.Distribution
  2. ===================================================
  3. .. py:class:: mindspore.nn.probability.distribution.Distribution(seed, dtype, name, param)
  4. 所有分布的基类。
  5. **参数:**
  6. - **seed** (int) - 采样时使用的种子。如果为None,则使用0。
  7. - **dtype** (mindspore.dtype) - 事件样例的类型。
  8. - **name** (str) - 分布的名称。
  9. - **param** (dict) - 用于初始化分布的参数。
  10. **支持平台:**
  11. ``Ascend`` ``GPU``
  12. .. note::
  13. 派生类必须重写 `_mean` 、 `_prob` 和 `_log_prob` 等操作。必填参数必须通过 `args` 或 `kwargs` 传入,如 `_prob` 的 `value` 。
  14. .. py:method:: cdf(value, *args, **kwargs)
  15. 在给定值下评估累积分布函数(Cumulatuve Distribution Function, CDF)。
  16. **参数:**
  17. - **value** (Tensor) - 要评估的值。
  18. - **args** (list) - 传递给子类的位置参数列表。
  19. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  20. .. py:method:: construct(name, *args, **kwargs)
  21. 重写Cell中的 `construct` 。
  22. .. note::
  23. 支持的函数包括:'prob'、'log_prob'、'cdf', 'log_cdf'、'survival_function'、'log_survival'、'var'、
  24. 'sd'、'mode'、'mean'、'entropy'、'kl_loss'、'cross_entropy'、'sample'、'get_dist_args'、'get_dist_type'。
  25. **参数:**
  26. - **name** (str) - 函数名称。
  27. - **args** (list) - 函数所需的位置参数列表。
  28. - **kwargs** (dict) - 函数所需的关键字参数字典。
  29. .. py:method:: cross_entropy(dist, *args, **kwargs)
  30. 评估分布a和b之间的交叉熵。
  31. **参数:**
  32. - **dist** (str) - 分布的类型。
  33. - **args** (list) - 传递给子类的位置参数列表。
  34. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  35. .. py:method:: entropy(*args, **kwargs)
  36. 计算熵。
  37. **参数:**
  38. - **args** (list) - 传递给子类的位置参数列表。
  39. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  40. .. py:method:: get_dist_args(*args, **kwargs)
  41. 检查默认参数的可用性和有效性。
  42. **参数:**
  43. - **args** (list) - 传递给子类的位置参数列表。
  44. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  45. .. note::
  46. 传递给字类的参数的顺序应该与通过 `_add_parameter` 初始化默认参数的顺序相同。
  47. .. py:method:: get_dist_type()
  48. 返回分布类型。
  49. .. py:method:: kl_loss(dist, *args, **kwargs)
  50. 评估KL散度,即KL(a||b)。
  51. **参数:**
  52. - **dist** (str) - 分布的类型。
  53. - **args** (list) - 传递给子类的位置参数列表。
  54. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  55. .. py:method:: log_cdf(value, *args, **kwargs)
  56. 计算给定值对于的cdf的对数。
  57. **参数:**
  58. - **value** (Tensor) - 要评估的值。
  59. - **args** (list) - 传递给子类的位置参数列表。
  60. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  61. .. py:method:: log_prob(value, *args, **kwargs)
  62. 计算给定值对应的概率的对数(pdf或pmf)。
  63. **参数:**
  64. - **value** (Tensor) - 要评估的值。
  65. - **args** (list) - 传递给子类的位置参数列表。
  66. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  67. .. py:method:: log_survival(value, *args, **kwargs)
  68. 计算给定值对应的剩余函数的对数。
  69. **参数:**
  70. - **value** (Tensor) - 要评估的值。
  71. - **args** (list) - 传递给子类的位置参数列表。
  72. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  73. .. py:method:: mean(*args, **kwargs)
  74. 评估平均值。
  75. **参数:**
  76. - **args** (list) - 传递给子类的位置参数列表。
  77. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  78. .. py:method:: mode(*args, **kwargs)
  79. 评估模式。
  80. **参数:**
  81. - **args** (list) - 传递给子类的位置参数列表。
  82. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  83. .. py:method:: prob(value, *args, **kwargs)
  84. 评估给定值下的概率(Probability Density Function或Probability Mass Function)。
  85. **参数:**
  86. - **value** (Tensor) - 要评估的值。
  87. - **args** (list) - 传递给子类的位置参数列表。
  88. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  89. .. py:method:: sample(*args, **kwargs)
  90. 采样函数。
  91. **参数:**
  92. - **shape** (tuple) - 样本的shape。
  93. - **args** (list) - 传递给子类的位置参数列表。
  94. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  95. .. py:method:: sd(*args, **kwargs)
  96. 标准差评估。
  97. **参数:**
  98. - **args** (list) - 传递给子类的位置参数列表。
  99. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  100. .. py:method:: survival_function(value, *args, **kwargs)
  101. 计算给定值对应的剩余函数。
  102. **参数:**
  103. - **value** (Tensor) - 要评估的值。
  104. - **args** (list) - 传递给子类的位置参数列表。
  105. - **kwargs** (dict) - 传递给子类的关键字参数字典。
  106. .. py:method:: var(*args, **kwargs)
  107. 评估方差。
  108. **参数:**
  109. - **args** (list) - 传递给子类的位置参数列表。
  110. - **kwargs** (dict) - 传递给子类的关键字参数字典。