| @@ -15,29 +15,12 @@ | |||||
| """Generate WithLossCell suitable for BNN.""" | """Generate WithLossCell suitable for BNN.""" | ||||
| from .conv_variational import _ConvVariational | from .conv_variational import _ConvVariational | ||||
| from .dense_variational import _DenseVariational | from .dense_variational import _DenseVariational | ||||
| from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss | |||||
| from ...cell import Cell | |||||
| __all__ = ['WithBNNLossCell'] | __all__ = ['WithBNNLossCell'] | ||||
| class ClassWrap: | |||||
| """Decorator of WithBNNLossCell""" | |||||
| def __init__(self, cls): | |||||
| self._cls = cls | |||||
| self.bnn_loss_file = None | |||||
| self.__doc__ = cls.__doc__ | |||||
| self.__name__ = cls.__name__ | |||||
| self.__bases__ = cls.__bases__ | |||||
| def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor): | |||||
| obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor) | |||||
| bnn_with_loss = obj() | |||||
| self.bnn_loss_file = obj.bnn_loss_file | |||||
| return bnn_with_loss | |||||
| @ClassWrap | |||||
| class WithBNNLossCell: | |||||
| class WithBNNLossCell(Cell): | |||||
| r""" | r""" | ||||
| Generate a suitable WithLossCell for BNN to wrap the bayesian network with loss function. | Generate a suitable WithLossCell for BNN to wrap the bayesian network with loss function. | ||||
| @@ -68,6 +51,7 @@ class WithBNNLossCell: | |||||
| """ | """ | ||||
| def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | ||||
| super(WithBNNLossCell, self).__init__(auto_prefix=False) | |||||
| if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): | if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): | ||||
| raise TypeError('The type of `dnn_factor` should be `int` or `float`') | raise TypeError('The type of `dnn_factor` should be `int` or `float`') | ||||
| if dnn_factor < 0: | if dnn_factor < 0: | ||||
| @@ -78,28 +62,36 @@ class WithBNNLossCell: | |||||
| if bnn_factor < 0: | if bnn_factor < 0: | ||||
| raise ValueError('The value of `bnn_factor` should >= 0') | raise ValueError('The value of `bnn_factor` should >= 0') | ||||
| self.backbone = backbone | |||||
| self.loss_fn = loss_fn | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| self.dnn_factor = dnn_factor | self.dnn_factor = dnn_factor | ||||
| self.bnn_factor = bnn_factor | self.bnn_factor = bnn_factor | ||||
| self.bnn_loss_file = None | |||||
| def _generate_loss_cell(self): | |||||
| """Generate WithBNNLossCell by ast.""" | |||||
| layer_count = self._kl_loss_count(self.backbone) | |||||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||||
| self.dnn_factor, self.bnn_factor) | |||||
| return bnn_with_loss | |||||
| def _kl_loss_count(self, net): | |||||
| """ Calculate the number of Bayesian layers.""" | |||||
| count = 0 | |||||
| self.kl_loss = [] | |||||
| self._add_kl_loss(self._backbone) | |||||
| def construct(self, x, label): | |||||
| y_pred = self._backbone(x) | |||||
| backbone_loss = self._loss_fn(y_pred, label) | |||||
| kl_loss = 0 | |||||
| for i in range(len(self.kl_loss)): | |||||
| kl_loss += self.kl_loss[i]() | |||||
| loss = backbone_loss * self.dnn_factor + kl_loss * self.bnn_factor | |||||
| return loss | |||||
| def _add_kl_loss(self, net): | |||||
| """Collect kl loss of each Bayesian layer.""" | |||||
| for (_, layer) in net.name_cells().items(): | for (_, layer) in net.name_cells().items(): | ||||
| if isinstance(layer, (_DenseVariational, _ConvVariational)): | if isinstance(layer, (_DenseVariational, _ConvVariational)): | ||||
| count += 1 | |||||
| self.kl_loss.append(layer.compute_kl_loss) | |||||
| else: | else: | ||||
| count += self._kl_loss_count(layer) | |||||
| return count | |||||
| self._add_kl_loss(layer) | |||||
| @property | |||||
| def backbone_network(self): | |||||
| """ | |||||
| Returns the backbone network. | |||||
| def __call__(self): | |||||
| return self._generate_loss_cell() | |||||
| Returns: | |||||
| Cell, the backbone network. | |||||
| """ | |||||
| return self._backbone | |||||
| @@ -1,19 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| bnn loss. | |||||
| """ | |||||
| from . import generate_kl_loss | |||||
| from .generate_kl_loss import gain_bnn_with_loss | |||||
| @@ -1,89 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Gain bnn_with_loss by rewrite WithLossCell as WithBNNLossCell to suit for BNN model""" | |||||
| import ast | |||||
| import importlib | |||||
| import os | |||||
| import sys | |||||
| import tempfile | |||||
| import astunparse | |||||
| import mindspore | |||||
| class _CodeTransformer(ast.NodeTransformer): | |||||
| """ | |||||
| Add kl_loss computation by analyzing the python code structure with the help of the AST module. | |||||
| Args: | |||||
| layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. | |||||
| """ | |||||
| def __init__(self, layer_count): | |||||
| self.layer_count = layer_count | |||||
| def visit_FunctionDef(self, node): | |||||
| """visit function and add kl_loss computation.""" | |||||
| self.generic_visit(node) | |||||
| if node.name == 'cal_kl_loss': | |||||
| for i in range(self.layer_count): | |||||
| func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())], | |||||
| value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(), | |||||
| right=ast.Call(func=ast.Name(id='self.kl_loss' + '[' + str(i) + ']', | |||||
| ctx=ast.Load()), | |||||
| args=[], keywords=[]))) | |||||
| node.body.insert(-1, func) | |||||
| return node | |||||
| def _generate_kl_loss_func(layer_count): | |||||
| """Rewrite WithLossCell as WithBNNLossCell to suit for BNN model.""" | |||||
| path = os.path.dirname(mindspore.__file__) + '/nn/probability/transforms/bnn_loss/withLossCell.py' | |||||
| with open(path, 'r') as fp: | |||||
| srclines = fp.readlines() | |||||
| src = ''.join(srclines) | |||||
| if src.startswith((' ', '\t')): | |||||
| src = 'if 1:\n' + src | |||||
| expr_ast = ast.parse(src, mode='exec') | |||||
| transformer = _CodeTransformer(layer_count) | |||||
| modify = transformer.visit(expr_ast) | |||||
| modify = ast.fix_missing_locations(modify) | |||||
| func = astunparse.unparse(modify) | |||||
| return func | |||||
| def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor): | |||||
| """ | |||||
| Gain bnn_with_loss, which wraps bnn network with loss function and kl loss of each bayesian layer. | |||||
| Args: | |||||
| layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. | |||||
| backbone (Cell): The target network to wrap. | |||||
| loss_fn (Cell): The loss function used to compute loss. | |||||
| dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function. | |||||
| bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. | |||||
| """ | |||||
| bnn_loss_func = _generate_kl_loss_func(layer_count) | |||||
| path = os.path.dirname(mindspore.__file__) | |||||
| bnn_loss_file = tempfile.NamedTemporaryFile(mode='w+t', suffix='.py', delete=True, | |||||
| dir=path + '/nn/probability/transforms/bnn_loss') | |||||
| bnn_loss_file.write(bnn_loss_func) | |||||
| bnn_loss_file.seek(0) | |||||
| sys.path.append(path + '/nn/probability/transforms/bnn_loss') | |||||
| module_name = os.path.basename(bnn_loss_file.name)[0:-3] | |||||
| bnn_loss_module = importlib.import_module(module_name, __package__) | |||||
| bnn_with_loss = bnn_loss_module.WithBNNLossCell(backbone, loss_fn, dnn_factor, bnn_factor) | |||||
| return bnn_with_loss, bnn_loss_file | |||||
| @@ -1,66 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Original WithBNNLossCell for ast to rewrite.""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.nn.probability.bnn_layers.conv_variational import _ConvVariational | |||||
| from mindspore.nn.probability.bnn_layers.dense_variational import _DenseVariational | |||||
| class WithBNNLossCell(nn.Cell): | |||||
| """ | |||||
| Cell with loss function. | |||||
| Wraps the network with loss function. This Cell accepts data, label, backbone_factor and kl_factor as inputs and | |||||
| the computed loss will be returned. | |||||
| """ | |||||
| def __init__(self, backbone, loss_fn, backbone_factor=1, kl_factor=1): | |||||
| super(WithBNNLossCell, self).__init__(auto_prefix=False) | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| self.backbone_factor = backbone_factor | |||||
| self.kl_factor = kl_factor | |||||
| self.kl_loss = [] | |||||
| self._add_kl_loss(self._backbone) | |||||
| def construct(self, x, label): | |||||
| y_pred = self._backbone(x) | |||||
| backbone_loss = self._loss_fn(y_pred, label) | |||||
| kl_loss = self.cal_kl_loss() | |||||
| loss = backbone_loss*self.backbone_factor + kl_loss*self.kl_factor | |||||
| return loss | |||||
| def cal_kl_loss(self): | |||||
| """Calculate kl loss.""" | |||||
| loss = 0.0 | |||||
| return loss | |||||
| def _add_kl_loss(self, net): | |||||
| """Collect kl loss of each Bayesian layer.""" | |||||
| for (_, layer) in net.name_cells().items(): | |||||
| if isinstance(layer, (_DenseVariational, _ConvVariational)): | |||||
| self.kl_loss.append(layer.compute_kl_loss) | |||||
| else: | |||||
| self._add_kl_loss(layer) | |||||
| @property | |||||
| def backbone_network(self): | |||||
| """ | |||||
| Returns the backbone network. | |||||
| Returns: | |||||
| Cell, the backbone network. | |||||
| """ | |||||
| return self._backbone | |||||
| @@ -17,8 +17,8 @@ import mindspore.nn as nn | |||||
| from ...wrap.cell_wrapper import TrainOneStepCell | from ...wrap.cell_wrapper import TrainOneStepCell | ||||
| from ....nn import optim | from ....nn import optim | ||||
| from ....nn import layer | from ....nn import layer | ||||
| from .bnn_loss.generate_kl_loss import gain_bnn_with_loss | |||||
| from ...probability import bnn_layers | from ...probability import bnn_layers | ||||
| from ..bnn_layers.bnn_cell_wrapper import WithBNNLossCell | |||||
| from ..bnn_layers.conv_variational import ConvReparam | from ..bnn_layers.conv_variational import ConvReparam | ||||
| from ..bnn_layers.dense_variational import DenseReparam | from ..bnn_layers.dense_variational import DenseReparam | ||||
| @@ -77,7 +77,6 @@ class TransformToBNN: | |||||
| self.loss_fn = getattr(net_with_loss, "_loss_fn") | self.loss_fn = getattr(net_with_loss, "_loss_fn") | ||||
| self.dnn_factor = dnn_factor | self.dnn_factor = dnn_factor | ||||
| self.bnn_factor = bnn_factor | self.bnn_factor = bnn_factor | ||||
| self.bnn_loss_file = None | |||||
| def transform_to_bnn_model(self, | def transform_to_bnn_model(self, | ||||
| get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias, | get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias, | ||||
| @@ -120,15 +119,13 @@ class TransformToBNN: | |||||
| if not add_conv_args: | if not add_conv_args: | ||||
| add_conv_args = {} | add_conv_args = {} | ||||
| layer_count = self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, | |||||
| add_conv_args) | |||||
| self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args) | |||||
| # rename layers of BNN model to prevent duplication of names | # rename layers of BNN model to prevent duplication of names | ||||
| for value, param in self.backbone.parameters_and_names(): | for value, param in self.backbone.parameters_and_names(): | ||||
| param.name = value | param.name = value | ||||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||||
| self.dnn_factor, self.bnn_factor) | |||||
| bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor) | |||||
| bnn_optimizer = self._create_optimizer_with_bnn_params() | bnn_optimizer = self._create_optimizer_with_bnn_params() | ||||
| train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) | train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) | ||||
| return train_bnn_network | return train_bnn_network | ||||
| @@ -179,13 +176,11 @@ class TransformToBNN: | |||||
| if not add_args: | if not add_args: | ||||
| add_args = {} | add_args = {} | ||||
| layer_count = self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, | |||||
| add_args) | |||||
| self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args, add_args) | |||||
| for value, param in self.backbone.parameters_and_names(): | for value, param in self.backbone.parameters_and_names(): | ||||
| param.name = value | param.name = value | ||||
| bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn, | |||||
| self.dnn_factor, self.bnn_factor) | |||||
| bnn_with_loss = WithBNNLossCell(self.backbone, self.loss_fn, self.dnn_factor, self.bnn_factor) | |||||
| bnn_optimizer = self._create_optimizer_with_bnn_params() | bnn_optimizer = self._create_optimizer_with_bnn_params() | ||||
| train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) | train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer) | ||||
| @@ -228,32 +223,25 @@ class TransformToBNN: | |||||
| def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args): | def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args): | ||||
| """Replace both dense layer and conv2d layer in DNN model to bayesian layers.""" | """Replace both dense layer and conv2d layer in DNN model to bayesian layers.""" | ||||
| count = 0 | |||||
| for name, cell in backbone.name_cells().items(): | for name, cell in backbone.name_cells().items(): | ||||
| if isinstance(cell, nn.Dense): | if isinstance(cell, nn.Dense): | ||||
| dense_args = get_dense_args(cell) | dense_args = get_dense_args(cell) | ||||
| new_layer = DenseReparam(**dense_args, **add_dense_args) | new_layer = DenseReparam(**dense_args, **add_dense_args) | ||||
| setattr(backbone, name, new_layer) | setattr(backbone, name, new_layer) | ||||
| count += 1 | |||||
| elif isinstance(cell, nn.Conv2d): | elif isinstance(cell, nn.Conv2d): | ||||
| conv_args = get_conv_args(cell) | conv_args = get_conv_args(cell) | ||||
| new_layer = ConvReparam(**conv_args, **add_conv_args) | new_layer = ConvReparam(**conv_args, **add_conv_args) | ||||
| setattr(backbone, name, new_layer) | setattr(backbone, name, new_layer) | ||||
| count += 1 | |||||
| else: | else: | ||||
| count += self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, | |||||
| add_conv_args) | |||||
| return count | |||||
| self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args, | |||||
| add_conv_args) | |||||
| def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args): | def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args): | ||||
| """Convert a specific type of layers in DNN model to corresponding bayesian layers.""" | """Convert a specific type of layers in DNN model to corresponding bayesian layers.""" | ||||
| count = 0 | |||||
| for name, cell in backbone.name_cells().items(): | for name, cell in backbone.name_cells().items(): | ||||
| if isinstance(cell, dnn_layer): | if isinstance(cell, dnn_layer): | ||||
| args = get_args(cell) | args = get_args(cell) | ||||
| new_layer = bnn_layer(**args, **add_args) | new_layer = bnn_layer(**args, **add_args) | ||||
| setattr(backbone, name, new_layer) | setattr(backbone, name, new_layer) | ||||
| count += 1 | |||||
| else: | else: | ||||
| count += self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args) | |||||
| return count | |||||
| self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args) | |||||