Merge pull request !3921 from vlne-v1/I1Q3KN-interface-changes-cause-network-training-failurstags/v0.7.0-beta
| @@ -340,7 +340,8 @@ class Parameter(MetaTensor): | |||||
| Default: False. | Default: False. | ||||
| Returns: | Returns: | ||||
| Parameter, Parameter after init data. | |||||
| Parameter, the `Parameter` after init data. If current `Parameter` already initialized before, | |||||
| returns the same initialized `Parameter`. | |||||
| """ | """ | ||||
| if self.init_mode is None: | if self.init_mode is None: | ||||
| return self | return self | ||||
| @@ -536,6 +536,10 @@ class Cell: | |||||
| """ | """ | ||||
| Init all parameters' data and replace the original saved parameters in cell. | Init all parameters' data and replace the original saved parameters in cell. | ||||
| Notes: | |||||
| trainable_params() and other similar interfaces may return different parameter instance after | |||||
| `init_parameters_data`, do not save these result. | |||||
| Args: | Args: | ||||
| auto_parallel_mode (bool): If running in auto_parallel_mode. | auto_parallel_mode (bool): If running in auto_parallel_mode. | ||||
| @@ -425,12 +425,13 @@ class Optimizer(Cell): | |||||
| raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") | raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") | ||||
| lr = [] | lr = [] | ||||
| ids = [id(p) for p in self.parameters] | |||||
| for p in param_list: | for p in param_list: | ||||
| validator.check_value_type("parameter", p, [Parameter], self.cls_name) | validator.check_value_type("parameter", p, [Parameter], self.cls_name) | ||||
| if p not in self.parameters: | |||||
| if id(p) not in ids: | |||||
| raise ValueError(f"The parameter {p.name} is not in optimizer.") | raise ValueError(f"The parameter {p.name} is not in optimizer.") | ||||
| if self.is_group_lr: | if self.is_group_lr: | ||||
| index = self.parameters.index(p) | |||||
| index = ids.index(id(p)) | |||||
| lr.append(get_lr_value(self.learning_rate[index])) | lr.append(get_lr_value(self.learning_rate[index])) | ||||
| else: | else: | ||||
| lr.append(get_lr_value(self.learning_rate)) | lr.append(get_lr_value(self.learning_rate)) | ||||
| @@ -84,8 +84,14 @@ if __name__ == '__main__': | |||||
| lr = Tensor(lr) | lr = Tensor(lr) | ||||
| # optimizer | # optimizer | ||||
| decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) | |||||
| no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] | |||||
| decayed_params = [] | |||||
| no_decayed_params = [] | |||||
| for param in net.trainable_params(): | |||||
| if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: | |||||
| decayed_params.append(param) | |||||
| else: | |||||
| no_decayed_params.append(param) | |||||
| group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, | group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, | ||||
| {'params': no_decayed_params}, | {'params': no_decayed_params}, | ||||
| {'order_params': net.trainable_params()}] | {'order_params': net.trainable_params()}] | ||||
| @@ -290,7 +290,6 @@ class MobileNetV3(nn.Cell): | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | kernel_size=1, has_bias=True, pad_mode='pad') | ||||
| self.squeeze = P.Squeeze(axis=(2, 3)) | self.squeeze = P.Squeeze(axis=(2, 3)) | ||||
| self.init_parameters_data() | |||||
| self._initialize_weights() | self._initialize_weights() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -320,6 +319,7 @@ class MobileNetV3(nn.Cell): | |||||
| Examples: | Examples: | ||||
| >>> _initialize_weights() | >>> _initialize_weights() | ||||
| """ | """ | ||||
| self.init_parameters_data() | |||||
| for _, m in self.cells_and_names(): | for _, m in self.cells_and_names(): | ||||
| if isinstance(m, (nn.Conv2d)): | if isinstance(m, (nn.Conv2d)): | ||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||||
| @@ -101,12 +101,12 @@ if __name__ == '__main__': | |||||
| for _, cell in net.cells_and_names(): | for _, cell in net.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if isinstance(cell, nn.Dense): | if isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| # init lr | # init lr | ||||
| if args_opt.net == "resnet50": | if args_opt.net == "resnet50": | ||||
| @@ -123,8 +123,14 @@ if __name__ == '__main__': | |||||
| lr = Tensor(lr) | lr = Tensor(lr) | ||||
| # define opt | # define opt | ||||
| decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) | |||||
| no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] | |||||
| decayed_params = [] | |||||
| no_decayed_params = [] | |||||
| for param in net.trainable_params(): | |||||
| if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: | |||||
| decayed_params.append(param) | |||||
| else: | |||||
| no_decayed_params.append(param) | |||||
| group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | ||||
| {'params': no_decayed_params}, | {'params': no_decayed_params}, | ||||
| {'order_params': net.trainable_params()}] | {'order_params': net.trainable_params()}] | ||||
| @@ -91,12 +91,12 @@ if __name__ == '__main__': | |||||
| for _, cell in net.cells_and_names(): | for _, cell in net.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if isinstance(cell, nn.Dense): | if isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if not config.use_label_smooth: | if not config.use_label_smooth: | ||||
| config.label_smooth_factor = 0.0 | config.label_smooth_factor = 0.0 | ||||
| loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | ||||
| @@ -63,19 +63,19 @@ class Resnet(ImageClassificationNetwork): | |||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = init.initializer( | cell.weight.default_input = init.initializer( | ||||
| KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | ||||
| cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, cell.weight.dtype) | |||||
| elif isinstance(cell, nn.BatchNorm2d): | elif isinstance(cell, nn.BatchNorm2d): | ||||
| cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor() | |||||
| cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor() | |||||
| cell.gamma.default_input = init.initializer('ones', cell.gamma.shape) | |||||
| cell.beta.default_input = init.initializer('zeros', cell.beta.shape) | |||||
| # Zero-initialize the last BN in each residual branch, | # Zero-initialize the last BN in each residual branch, | ||||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||
| for cell in self.cells_and_names(): | for cell in self.cells_and_names(): | ||||
| if isinstance(cell, backbones.resnet.Bottleneck): | if isinstance(cell, backbones.resnet.Bottleneck): | ||||
| cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.default_input.shape).to_tensor() | |||||
| cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.shape) | |||||
| elif isinstance(cell, backbones.resnet.BasicBlock): | elif isinstance(cell, backbones.resnet.BasicBlock): | ||||
| cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.default_input.shape).to_tensor() | |||||
| cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.shape) | |||||
| @@ -19,7 +19,6 @@ import math | |||||
| from functools import reduce | from functools import reduce | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | |||||
| from mindspore.common import initializer as init | from mindspore.common import initializer as init | ||||
| def _calculate_gain(nonlinearity, param=None): | def _calculate_gain(nonlinearity, param=None): | ||||
| @@ -191,23 +190,25 @@ def default_recurisive_init(custom_cell): | |||||
| for _, cell in custom_cell.cells_and_names(): | for _, cell in custom_cell.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | bound = 1 / math.sqrt(fan_in) | ||||
| np.random.seed(0) | np.random.seed(0) | ||||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||||
| cell.bias.default_input.dtype) | |||||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype) | |||||
| elif isinstance(cell, nn.Dense): | elif isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | bound = 1 / math.sqrt(fan_in) | ||||
| np.random.seed(0) | np.random.seed(0) | ||||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||||
| cell.bias.default_input.dtype) | |||||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype) | |||||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | ||||
| pass | pass | ||||
| @@ -19,7 +19,6 @@ import math | |||||
| from functools import reduce | from functools import reduce | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | |||||
| from mindspore.common import initializer as init | from mindspore.common import initializer as init | ||||
| def _calculate_gain(nonlinearity, param=None): | def _calculate_gain(nonlinearity, param=None): | ||||
| @@ -191,23 +190,25 @@ def default_recurisive_init(custom_cell): | |||||
| for _, cell in custom_cell.cells_and_names(): | for _, cell in custom_cell.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | bound = 1 / math.sqrt(fan_in) | ||||
| np.random.seed(0) | np.random.seed(0) | ||||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||||
| cell.bias.default_input.dtype) | |||||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype) | |||||
| elif isinstance(cell, nn.Dense): | elif isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) | |||||
| fan_in, _ = _calculate_in_and_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | bound = 1 / math.sqrt(fan_in) | ||||
| np.random.seed(0) | np.random.seed(0) | ||||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||||
| cell.bias.default_input.dtype) | |||||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype) | |||||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | ||||
| pass | pass | ||||
| @@ -102,16 +102,16 @@ class Vgg(nn.Cell): | |||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = init.initializer( | cell.weight.default_input = init.initializer( | ||||
| KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), | ||||
| cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| cell.bias.default_input = init.initializer( | cell.bias.default_input = init.initializer( | ||||
| 'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() | |||||
| 'zeros', cell.bias.shape, cell.bias.dtype) | |||||
| elif isinstance(cell, nn.Dense): | elif isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = init.initializer( | cell.weight.default_input = init.initializer( | ||||
| init.Normal(0.01), cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() | |||||
| init.Normal(0.01), cell.weight.shape, cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| cell.bias.default_input = init.initializer( | cell.bias.default_input = init.initializer( | ||||
| 'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() | |||||
| 'zeros', cell.bias.shape, cell.bias.dtype) | |||||
| cfg = { | cfg = { | ||||
| @@ -14,11 +14,11 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Parameter init.""" | """Parameter init.""" | ||||
| import math | import math | ||||
| from functools import reduce | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.common import initializer as init | from mindspore.common import initializer as init | ||||
| from mindspore.common.initializer import Initializer as MeInitializer | from mindspore.common.initializer import Initializer as MeInitializer | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | |||||
| np.random.seed(5) | np.random.seed(5) | ||||
| @@ -134,7 +134,7 @@ def _calculate_fan_in_and_fan_out(arr): | |||||
| num_output_fmaps = arr.shape[0] | num_output_fmaps = arr.shape[0] | ||||
| receptive_field_size = 1 | receptive_field_size = 1 | ||||
| if dimensions > 2: | if dimensions > 2: | ||||
| receptive_field_size = arr[0][0].size | |||||
| receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:]) | |||||
| fan_in = num_input_fmaps * receptive_field_size | fan_in = num_input_fmaps * receptive_field_size | ||||
| fan_out = num_output_fmaps * receptive_field_size | fan_out = num_output_fmaps * receptive_field_size | ||||
| @@ -159,21 +159,23 @@ def default_recurisive_init(custom_cell): | |||||
| for _, cell in custom_cell.cells_and_names(): | for _, cell in custom_cell.cells_and_names(): | ||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) | |||||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | bound = 1 / math.sqrt(fan_in) | ||||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||||
| cell.bias.default_input.dtype) | |||||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype) | |||||
| elif isinstance(cell, nn.Dense): | elif isinstance(cell, nn.Dense): | ||||
| cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), | ||||
| cell.weight.default_input.shape, | |||||
| cell.weight.default_input.dtype).to_tensor() | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype) | |||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) | |||||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | bound = 1 / math.sqrt(fan_in) | ||||
| cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), | |||||
| cell.bias.default_input.dtype) | |||||
| cell.bias.default_input = init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype) | |||||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | ||||
| pass | pass | ||||
| @@ -58,7 +58,7 @@ def load_backbone(net, ckpt_path, args): | |||||
| darknet_backbone_prefix = 'network.backbone' | darknet_backbone_prefix = 'network.backbone' | ||||
| find_param = [] | find_param = [] | ||||
| not_found_param = [] | not_found_param = [] | ||||
| net.init_parameters_data() | |||||
| for name, cell in net.cells_and_names(): | for name, cell in net.cells_and_names(): | ||||
| if name.startswith(yolo_backbone_prefix): | if name.startswith(yolo_backbone_prefix): | ||||
| name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) | name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) | ||||