| @@ -15,29 +15,12 @@ | |||
| """Generate WithLossCell suitable for BNN.""" | |||
| from .conv_variational import _ConvVariational | |||
| from .dense_variational import _DenseVariational | |||
| from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss | |||
| from ...cell import Cell | |||
| __all__ = ['WithBNNLossCell'] | |||
| class ClassWrap: | |||
| """Decorator of WithBNNLossCell""" | |||
| def __init__(self, cls): | |||
| self._cls = cls | |||
| self.bnn_loss_file = None | |||
| self.__doc__ = cls.__doc__ | |||
| self.__name__ = cls.__name__ | |||
| self.__bases__ = cls.__bases__ | |||
| def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor): | |||
| obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor) | |||
| bnn_with_loss = obj() | |||
| self.bnn_loss_file = obj.bnn_loss_file | |||
| return bnn_with_loss | |||
| @ClassWrap | |||
| class WithBNNLossCell: | |||
| class WithBNNLossCell(Cell): | |||
| r""" | |||
| Generate a suitable WithLossCell for BNN to wrap the bayesian network with loss function. | |||
| @@ -68,6 +51,7 @@ class WithBNNLossCell: | |||
| """ | |||
| def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | |||
| super(WithBNNLossCell, self).__init__(auto_prefix=False) | |||
| if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): | |||
| raise TypeError('The type of `dnn_factor` should be `int` or `float`') | |||
| if dnn_factor < 0: | |||
| @@ -78,28 +62,36 @@ class WithBNNLossCell: | |||
| if bnn_factor < 0: | |||
| raise ValueError('The value of `bnn_factor` should >= 0') | |||
| self.backbone = backbone | |||
| self.loss_fn = loss_fn | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| self.dnn_factor = dnn_factor | |||
| self.bnn_factor = bnn_factor | |||
| self.bnn_loss_file = None | |||
| def _generate_loss_cell(self): | |||
| """Generate WithBNNLossCell by ast.""" | |||
| layer_count = self._kl_loss_count(self.backbone) | |||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||
| self.dnn_factor, self.bnn_factor) | |||
| return bnn_with_loss | |||
| def _kl_loss_count(self, net): | |||
| """ Calculate the number of Bayesian layers.""" | |||
| count = 0 | |||
| self.kl_loss = [] | |||
| self._add_kl_loss(self._backbone) | |||
| def construct(self, x, label): | |||
| y_pred = self._backbone(x) | |||
| backbone_loss = self._loss_fn(y_pred, label) | |||
| kl_loss = 0 | |||
| for i in range(len(self.kl_loss)): | |||
| kl_loss += self.kl_loss[i]() | |||
| loss = backbone_loss * self.dnn_factor + kl_loss * self.bnn_factor | |||
| return loss | |||
| def _add_kl_loss(self, net): | |||
| """Collect kl loss of each Bayesian layer.""" | |||
| for (_, layer) in net.name_cells().items(): | |||
| if isinstance(layer, (_DenseVariational, _ConvVariational)): | |||
| count += 1 | |||
| self.kl_loss.append(layer.compute_kl_loss) | |||
| else: | |||
| count += self._kl_loss_count(layer) | |||
| return count | |||
| self._add_kl_loss(layer) | |||
| @property | |||
| def backbone_network(self): | |||
| """ | |||
| Returns the backbone network. | |||
| def __call__(self): | |||
| return self._generate_loss_cell() | |||
| Returns: | |||
| Cell, the backbone network. | |||
| """ | |||
| return self._backbone | |||
| @@ -1,19 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| bnn loss. | |||
| """ | |||
| from . import generate_kl_loss | |||
| from .generate_kl_loss import gain_bnn_with_loss | |||
| @@ -1,89 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Gain bnn_with_loss by rewrite WithLossCell as WithBNNLossCell to suit for BNN model""" | |||
| import ast | |||
| import importlib | |||
| import os | |||
| import sys | |||
| import tempfile | |||
| import astunparse | |||
| import mindspore | |||
| class _CodeTransformer(ast.NodeTransformer): | |||
| """ | |||
| Add kl_loss computation by analyzing the python code structure with the help of the AST module. | |||
| Args: | |||
| layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. | |||
| """ | |||
| def __init__(self, layer_count): | |||
| self.layer_count = layer_count | |||
| def visit_FunctionDef(self, node): | |||
| """visit function and add kl_loss computation.""" | |||
| self.generic_visit(node) | |||
| if node.name == 'cal_kl_loss': | |||
| for i in range(self.layer_count): | |||
| func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())], | |||
| value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(), | |||
| right=ast.Call(func=ast.Name(id='self.kl_loss' + '[' + str(i) + ']', | |||
| ctx=ast.Load()), | |||
| args=[], keywords=[]))) | |||
| node.body.insert(-1, func) | |||
| return node | |||
| def _generate_kl_loss_func(layer_count): | |||
| """Rewrite WithLossCell as WithBNNLossCell to suit for BNN model.""" | |||
| path = os.path.dirname(mindspore.__file__) + '/nn/probability/transforms/bnn_loss/withLossCell.py' | |||
| with open(path, 'r') as fp: | |||
| srclines = fp.readlines() | |||
| src = ''.join(srclines) | |||
| if src.startswith((' ', '\t')): | |||
| src = 'if 1:\n' + src | |||
| expr_ast = ast.parse(src, mode='exec') | |||
| transformer = _CodeTransformer(layer_count) | |||
| modify = transformer.visit(expr_ast) | |||
| modify = ast.fix_missing_locations(modify) | |||
| func = astunparse.unparse(modify) | |||
| return func | |||
| def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor): | |||
| """ | |||
| Gain bnn_with_loss, which wraps bnn network with loss function and kl loss of each bayesian layer. | |||
| Args: | |||
| layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. | |||
| backbone (Cell): The target network to wrap. | |||
| loss_fn (Cell): The loss function used to compute loss. | |||
| dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function. | |||
| bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. | |||
| """ | |||
| bnn_loss_func = _generate_kl_loss_func(layer_count) | |||
| path = os.path.dirname(mindspore.__file__) | |||
| bnn_loss_file = tempfile.NamedTemporaryFile(mode='w+t', suffix='.py', delete=True, | |||
| dir=path + '/nn/probability/transforms/bnn_loss') | |||
| bnn_loss_file.write(bnn_loss_func) | |||
| bnn_loss_file.seek(0) | |||
| sys.path.append(path + '/nn/probability/transforms/bnn_loss') | |||
| module_name = os.path.basename(bnn_loss_file.name)[0:-3] | |||
| bnn_loss_module = importlib.import_module(module_name, __package__) | |||
| bnn_with_loss = bnn_loss_module.WithBNNLossCell(backbone, loss_fn, dnn_factor, bnn_factor) | |||
| return bnn_with_loss, bnn_loss_file | |||
| @@ -1,66 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Original WithBNNLossCell for ast to rewrite.""" | |||
| import mindspore.nn as nn | |||
| from mindspore.nn.probability.bnn_layers.conv_variational import _ConvVariational | |||
| from mindspore.nn.probability.bnn_layers.dense_variational import _DenseVariational | |||
| class WithBNNLossCell(nn.Cell): | |||
| """ | |||
| Cell with loss function. | |||
| Wraps the network with loss function. This Cell accepts data, label, backbone_factor and kl_factor as inputs and | |||
| the computed loss will be returned. | |||
| """ | |||
| def __init__(self, backbone, loss_fn, backbone_factor=1, kl_factor=1): | |||
| super(WithBNNLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| self.backbone_factor = backbone_factor | |||
| self.kl_factor = kl_factor | |||
| self.kl_loss = [] | |||
| self._add_kl_loss(self._backbone) | |||
| def construct(self, x, label): | |||
| y_pred = self._backbone(x) | |||
| backbone_loss = self._loss_fn(y_pred, label) | |||
| kl_loss = self.cal_kl_loss() | |||
| loss = backbone_loss*self.backbone_factor + kl_loss*self.kl_factor | |||
| return loss | |||
| def cal_kl_loss(self): | |||
| """Calculate kl loss.""" | |||
| loss = 0.0 | |||
| return loss | |||
| def _add_kl_loss(self, net): | |||
| """Collect kl loss of each Bayesian layer.""" | |||
| for (_, layer) in net.name_cells().items(): | |||
| if isinstance(layer, (_DenseVariational, _ConvVariational)): | |||
| self.kl_loss.append(layer.compute_kl_loss) | |||
| else: | |||
| self._add_kl_loss(layer) | |||
| @property | |||
| def backbone_network(self): | |||
| """ | |||
| Returns the backbone network. | |||
| Returns: | |||
| Cell, the backbone network. | |||
| """ | |||
| return self._backbone | |||
| @@ -17,8 +17,8 @@ import mindspore.nn as nn | |||
| from ...wrap.cell_wrapper import TrainOneStepCell | |||
| from ....nn import optim | |||
| from ....nn import layer | |||
| from .bnn_loss.generate_kl_loss import gain_bnn_with_loss | |||
| from ...probability import bnn_layers | |||
| from ..bnn_layers.bnn_cell_wrapper import WithBNNLossCell | |||
| from ..bnn_layers.conv_variational import ConvReparam | |||
| from ..bnn_layers.dense_variational import DenseReparam | |||
| @@ -77,7 +77,6 @@ class TransformToBNN: | |||
| self.loss_fn = getattr(net_with_loss, "_loss_fn") | |||
| self.dnn_factor = dnn_factor | |||
| self.bnn_factor = bnn_factor | |||
| self.bnn_loss_file = None | |||
| def transform_to_bnn_model(self, | |||
| get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias, | |||
| @@ -120,15 +119,13 @@ class TransformToBNN: | |||
| if not add_conv_args: | |||
| add_conv_args = {} | |||
| layer_count = self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, | |||
| add_conv_args) | |||
| self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args) | |||
| # rename layers of BNN model to prevent duplication of names | |||
| for value, param in self.backbone.parameters_and_names(): | |||
| param.name = value | |||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||
| self.dnn_factor, self.bnn_factor) | |||
| bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor) | |||
| bnn_optimizer = self._create_optimizer_with_bnn_params() | |||
| train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) | |||
| return train_bnn_network | |||
| @@ -179,13 +176,11 @@ class TransformToBNN: | |||
| if not add_args: | |||
| add_args = {} | |||
| layer_count = self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, | |||
| add_args) | |||
| self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, add_args) | |||
| for value, param in self.backbone.parameters_and_names(): | |||
| param.name = value | |||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||
| self.dnn_factor, self.bnn_factor) | |||
| bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor) | |||
| bnn_optimizer = self._create_optimizer_with_bnn_params() | |||
| train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) | |||
| @@ -228,32 +223,25 @@ class TransformToBNN: | |||
| def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args): | |||
| """Replace both dense layer and conv2d layer in DNN model to bayesian layers.""" | |||
| count = 0 | |||
| for name, cell in backbone.name_cells().items(): | |||
| if isinstance(cell, nn.Dense): | |||
| dense_args = get_dense_args(cell) | |||
| new_layer = DenseReparam(**dense_args, **add_dense_args) | |||
| setattr(backbone, name, new_layer) | |||
| count += 1 | |||
| elif isinstance(cell, nn.Conv2d): | |||
| conv_args = get_conv_args(cell) | |||
| new_layer = ConvReparam(**conv_args, **add_conv_args) | |||
| setattr(backbone, name, new_layer) | |||
| count += 1 | |||
| else: | |||
| count += self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, | |||
| add_conv_args) | |||
| return count | |||
| self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, | |||
| add_conv_args) | |||
| def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args): | |||
| """Convert a specific type of layers in DNN model to corresponding bayesian layers.""" | |||
| count = 0 | |||
| for name, cell in backbone.name_cells().items(): | |||
| if isinstance(cell, dnn_layer): | |||
| args = get_args(cell) | |||
| new_layer = bnn_layer(**args, **add_args) | |||
| setattr(backbone, name, new_layer) | |||
| count += 1 | |||
| else: | |||
| count += self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args) | |||
| return count | |||
| self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args) | |||