From 21c663178ca64de738f18d801ac5711beddd94b7 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 2 Sep 2020 14:32:15 -0400 Subject: [PATCH] Address with 'param not None' case --- mindspore/nn/probability/bijector/bijector.py | 2 ++ mindspore/nn/probability/bijector/power_transform.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/mindspore/nn/probability/bijector/bijector.py b/mindspore/nn/probability/bijector/bijector.py index 79fe5d129d..4760efd4b6 100644 --- a/mindspore/nn/probability/bijector/bijector.py +++ b/mindspore/nn/probability/bijector/bijector.py @@ -50,6 +50,8 @@ class Bijector(Cell): self._parameters = {} # parsing parameters for k in param.keys(): + if k == 'param': + continue if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] self._is_constant_jacobian = is_constant_jacobian diff --git a/mindspore/nn/probability/bijector/power_transform.py b/mindspore/nn/probability/bijector/power_transform.py index e67f676238..696749692d 100644 --- a/mindspore/nn/probability/bijector/power_transform.py +++ b/mindspore/nn/probability/bijector/power_transform.py @@ -35,6 +35,9 @@ class PowerTransform(Bijector): Args: power (int or float): scale factor. Default: 0. name (str): name of the bijector. Default: 'PowerTransform'. + param (dict): parameters used to initialize the bijector. This is only used when other bijectors that inherits + from powertransform passing in parameters. In this case the derived bijector may overwrite the param args. + Default: None. Examples: >>> # To initialize a PowerTransform bijector of power 0.5