Browse Source

Address with 'param not None' case

tags/v1.0.0
peixu_ren 5 years ago
parent
commit
21c663178c
2 changed files with 5 additions and 0 deletions
  1. +2
    -0
      mindspore/nn/probability/bijector/bijector.py
  2. +3
    -0
      mindspore/nn/probability/bijector/power_transform.py

+ 2
- 0
mindspore/nn/probability/bijector/bijector.py View File

@@ -50,6 +50,8 @@ class Bijector(Cell):
self._parameters = {} self._parameters = {}
# parsing parameters # parsing parameters
for k in param.keys(): for k in param.keys():
if k == 'param':
continue
if not(k == 'self' or k.startswith('_')): if not(k == 'self' or k.startswith('_')):
self._parameters[k] = param[k] self._parameters[k] = param[k]
self._is_constant_jacobian = is_constant_jacobian self._is_constant_jacobian = is_constant_jacobian


+ 3
- 0
mindspore/nn/probability/bijector/power_transform.py View File

@@ -35,6 +35,9 @@ class PowerTransform(Bijector):
Args: Args:
power (int or float): scale factor. Default: 0. power (int or float): scale factor. Default: 0.
name (str): name of the bijector. Default: 'PowerTransform'. name (str): name of the bijector. Default: 'PowerTransform'.
param (dict): parameters used to initialize the bijector. This is only used when other bijectors that inherits
from powertransform passing in parameters. In this case the derived bijector may overwrite the param args.
Default: None.


Examples: Examples:
>>> # To initialize a PowerTransform bijector of power 0.5 >>> # To initialize a PowerTransform bijector of power 0.5


Loading…
Cancel
Save