From 11aa3f6f5f36b87074a50648bb62d11cb50c9acb Mon Sep 17 00:00:00 2001 From: bingyaweng Date: Tue, 27 Oct 2020 20:06:39 +0800 Subject: [PATCH] modify the method for calculating kl loss --- .../bnn_layers/bnn_cell_wrapper.py | 68 +++++++------- .../transforms/bnn_loss/__init__.py | 19 ---- .../transforms/bnn_loss/generate_kl_loss.py | 89 ------------------- .../transforms/bnn_loss/withLossCell.py | 66 -------------- .../probability/transforms/transform_bnn.py | 28 ++---- 5 files changed, 38 insertions(+), 232 deletions(-) delete mode 100644 mindspore/nn/probability/transforms/bnn_loss/__init__.py delete mode 100644 mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py delete mode 100644 mindspore/nn/probability/transforms/bnn_loss/withLossCell.py diff --git a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py index 6912beafdc..3801496dc4 100644 --- a/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +++ b/mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py @@ -15,29 +15,12 @@ """Generate WithLossCell suitable for BNN.""" from .conv_variational import _ConvVariational from .dense_variational import _DenseVariational -from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss +from ...cell import Cell __all__ = ['WithBNNLossCell'] -class ClassWrap: - """Decorator of WithBNNLossCell""" - def __init__(self, cls): - self._cls = cls - self.bnn_loss_file = None - self.__doc__ = cls.__doc__ - self.__name__ = cls.__name__ - self.__bases__ = cls.__bases__ - - def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor): - obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor) - bnn_with_loss = obj() - self.bnn_loss_file = obj.bnn_loss_file - return bnn_with_loss - - -@ClassWrap -class WithBNNLossCell: +class WithBNNLossCell(Cell): r""" Generate a suitable WithLossCell for BNN to wrap the bayesian network with loss function. @@ -68,6 +51,7 @@ class WithBNNLossCell: """ def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): + super(WithBNNLossCell, self).__init__(auto_prefix=False) if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): raise TypeError('The type of `dnn_factor` should be `int` or `float`') if dnn_factor < 0: @@ -78,28 +62,36 @@ class WithBNNLossCell: if bnn_factor < 0: raise ValueError('The value of `bnn_factor` should >= 0') - self.backbone = backbone - self.loss_fn = loss_fn + self._backbone = backbone + self._loss_fn = loss_fn self.dnn_factor = dnn_factor self.bnn_factor = bnn_factor - self.bnn_loss_file = None - - def _generate_loss_cell(self): - """Generate WithBNNLossCell by ast.""" - layer_count = self._kl_loss_count(self.backbone) - bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, - self.dnn_factor, self.bnn_factor) - return bnn_with_loss - - def _kl_loss_count(self, net): - """ Calculate the number of Bayesian layers.""" - count = 0 + self.kl_loss = [] + self._add_kl_loss(self._backbone) + + def construct(self, x, label): + y_pred = self._backbone(x) + backbone_loss = self._loss_fn(y_pred, label) + kl_loss = 0 + for i in range(len(self.kl_loss)): + kl_loss += self.kl_loss[i]() + loss = backbone_loss * self.dnn_factor + kl_loss * self.bnn_factor + return loss + + def _add_kl_loss(self, net): + """Collect kl loss of each Bayesian layer.""" for (_, layer) in net.name_cells().items(): if isinstance(layer, (_DenseVariational, _ConvVariational)): - count += 1 + self.kl_loss.append(layer.compute_kl_loss) else: - count += self._kl_loss_count(layer) - return count + self._add_kl_loss(layer) + + @property + def backbone_network(self): + """ + Returns the backbone network. - def __call__(self): - return self._generate_loss_cell() + Returns: + Cell, the backbone network. + """ + return self._backbone diff --git a/mindspore/nn/probability/transforms/bnn_loss/__init__.py b/mindspore/nn/probability/transforms/bnn_loss/__init__.py deleted file mode 100644 index c10f1a4578..0000000000 --- a/mindspore/nn/probability/transforms/bnn_loss/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -bnn loss. -""" -from . import generate_kl_loss -from .generate_kl_loss import gain_bnn_with_loss diff --git a/mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py b/mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py deleted file mode 100644 index 7ce57337ef..0000000000 --- a/mindspore/nn/probability/transforms/bnn_loss/generate_kl_loss.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Gain bnn_with_loss by rewrite WithLossCell as WithBNNLossCell to suit for BNN model""" -import ast -import importlib -import os -import sys -import tempfile -import astunparse -import mindspore - - -class _CodeTransformer(ast.NodeTransformer): - """ - Add kl_loss computation by analyzing the python code structure with the help of the AST module. - - Args: - layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. - """ - - def __init__(self, layer_count): - self.layer_count = layer_count - - def visit_FunctionDef(self, node): - """visit function and add kl_loss computation.""" - self.generic_visit(node) - if node.name == 'cal_kl_loss': - for i in range(self.layer_count): - func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())], - value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(), - right=ast.Call(func=ast.Name(id='self.kl_loss' + '[' + str(i) + ']', - ctx=ast.Load()), - args=[], keywords=[]))) - node.body.insert(-1, func) - return node - - -def _generate_kl_loss_func(layer_count): - """Rewrite WithLossCell as WithBNNLossCell to suit for BNN model.""" - path = os.path.dirname(mindspore.__file__) + '/nn/probability/transforms/bnn_loss/withLossCell.py' - with open(path, 'r') as fp: - srclines = fp.readlines() - src = ''.join(srclines) - if src.startswith((' ', '\t')): - src = 'if 1:\n' + src - expr_ast = ast.parse(src, mode='exec') - transformer = _CodeTransformer(layer_count) - modify = transformer.visit(expr_ast) - modify = ast.fix_missing_locations(modify) - func = astunparse.unparse(modify) - return func - - -def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor): - """ - Gain bnn_with_loss, which wraps bnn network with loss function and kl loss of each bayesian layer. - - Args: - layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. - backbone (Cell): The target network to wrap. - loss_fn (Cell): The loss function used to compute loss. - dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function. - bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. - """ - bnn_loss_func = _generate_kl_loss_func(layer_count) - path = os.path.dirname(mindspore.__file__) - bnn_loss_file = tempfile.NamedTemporaryFile(mode='w+t', suffix='.py', delete=True, - dir=path + '/nn/probability/transforms/bnn_loss') - bnn_loss_file.write(bnn_loss_func) - bnn_loss_file.seek(0) - - sys.path.append(path + '/nn/probability/transforms/bnn_loss') - - module_name = os.path.basename(bnn_loss_file.name)[0:-3] - bnn_loss_module = importlib.import_module(module_name, __package__) - bnn_with_loss = bnn_loss_module.WithBNNLossCell(backbone, loss_fn, dnn_factor, bnn_factor) - return bnn_with_loss, bnn_loss_file diff --git a/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py b/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py deleted file mode 100644 index 14176e1d20..0000000000 --- a/mindspore/nn/probability/transforms/bnn_loss/withLossCell.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Original WithBNNLossCell for ast to rewrite.""" - -import mindspore.nn as nn -from mindspore.nn.probability.bnn_layers.conv_variational import _ConvVariational -from mindspore.nn.probability.bnn_layers.dense_variational import _DenseVariational - - -class WithBNNLossCell(nn.Cell): - """ - Cell with loss function. - - Wraps the network with loss function. This Cell accepts data, label, backbone_factor and kl_factor as inputs and - the computed loss will be returned. - """ - def __init__(self, backbone, loss_fn, backbone_factor=1, kl_factor=1): - super(WithBNNLossCell, self).__init__(auto_prefix=False) - self._backbone = backbone - self._loss_fn = loss_fn - self.backbone_factor = backbone_factor - self.kl_factor = kl_factor - self.kl_loss = [] - self._add_kl_loss(self._backbone) - - def construct(self, x, label): - y_pred = self._backbone(x) - backbone_loss = self._loss_fn(y_pred, label) - kl_loss = self.cal_kl_loss() - loss = backbone_loss*self.backbone_factor + kl_loss*self.kl_factor - return loss - - def cal_kl_loss(self): - """Calculate kl loss.""" - loss = 0.0 - return loss - - def _add_kl_loss(self, net): - """Collect kl loss of each Bayesian layer.""" - for (_, layer) in net.name_cells().items(): - if isinstance(layer, (_DenseVariational, _ConvVariational)): - self.kl_loss.append(layer.compute_kl_loss) - else: - self._add_kl_loss(layer) - - @property - def backbone_network(self): - """ - Returns the backbone network. - - Returns: - Cell, the backbone network. - """ - return self._backbone diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py index 9d23af8034..cb6cf0b48b 100644 --- a/mindspore/nn/probability/transforms/transform_bnn.py +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -17,8 +17,8 @@ import mindspore.nn as nn from ...wrap.cell_wrapper import TrainOneStepCell from ....nn import optim from ....nn import layer -from .bnn_loss.generate_kl_loss import gain_bnn_with_loss from ...probability import bnn_layers +from ..bnn_layers.bnn_cell_wrapper import WithBNNLossCell from ..bnn_layers.conv_variational import ConvReparam from ..bnn_layers.dense_variational import DenseReparam @@ -77,7 +77,6 @@ class TransformToBNN: self.loss_fn = getattr(net_with_loss, "_loss_fn") self.dnn_factor = dnn_factor self.bnn_factor = bnn_factor - self.bnn_loss_file = None def transform_to_bnn_model(self, get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias, @@ -120,15 +119,13 @@ class TransformToBNN: if not add_conv_args: add_conv_args = {} - layer_count = self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, - add_conv_args) + self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args) # rename layers of BNN model to prevent duplication of names for value, param in self.backbone.parameters_and_names(): param.name = value - bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, - self.dnn_factor, self.bnn_factor) + bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor) bnn_optimizer = self._create_optimizer_with_bnn_params() train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) return train_bnn_network @@ -179,13 +176,11 @@ class TransformToBNN: if not add_args: add_args = {} - layer_count = self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, - add_args) + self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, add_args) for value, param in self.backbone.parameters_and_names(): param.name = value - bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, - self.dnn_factor, self.bnn_factor) + bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor) bnn_optimizer = self._create_optimizer_with_bnn_params() train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) @@ -228,32 +223,25 @@ class TransformToBNN: def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args): """Replace both dense layer and conv2d layer in DNN model to bayesian layers.""" - count = 0 for name, cell in backbone.name_cells().items(): if isinstance(cell, nn.Dense): dense_args = get_dense_args(cell) new_layer = DenseReparam(**dense_args, **add_dense_args) setattr(backbone, name, new_layer) - count += 1 elif isinstance(cell, nn.Conv2d): conv_args = get_conv_args(cell) new_layer = ConvReparam(**conv_args, **add_conv_args) setattr(backbone, name, new_layer) - count += 1 else: - count += self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, - add_conv_args) - return count + self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, + add_conv_args) def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args): """Convert a specific type of layers in DNN model to corresponding bayesian layers.""" - count = 0 for name, cell in backbone.name_cells().items(): if isinstance(cell, dnn_layer): args = get_args(cell) new_layer = bnn_layer(**args, **add_args) setattr(backbone, name, new_layer) - count += 1 else: - count += self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args) - return count + self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args)