Merge pull request !3921 from vlne-v1/I1Q3KN-interface-changes-cause-network-training-failurstags/v0.7.0-beta
| @@ -340,7 +340,8 @@ class Parameter(MetaTensor): | |||
| Default: False. | |||
| Returns: | |||
| Parameter, Parameter after init data. | |||
| Parameter, the `Parameter` after init data. If current `Parameter` already initialized before, | |||
| returns the same initialized `Parameter`. | |||
| """ | |||
| if self.init_mode is None: | |||
| return self | |||
| @@ -536,6 +536,10 @@ class Cell: | |||
| """ | |||
| Init all parameters' data and replace the original saved parameters in cell. | |||
| Notes: | |||
| trainable_params() and other similar interfaces may return different parameter instance after | |||
| `init_parameters_data`, do not save these result. | |||
| Args: | |||
| auto_parallel_mode (bool): If running in auto_parallel_mode. | |||
| @@ -425,12 +425,13 @@ class Optimizer(Cell): | |||
| raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") | |||
| lr = [] | |||
| ids = [id(p) for p in self.parameters] | |||
| for p in param_list: | |||
| validator.check_value_type("parameter", p, [Parameter], self.cls_name) | |||
| if p not in self.parameters: | |||
| if id(p) not in ids: | |||
| raise ValueError(f"The parameter {p.name} is not in optimizer.") | |||
| if self.is_group_lr: | |||
| index = self.parameters.index(p) | |||
| index = ids.index(id(p)) | |||
| lr.append(get_lr_value(self.learning_rate[index])) | |||
| else: | |||
| lr.append(get_lr_value(self.learning_rate)) | |||
| @@ -84,8 +84,14 @@ if __name__ == '__main__': | |||
| lr = Tensor(lr) | |||
| # optimizer | |||
| decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) | |||
| no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] | |||
| decayed_params = [] | |||
| no_decayed_params = [] | |||
| for param in net.trainable_params(): | |||
| if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: | |||
| decayed_params.append(param) | |||
| else: | |||
| no_decayed_params.append(param) | |||
| group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, | |||
| {'params': no_decayed_params}, | |||
| {'order_params': net.trainable_params()}] | |||
| @@ -290,7 +290,6 @@ class MobileNetV3(nn.Cell): | |||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||
| self.squeeze = P.Squeeze(axis=(2, 3)) | |||
| self.init_parameters_data() | |||
| self._initialize_weights() | |||
| def construct(self, x): | |||
| @@ -320,6 +319,7 @@ class MobileNetV3(nn.Cell): | |||
| Examples: | |||
| >>> _initialize_weights() | |||
| """ | |||
| self.init_parameters_data() | |||
| for _, m in self.cells_and_names(): | |||
| if isinstance(m, (nn.Conv2d)): | |||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
| @@ -101,12 +101,12 @@ if __name__ == '__main__': | |||
| for _, cell in net.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| # init lr | |||
| if args_opt.net == "resnet50": | |||
| @@ -123,8 +123,14 @@ if __name__ == '__main__': | |||
| lr = Tensor(lr) | |||
| # define opt | |||
| decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) | |||
| no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] | |||
| decayed_params = [] | |||
| no_decayed_params = [] | |||
| for param in net.trainable_params(): | |||
| if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: | |||
| decayed_params.append(param) | |||
| else: | |||
| no_decayed_params.append(param) | |||
| group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | |||
| {'params': no_decayed_params}, | |||
| {'order_params': net.trainable_params()}] | |||
| @@ -91,12 +91,12 @@ if __name__ == '__main__': | |||
| for _, cell in net.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if not config.use_label_smooth: | |||
| config.label_smooth_factor = 0.0 | |||
| loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| @@ -63,19 +63,19 @@ class Resnet(ImageClassificationNetwork): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer( | |||
| KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | |||
| cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, cell.weight.dtype) | |||
| elif isinstance(cell, nn.BatchNorm2d): | |||
| cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor() | |||
| cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor() | |||
| cell.gamma.default_input = init.initializer('ones', cell.gamma.shape) | |||
| cell.beta.default_input = init.initializer('zeros', cell.beta.shape) | |||
| # Zero-initialize the last BN in each residual branch, | |||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||
| for cell in self.cells_and_names(): | |||
| if isinstance(cell, backbones.resnet.Bottleneck): | |||
| cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.default_input.shape).to_tensor() | |||
| cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.shape) | |||
| elif isinstance(cell, backbones.resnet.BasicBlock): | |||
| cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.default_input.shape).to_tensor() | |||
| cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.shape) | |||
| @@ -19,7 +19,6 @@ import math | |||
| from functools import reduce | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import initializer as init | |||
| def _calculate_gain(nonlinearity, param=None): | |||
| @@ -191,23 +190,25 @@ def default_recurisive_init(custom_cell): | |||
| for _, cell in custom_cell.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| np.random.seed(0) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| np.random.seed(0) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| pass | |||
| @@ -19,7 +19,6 @@ import math | |||
| from functools import reduce | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import initializer as init | |||
| def _calculate_gain(nonlinearity, param=None): | |||
| @@ -191,23 +190,25 @@ def default_recurisive_init(custom_cell): | |||
| for _, cell in custom_cell.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| np.random.seed(0) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| np.random.seed(0) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| pass | |||
| @@ -102,16 +102,16 @@ class Vgg(nn.Cell): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer( | |||
| KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | |||
| cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| cell.bias.default_input = init.initializer( | |||
| 'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() | |||
| 'zeros', cell.bias.shape, cell.bias.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer( | |||
| init.Normal(0.01), cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() | |||
| init.Normal(0.01), cell.weight.shape, cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| cell.bias.default_input = init.initializer( | |||
| 'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() | |||
| 'zeros', cell.bias.shape, cell.bias.dtype) | |||
| cfg = { | |||
| @@ -14,11 +14,11 @@ | |||
| # ============================================================================ | |||
| """Parameter init.""" | |||
| import math | |||
| from functools import reduce | |||
| import numpy as np | |||
| from mindspore.common import initializer as init | |||
| from mindspore.common.initializer import Initializer as MeInitializer | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| np.random.seed(5) | |||
| @@ -134,7 +134,7 @@ def _calculate_fan_in_and_fan_out(arr): | |||
| num_output_fmaps = arr.shape[0] | |||
| receptive_field_size = 1 | |||
| if dimensions > 2: | |||
| receptive_field_size = arr[0][0].size | |||
| receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:]) | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| @@ -159,21 +159,23 @@ def default_recurisive_init(custom_cell): | |||
| for _, cell in custom_cell.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, nn.Dense): | |||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | |||
| cell.weight.default_input.shape, | |||
| cell.weight.default_input.dtype).to_tensor() | |||
| cell.weight.shape, | |||
| cell.weight.dtype) | |||
| if cell.bias is not None: | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) | |||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) | |||
| bound = 1 / math.sqrt(fan_in) | |||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||
| cell.bias.default_input.dtype) | |||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||
| cell.bias.shape, | |||
| cell.bias.dtype) | |||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||
| pass | |||
| @@ -58,7 +58,7 @@ def load_backbone(net, ckpt_path, args): | |||
| darknet_backbone_prefix = 'network.backbone' | |||
| find_param = [] | |||
| not_found_param = [] | |||
| net.init_parameters_data() | |||
| for name, cell in net.cells_and_names(): | |||
| if name.startswith(yolo_backbone_prefix): | |||
| name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) | |||