Merge pull request !6852 from 吴书全/yolov3_dark_0924tags/v1.1.0
| @@ -16,3 +16,4 @@ pandas >= 1.0.2 # for ut test | |||||
| bs4 | bs4 | ||||
| astunparse | astunparse | ||||
| packaging >= 20.0 | packaging >= 20.0 | ||||
| pycocotools >= 2.0.0 # for st test | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Config parameters for Darknet based yolov3_darknet53 models.""" | |||||
| class ConfigYOLOV3DarkNet53: | |||||
| """ | |||||
| Config parameters for the yolov3_darknet53. | |||||
| Examples: | |||||
| ConfigYOLOV3DarkNet53() | |||||
| """ | |||||
| # train_param | |||||
| # data augmentation related | |||||
| hue = 0.1 | |||||
| saturation = 1.5 | |||||
| value = 1.5 | |||||
| jitter = 0.3 | |||||
| resize_rate = 1 | |||||
| multi_scale = [[320, 320], | |||||
| [352, 352], | |||||
| [384, 384], | |||||
| [416, 416], | |||||
| [448, 448], | |||||
| [480, 480], | |||||
| [512, 512], | |||||
| [544, 544], | |||||
| [576, 576], | |||||
| [608, 608] | |||||
| ] | |||||
| num_classes = 80 | |||||
| max_box = 50 | |||||
| backbone_input_shape = [32, 64, 128, 256, 512] | |||||
| backbone_shape = [64, 128, 256, 512, 1024] | |||||
| backbone_layers = [1, 2, 8, 8, 4] | |||||
| # confidence under ignore_threshold means no object when training | |||||
| ignore_threshold = 0.7 | |||||
| # h->w | |||||
| anchor_scales = [(10, 13), | |||||
| (16, 30), | |||||
| (33, 23), | |||||
| (30, 61), | |||||
| (62, 45), | |||||
| (59, 119), | |||||
| (116, 90), | |||||
| (156, 198), | |||||
| (373, 326)] | |||||
| out_channel = 255 | |||||
| # test_param | |||||
| test_img_shape = [416, 416] | |||||
| label_smooth = 0 | |||||
| label_smooth_factor = 0.1 | |||||
| @@ -0,0 +1,211 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """DarkNet model.""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| def conv_block(in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| dilation=1): | |||||
| """Get a conv2d batchnorm and relu layer""" | |||||
| pad_mode = 'same' | |||||
| padding = 0 | |||||
| return nn.SequentialCell( | |||||
| [nn.Conv2d(in_channels, | |||||
| out_channels, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| pad_mode=pad_mode), | |||||
| nn.BatchNorm2d(out_channels, momentum=0.1), | |||||
| nn.ReLU()] | |||||
| ) | |||||
| class ResidualBlock(nn.Cell): | |||||
| """ | |||||
| DarkNet V1 residual block definition. | |||||
| Args: | |||||
| in_channels: Integer. Input channel. | |||||
| out_channels: Integer. Output channel. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| ResidualBlock(3, 208) | |||||
| """ | |||||
| expansion = 4 | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels): | |||||
| super(ResidualBlock, self).__init__() | |||||
| out_chls = out_channels//2 | |||||
| self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) | |||||
| self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, x): | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.conv2(out) | |||||
| out = self.add(out, identity) | |||||
| return out | |||||
| class DarkNet(nn.Cell): | |||||
| """ | |||||
| DarkNet V1 network. | |||||
| Args: | |||||
| block: Cell. Block for network. | |||||
| layer_nums: List. Numbers of different layers. | |||||
| in_channels: Integer. Input channel. | |||||
| out_channels: Integer. Output channel. | |||||
| detect: Bool. Whether detect or not. Default:False. | |||||
| Returns: | |||||
| Tuple, tuple of output tensor,(f1,f2,f3,f4,f5). | |||||
| Examples: | |||||
| DarkNet(ResidualBlock, | |||||
| [1, 2, 8, 8, 4], | |||||
| [32, 64, 128, 256, 512], | |||||
| [64, 128, 256, 512, 1024], | |||||
| 100) | |||||
| """ | |||||
| def __init__(self, | |||||
| block, | |||||
| layer_nums, | |||||
| in_channels, | |||||
| out_channels, | |||||
| detect=False): | |||||
| super(DarkNet, self).__init__() | |||||
| self.outchannel = out_channels[-1] | |||||
| self.detect = detect | |||||
| if not len(layer_nums) == len(in_channels) == len(out_channels) == 5: | |||||
| raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!") | |||||
| self.conv0 = conv_block(3, | |||||
| in_channels[0], | |||||
| kernel_size=3, | |||||
| stride=1) | |||||
| self.conv1 = conv_block(in_channels[0], | |||||
| out_channels[0], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv2 = conv_block(in_channels[1], | |||||
| out_channels[1], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv3 = conv_block(in_channels[2], | |||||
| out_channels[2], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv4 = conv_block(in_channels[3], | |||||
| out_channels[3], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv5 = conv_block(in_channels[4], | |||||
| out_channels[4], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.layer1 = self._make_layer(block, | |||||
| layer_nums[0], | |||||
| in_channel=out_channels[0], | |||||
| out_channel=out_channels[0]) | |||||
| self.layer2 = self._make_layer(block, | |||||
| layer_nums[1], | |||||
| in_channel=out_channels[1], | |||||
| out_channel=out_channels[1]) | |||||
| self.layer3 = self._make_layer(block, | |||||
| layer_nums[2], | |||||
| in_channel=out_channels[2], | |||||
| out_channel=out_channels[2]) | |||||
| self.layer4 = self._make_layer(block, | |||||
| layer_nums[3], | |||||
| in_channel=out_channels[3], | |||||
| out_channel=out_channels[3]) | |||||
| self.layer5 = self._make_layer(block, | |||||
| layer_nums[4], | |||||
| in_channel=out_channels[4], | |||||
| out_channel=out_channels[4]) | |||||
| def _make_layer(self, block, layer_num, in_channel, out_channel): | |||||
| """ | |||||
| Make Layer for DarkNet. | |||||
| :param block: Cell. DarkNet block. | |||||
| :param layer_num: Integer. Layer number. | |||||
| :param in_channel: Integer. Input channel. | |||||
| :param out_channel: Integer. Output channel. | |||||
| Examples: | |||||
| _make_layer(ConvBlock, 1, 128, 256) | |||||
| """ | |||||
| layers = [] | |||||
| darkblk = block(in_channel, out_channel) | |||||
| layers.append(darkblk) | |||||
| for _ in range(1, layer_num): | |||||
| darkblk = block(out_channel, out_channel) | |||||
| layers.append(darkblk) | |||||
| return nn.SequentialCell(layers) | |||||
| def construct(self, x): | |||||
| c1 = self.conv0(x) | |||||
| c2 = self.conv1(c1) | |||||
| c3 = self.layer1(c2) | |||||
| c4 = self.conv2(c3) | |||||
| c5 = self.layer2(c4) | |||||
| c6 = self.conv3(c5) | |||||
| c7 = self.layer3(c6) | |||||
| c8 = self.conv4(c7) | |||||
| c9 = self.layer4(c8) | |||||
| c10 = self.conv5(c9) | |||||
| c11 = self.layer5(c10) | |||||
| if self.detect: | |||||
| return c7, c9, c11 | |||||
| return c11 | |||||
| def get_out_channels(self): | |||||
| return self.outchannel | |||||
| def darknet53(): | |||||
| """ | |||||
| Get DarkNet53 neural network. | |||||
| Returns: | |||||
| Cell, cell instance of DarkNet53 neural network. | |||||
| Examples: | |||||
| darknet53() | |||||
| """ | |||||
| return DarkNet(ResidualBlock, [1, 2, 8, 8, 4], | |||||
| [32, 64, 128, 256, 512], | |||||
| [64, 128, 256, 512, 1024]) | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Yolo dataset distributed sampler.""" | |||||
| from __future__ import division | |||||
| import math | |||||
| import numpy as np | |||||
| class DistributedSampler: | |||||
| """Distributed sampler.""" | |||||
| def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): | |||||
| if num_replicas is None: | |||||
| print("***********Setting world_size to 1 since it is not passed in ******************") | |||||
| num_replicas = 1 | |||||
| if rank is None: | |||||
| print("***********Setting rank to 0 since it is not passed in ******************") | |||||
| rank = 0 | |||||
| self.dataset_size = dataset_size | |||||
| self.num_replicas = num_replicas | |||||
| self.rank = rank | |||||
| self.epoch = 0 | |||||
| self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) | |||||
| self.total_size = self.num_samples * self.num_replicas | |||||
| self.shuffle = shuffle | |||||
| def __iter__(self): | |||||
| # deterministically shuffle based on epoch | |||||
| if self.shuffle: | |||||
| indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) | |||||
| # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset | |||||
| indices = indices.tolist() | |||||
| self.epoch += 1 | |||||
| # change to list type | |||||
| else: | |||||
| indices = list(range(self.dataset_size)) | |||||
| # add extra samples to make it evenly divisible | |||||
| indices += indices[:(self.total_size - len(indices))] | |||||
| assert len(indices) == self.total_size | |||||
| # subsample | |||||
| indices = indices[self.rank:self.total_size:self.num_replicas] | |||||
| assert len(indices) == self.num_samples | |||||
| return iter(indices) | |||||
| def __len__(self): | |||||
| return self.num_samples | |||||
| @@ -0,0 +1,204 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Parameter init.""" | |||||
| import math | |||||
| from functools import reduce | |||||
| import numpy as np | |||||
| from mindspore.common import initializer as init | |||||
| from mindspore.common.initializer import Initializer as MeInitializer | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| import mindspore.nn as nn | |||||
| from .util import load_backbone | |||||
| def calculate_gain(nonlinearity, param=None): | |||||
| r"""Return the recommended gain value for the given nonlinearity function. | |||||
| The values are as follows: | |||||
| ================= ==================================================== | |||||
| nonlinearity gain | |||||
| ================= ==================================================== | |||||
| Linear / Identity :math:`1` | |||||
| Conv{1,2,3}D :math:`1` | |||||
| Sigmoid :math:`1` | |||||
| Tanh :math:`\frac{5}{3}` | |||||
| ReLU :math:`\sqrt{2}` | |||||
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` | |||||
| ================= ==================================================== | |||||
| Args: | |||||
| nonlinearity: the non-linear function (`nn.functional` name) | |||||
| param: optional parameter for the non-linear function | |||||
| Examples: | |||||
| >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 | |||||
| """ | |||||
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |||||
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |||||
| return 1 | |||||
| if nonlinearity == 'tanh': | |||||
| return 5.0 / 3 | |||||
| if nonlinearity == 'relu': | |||||
| return math.sqrt(2.0) | |||||
| if nonlinearity == 'leaky_relu': | |||||
| if param is None: | |||||
| negative_slope = 0.01 | |||||
| elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |||||
| # True/False are instances of int, hence check above | |||||
| negative_slope = param | |||||
| else: | |||||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||||
| def _assignment(arr, num): | |||||
| """Assign the value of 'num' and 'arr'.""" | |||||
| if arr.shape == (): | |||||
| arr = arr.reshape((1)) | |||||
| arr[:] = num | |||||
| arr = arr.reshape(()) | |||||
| else: | |||||
| if isinstance(num, np.ndarray): | |||||
| arr[:] = num[:] | |||||
| else: | |||||
| arr[:] = num | |||||
| return arr | |||||
| def _calculate_correct_fan(array, mode): | |||||
| mode = mode.lower() | |||||
| valid_modes = ['fan_in', 'fan_out'] | |||||
| if mode not in valid_modes: | |||||
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | |||||
| fan_in, fan_out = _calculate_fan_in_and_fan_out(array) | |||||
| return fan_in if mode == 'fan_in' else fan_out | |||||
| def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): | |||||
| r"""Fills the input `Tensor` with values according to the method | |||||
| described in `Delving deep into rectifiers: Surpassing human-level | |||||
| performance on ImageNet classification` - He, K. et al. (2015), using a | |||||
| uniform distribution. The resulting tensor will have values sampled from | |||||
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |||||
| .. math:: | |||||
| \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} | |||||
| Also known as He initialization. | |||||
| Args: | |||||
| tensor: an n-dimensional `Tensor` | |||||
| a: the negative slope of the rectifier used after this layer (only | |||||
| used with ``'leaky_relu'``) | |||||
| mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` | |||||
| preserves the magnitude of the variance of the weights in the | |||||
| forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the | |||||
| backwards pass. | |||||
| nonlinearity: the non-linear function (`nn.functional` name), | |||||
| recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). | |||||
| Examples: | |||||
| >>> w = np.empty(3, 5) | |||||
| >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') | |||||
| """ | |||||
| fan = _calculate_correct_fan(arr, mode) | |||||
| gain = calculate_gain(nonlinearity, a) | |||||
| std = gain / math.sqrt(fan) | |||||
| bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation | |||||
| return np.random.uniform(-bound, bound, arr.shape) | |||||
| def _calculate_fan_in_and_fan_out(arr): | |||||
| """Calculate fan in and fan out.""" | |||||
| dimensions = len(arr.shape) | |||||
| if dimensions < 2: | |||||
| raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") | |||||
| num_input_fmaps = arr.shape[1] | |||||
| num_output_fmaps = arr.shape[0] | |||||
| receptive_field_size = 1 | |||||
| if dimensions > 2: | |||||
| receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:]) | |||||
| fan_in = num_input_fmaps * receptive_field_size | |||||
| fan_out = num_output_fmaps * receptive_field_size | |||||
| return fan_in, fan_out | |||||
| class KaimingUniform(MeInitializer): | |||||
| """Kaiming uniform initializer.""" | |||||
| def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): | |||||
| super(KaimingUniform, self).__init__() | |||||
| self.a = a | |||||
| self.mode = mode | |||||
| self.nonlinearity = nonlinearity | |||||
| def _initialize(self, arr): | |||||
| tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) | |||||
| _assignment(arr, tmp) | |||||
| def default_recurisive_init(custom_cell): | |||||
| """Initialize parameter.""" | |||||
| for _, cell in custom_cell.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype)) | |||||
| if cell.bias is not None: | |||||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | |||||
| cell.bias.set_data(init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype)) | |||||
| elif isinstance(cell, nn.Dense): | |||||
| cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), | |||||
| cell.weight.shape, | |||||
| cell.weight.dtype)) | |||||
| if cell.bias is not None: | |||||
| fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) | |||||
| bound = 1 / math.sqrt(fan_in) | |||||
| cell.bias.set_data(init.initializer(init.Uniform(bound), | |||||
| cell.bias.shape, | |||||
| cell.bias.dtype)) | |||||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||||
| pass | |||||
| def load_yolov3_params(args, network): | |||||
| """Load yolov3 darknet parameter from checkpoint.""" | |||||
| if args.pretrained_backbone: | |||||
| network = load_backbone(network, args.pretrained_backbone, args) | |||||
| args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone)) | |||||
| else: | |||||
| args.logger.info('Not load pre-trained backbone, please be careful') | |||||
| if args.resume_yolov3: | |||||
| param_dict = load_checkpoint(args.resume_yolov3) | |||||
| param_dict_new = {} | |||||
| for key, values in param_dict.items(): | |||||
| if key.startswith('moments.'): | |||||
| continue | |||||
| elif key.startswith('yolo_network.'): | |||||
| param_dict_new[key[13:]] = values | |||||
| args.logger.info('in resume {}'.format(key)) | |||||
| else: | |||||
| param_dict_new[key] = values | |||||
| args.logger.info('in resume {}'.format(key)) | |||||
| args.logger.info('resume finished') | |||||
| load_param_into_net(network, param_dict_new) | |||||
| args.logger.info('load_model {} success'.format(args.resume_yolov3)) | |||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Custom Logger.""" | |||||
| import os | |||||
| import sys | |||||
| import logging | |||||
| from datetime import datetime | |||||
| class LOGGER(logging.Logger): | |||||
| """ | |||||
| Logger. | |||||
| Args: | |||||
| logger_name: String. Logger name. | |||||
| rank: Integer. Rank id. | |||||
| """ | |||||
| def __init__(self, logger_name, rank=0): | |||||
| super(LOGGER, self).__init__(logger_name) | |||||
| self.rank = rank | |||||
| if rank % 8 == 0: | |||||
| console = logging.StreamHandler(sys.stdout) | |||||
| console.setLevel(logging.INFO) | |||||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||||
| console.setFormatter(formatter) | |||||
| self.addHandler(console) | |||||
| def setup_logging_file(self, log_dir, rank=0): | |||||
| """Setup logging file.""" | |||||
| self.rank = rank | |||||
| if not os.path.exists(log_dir): | |||||
| os.makedirs(log_dir, exist_ok=True) | |||||
| log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) | |||||
| self.log_fn = os.path.join(log_dir, log_name) | |||||
| fh = logging.FileHandler(self.log_fn) | |||||
| fh.setLevel(logging.INFO) | |||||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||||
| fh.setFormatter(formatter) | |||||
| self.addHandler(fh) | |||||
| def info(self, msg, *args, **kwargs): | |||||
| if self.isEnabledFor(logging.INFO): | |||||
| self._log(logging.INFO, msg, args, **kwargs) | |||||
| def save_args(self, args): | |||||
| self.info('Args:') | |||||
| args_dict = vars(args) | |||||
| for key in args_dict.keys(): | |||||
| self.info('--> %s: %s', key, args_dict[key]) | |||||
| self.info('') | |||||
| def important_info(self, msg, *args, **kwargs): | |||||
| if self.isEnabledFor(logging.INFO) and self.rank == 0: | |||||
| line_width = 2 | |||||
| important_msg = '\n' | |||||
| important_msg += ('*'*70 + '\n')*line_width | |||||
| important_msg += ('*'*line_width + '\n')*2 | |||||
| important_msg += '*'*line_width + ' '*8 + msg + '\n' | |||||
| important_msg += ('*'*line_width + '\n')*2 | |||||
| important_msg += ('*'*70 + '\n')*line_width | |||||
| self.info(important_msg, *args, **kwargs) | |||||
| def get_logger(path, rank): | |||||
| """Get Logger.""" | |||||
| logger = LOGGER('yolov3_darknet53', rank) | |||||
| logger.setup_logging_file(path, rank) | |||||
| return logger | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """YOLOV3 loss.""" | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| class XYLoss(nn.Cell): | |||||
| """Loss for x and y.""" | |||||
| def __init__(self): | |||||
| super(XYLoss, self).__init__() | |||||
| self.cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| def construct(self, object_mask, box_loss_scale, predict_xy, true_xy): | |||||
| xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy) | |||||
| xy_loss = self.reduce_sum(xy_loss, ()) | |||||
| return xy_loss | |||||
| class WHLoss(nn.Cell): | |||||
| """Loss for w and h.""" | |||||
| def __init__(self): | |||||
| super(WHLoss, self).__init__() | |||||
| self.square = P.Square() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| def construct(self, object_mask, box_loss_scale, predict_wh, true_wh): | |||||
| wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh) | |||||
| wh_loss = self.reduce_sum(wh_loss, ()) | |||||
| return wh_loss | |||||
| class ConfidenceLoss(nn.Cell): | |||||
| """Loss for confidence.""" | |||||
| def __init__(self): | |||||
| super(ConfidenceLoss, self).__init__() | |||||
| self.cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| def construct(self, object_mask, predict_confidence, ignore_mask): | |||||
| confidence_loss = self.cross_entropy(predict_confidence, object_mask) | |||||
| confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask | |||||
| confidence_loss = self.reduce_sum(confidence_loss, ()) | |||||
| return confidence_loss | |||||
| class ClassLoss(nn.Cell): | |||||
| """Loss for classification.""" | |||||
| def __init__(self): | |||||
| super(ClassLoss, self).__init__() | |||||
| self.cross_entropy = P.SigmoidCrossEntropyWithLogits() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| def construct(self, object_mask, predict_class, class_probs): | |||||
| class_loss = object_mask * self.cross_entropy(predict_class, class_probs) | |||||
| class_loss = self.reduce_sum(class_loss, ()) | |||||
| return class_loss | |||||
| @@ -0,0 +1,180 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Learning rate scheduler.""" | |||||
| import math | |||||
| from collections import Counter | |||||
| import numpy as np | |||||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||||
| """Linear learning rate.""" | |||||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
| lr = float(init_lr) + lr_inc * current_step | |||||
| return lr | |||||
| def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): | |||||
| """Warmup step learning rate.""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| milestones = lr_epochs | |||||
| milestones_steps = [] | |||||
| for milestone in milestones: | |||||
| milestones_step = milestone * steps_per_epoch | |||||
| milestones_steps.append(milestones_step) | |||||
| lr_each_step = [] | |||||
| lr = base_lr | |||||
| milestones_steps_counter = Counter(milestones_steps) | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| lr = lr * gamma**milestones_steps_counter[i] | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): | |||||
| return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) | |||||
| def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): | |||||
| lr_epochs = [] | |||||
| for i in range(1, max_epoch): | |||||
| if i % epoch_size == 0: | |||||
| lr_epochs.append(i) | |||||
| return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) | |||||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||||
| """Cosine annealing learning rate.""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| lr_each_step = [] | |||||
| for i in range(total_steps): | |||||
| last_epoch = i // steps_per_epoch | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||||
| """Cosine annealing learning rate V2.""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| last_lr = 0 | |||||
| last_epoch_V1 = 0 | |||||
| T_max_V2 = int(max_epoch*1/3) | |||||
| lr_each_step = [] | |||||
| for i in range(total_steps): | |||||
| last_epoch = i // steps_per_epoch | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| if i < total_steps*2/3: | |||||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||||
| last_lr = lr | |||||
| last_epoch_V1 = last_epoch | |||||
| else: | |||||
| base_lr = last_lr | |||||
| last_epoch = last_epoch-last_epoch_V1 | |||||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2 | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): | |||||
| """Warmup cosine annealing learning rate.""" | |||||
| start_sample_epoch = 60 | |||||
| step_sample = 2 | |||||
| tobe_sampled_epoch = 60 | |||||
| end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch | |||||
| max_sampled_epoch = max_epoch+tobe_sampled_epoch | |||||
| T_max = max_sampled_epoch | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| total_sampled_steps = int(max_sampled_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| lr_each_step = [] | |||||
| for i in range(total_sampled_steps): | |||||
| last_epoch = i // steps_per_epoch | |||||
| if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample): | |||||
| continue | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2 | |||||
| lr_each_step.append(lr) | |||||
| assert total_steps == len(lr_each_step) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def get_lr(args): | |||||
| """generate learning rate.""" | |||||
| if args.lr_scheduler == 'exponential': | |||||
| lr = warmup_step_lr(args.lr, | |||||
| args.lr_epochs, | |||||
| args.steps_per_epoch, | |||||
| args.warmup_epochs, | |||||
| args.max_epoch, | |||||
| gamma=args.lr_gamma, | |||||
| ) | |||||
| elif args.lr_scheduler == 'cosine_annealing': | |||||
| lr = warmup_cosine_annealing_lr(args.lr, | |||||
| args.steps_per_epoch, | |||||
| args.warmup_epochs, | |||||
| args.max_epoch, | |||||
| args.T_max, | |||||
| args.eta_min) | |||||
| elif args.lr_scheduler == 'cosine_annealing_V2': | |||||
| lr = warmup_cosine_annealing_lr_V2(args.lr, | |||||
| args.steps_per_epoch, | |||||
| args.warmup_epochs, | |||||
| args.max_epoch, | |||||
| args.T_max, | |||||
| args.eta_min) | |||||
| elif args.lr_scheduler == 'cosine_annealing_sample': | |||||
| lr = warmup_cosine_annealing_lr_sample(args.lr, | |||||
| args.steps_per_epoch, | |||||
| args.warmup_epochs, | |||||
| args.max_epoch, | |||||
| args.T_max, | |||||
| args.eta_min) | |||||
| else: | |||||
| raise NotImplementedError(args.lr_scheduler) | |||||
| return lr | |||||
| @@ -0,0 +1,593 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Preprocess dataset.""" | |||||
| import random | |||||
| import threading | |||||
| import copy | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| import cv2 | |||||
| def _rand(a=0., b=1.): | |||||
| return np.random.rand() * (b - a) + a | |||||
| def bbox_iou(bbox_a, bbox_b, offset=0): | |||||
| """Calculate Intersection-Over-Union(IOU) of two bounding boxes. | |||||
| Parameters | |||||
| ---------- | |||||
| bbox_a : numpy.ndarray | |||||
| An ndarray with shape :math:`(N, 4)`. | |||||
| bbox_b : numpy.ndarray | |||||
| An ndarray with shape :math:`(M, 4)`. | |||||
| offset : float or int, default is 0 | |||||
| The ``offset`` is used to control the whether the width(or height) is computed as | |||||
| (right - left + ``offset``). | |||||
| Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``. | |||||
| Returns | |||||
| ------- | |||||
| numpy.ndarray | |||||
| An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of | |||||
| bounding boxes in `bbox_a` and `bbox_b`. | |||||
| """ | |||||
| if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4: | |||||
| raise IndexError("Bounding boxes axis 1 must have at least length 4") | |||||
| tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) | |||||
| br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4]) | |||||
| area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2) | |||||
| area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1) | |||||
| area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1) | |||||
| return area_i / (area_a[:, None] + area_b - area_i) | |||||
| def statistic_normalize_img(img, statistic_norm): | |||||
| """Statistic normalize images.""" | |||||
| # img: RGB | |||||
| if isinstance(img, Image.Image): | |||||
| img = np.array(img) | |||||
| img = img/255. | |||||
| mean = np.array([0.485, 0.456, 0.406]) | |||||
| std = np.array([0.229, 0.224, 0.225]) | |||||
| if statistic_norm: | |||||
| img = (img - mean) / std | |||||
| return img | |||||
| def get_interp_method(interp, sizes=()): | |||||
| """ | |||||
| Get the interpolation method for resize functions. | |||||
| The major purpose of this function is to wrap a random interp method selection | |||||
| and a auto-estimation method. | |||||
| Note: | |||||
| When shrinking an image, it will generally look best with AREA-based | |||||
| interpolation, whereas, when enlarging an image, it will generally look best | |||||
| with Bicubic or Bilinear. | |||||
| Args: | |||||
| interp (int): Interpolation method for all resizing operations. | |||||
| - 0: Nearest Neighbors Interpolation. | |||||
| - 1: Bilinear interpolation. | |||||
| - 2: Bicubic interpolation over 4x4 pixel neighborhood. | |||||
| - 3: Nearest Neighbors. Originally it should be Area-based, as we cannot find Area-based, | |||||
| so we use NN instead. Area-based (resampling using pixel area relation). | |||||
| It may be a preferred method for image decimation, as it gives moire-free results. | |||||
| But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). | |||||
| - 4: Lanczos interpolation over 8x8 pixel neighborhood. | |||||
| - 9: Cubic for enlarge, area for shrink, bilinear for others. | |||||
| - 10: Random select from interpolation method mentioned above. | |||||
| sizes (tuple): Format should like (old_height, old_width, new_height, new_width), | |||||
| if None provided, auto(9) will return Area(2) anyway. Default: () | |||||
| Returns: | |||||
| int, interp method from 0 to 4. | |||||
| """ | |||||
| if interp == 9: | |||||
| if sizes: | |||||
| assert len(sizes) == 4 | |||||
| oh, ow, nh, nw = sizes | |||||
| if nh > oh and nw > ow: | |||||
| return 2 | |||||
| if nh < oh and nw < ow: | |||||
| return 0 | |||||
| return 1 | |||||
| return 2 | |||||
| if interp == 10: | |||||
| return random.randint(0, 4) | |||||
| if interp not in (0, 1, 2, 3, 4): | |||||
| raise ValueError('Unknown interp method %d' % interp) | |||||
| return interp | |||||
| def pil_image_reshape(interp): | |||||
| """Reshape pil image.""" | |||||
| reshape_type = { | |||||
| 0: Image.NEAREST, | |||||
| 1: Image.BILINEAR, | |||||
| 2: Image.BICUBIC, | |||||
| 3: Image.NEAREST, | |||||
| 4: Image.LANCZOS, | |||||
| } | |||||
| return reshape_type[interp] | |||||
| def _preprocess_true_boxes(true_boxes, anchors, in_shape, num_classes, | |||||
| max_boxes, label_smooth, label_smooth_factor=0.1): | |||||
| """Preprocess annotation boxes.""" | |||||
| anchors = np.array(anchors) | |||||
| num_layers = anchors.shape[0] // 3 | |||||
| anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] | |||||
| true_boxes = np.array(true_boxes, dtype='float32') | |||||
| input_shape = np.array(in_shape, dtype='int32') | |||||
| boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. | |||||
| # trans to box center point | |||||
| boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] | |||||
| # input_shape is [h, w] | |||||
| true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] | |||||
| true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] | |||||
| # true_boxes [x, y, w, h] | |||||
| grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] | |||||
| # grid_shape [h, w] | |||||
| y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), | |||||
| 5 + num_classes), dtype='float32') for l in range(num_layers)] | |||||
| # y_true [gridy, gridx] | |||||
| anchors = np.expand_dims(anchors, 0) | |||||
| anchors_max = anchors / 2. | |||||
| anchors_min = -anchors_max | |||||
| valid_mask = boxes_wh[..., 0] > 0 | |||||
| wh = boxes_wh[valid_mask] | |||||
| if wh.size > 0: | |||||
| wh = np.expand_dims(wh, -2) | |||||
| boxes_max = wh / 2. | |||||
| boxes_min = -boxes_max | |||||
| intersect_min = np.maximum(boxes_min, anchors_min) | |||||
| intersect_max = np.minimum(boxes_max, anchors_max) | |||||
| intersect_wh = np.maximum(intersect_max - intersect_min, 0.) | |||||
| intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] | |||||
| box_area = wh[..., 0] * wh[..., 1] | |||||
| anchor_area = anchors[..., 0] * anchors[..., 1] | |||||
| iou = intersect_area / (box_area + anchor_area - intersect_area) | |||||
| best_anchor = np.argmax(iou, axis=-1) | |||||
| for t, n in enumerate(best_anchor): | |||||
| for l in range(num_layers): | |||||
| if n in anchor_mask[l]: | |||||
| i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') # grid_y | |||||
| j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') # grid_x | |||||
| k = anchor_mask[l].index(n) | |||||
| c = true_boxes[t, 4].astype('int32') | |||||
| y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] | |||||
| y_true[l][j, i, k, 4] = 1. | |||||
| # lable-smooth | |||||
| if label_smooth: | |||||
| sigma = label_smooth_factor/(num_classes-1) | |||||
| y_true[l][j, i, k, 5:] = sigma | |||||
| y_true[l][j, i, k, 5+c] = 1-label_smooth_factor | |||||
| else: | |||||
| y_true[l][j, i, k, 5 + c] = 1. | |||||
| # pad_gt_boxes for avoiding dynamic shape | |||||
| pad_gt_box0 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) | |||||
| pad_gt_box1 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) | |||||
| pad_gt_box2 = np.zeros(shape=[max_boxes, 4], dtype=np.float32) | |||||
| mask0 = np.reshape(y_true[0][..., 4:5], [-1]) | |||||
| gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) | |||||
| # gt_box [boxes, [x,y,w,h]] | |||||
| gt_box0 = gt_box0[mask0 == 1] | |||||
| # gt_box0: get all boxes which have object | |||||
| pad_gt_box0[:gt_box0.shape[0]] = gt_box0 | |||||
| # gt_box0.shape[0]: total number of boxes in gt_box0 | |||||
| # top N of pad_gt_box0 is real box, and after are pad by zero | |||||
| mask1 = np.reshape(y_true[1][..., 4:5], [-1]) | |||||
| gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) | |||||
| gt_box1 = gt_box1[mask1 == 1] | |||||
| pad_gt_box1[:gt_box1.shape[0]] = gt_box1 | |||||
| mask2 = np.reshape(y_true[2][..., 4:5], [-1]) | |||||
| gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) | |||||
| gt_box2 = gt_box2[mask2 == 1] | |||||
| pad_gt_box2[:gt_box2.shape[0]] = gt_box2 | |||||
| return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 | |||||
| def _reshape_data(image, image_size): | |||||
| """Reshape image.""" | |||||
| if not isinstance(image, Image.Image): | |||||
| image = Image.fromarray(image) | |||||
| ori_w, ori_h = image.size | |||||
| ori_image_shape = np.array([ori_w, ori_h], np.int32) | |||||
| # original image shape fir:H sec:W | |||||
| h, w = image_size | |||||
| interp = get_interp_method(interp=9, sizes=(ori_h, ori_w, h, w)) | |||||
| image = image.resize((w, h), pil_image_reshape(interp)) | |||||
| image_data = statistic_normalize_img(image, statistic_norm=True) | |||||
| if len(image_data.shape) == 2: | |||||
| image_data = np.expand_dims(image_data, axis=-1) | |||||
| image_data = np.concatenate([image_data, image_data, image_data], axis=-1) | |||||
| image_data = image_data.astype(np.float32) | |||||
| return image_data, ori_image_shape | |||||
| def color_distortion(img, hue, sat, val, device_num): | |||||
| """Color distortion.""" | |||||
| hue = _rand(-hue, hue) | |||||
| sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) | |||||
| val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) | |||||
| if device_num != 1: | |||||
| cv2.setNumThreads(1) | |||||
| x = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL) | |||||
| x = x / 255. | |||||
| x[..., 0] += hue | |||||
| x[..., 0][x[..., 0] > 1] -= 1 | |||||
| x[..., 0][x[..., 0] < 0] += 1 | |||||
| x[..., 1] *= sat | |||||
| x[..., 2] *= val | |||||
| x[x > 1] = 1 | |||||
| x[x < 0] = 0 | |||||
| x = x * 255. | |||||
| x = x.astype(np.uint8) | |||||
| image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL) | |||||
| return image_data | |||||
| def filp_pil_image(img): | |||||
| return img.transpose(Image.FLIP_LEFT_RIGHT) | |||||
| def convert_gray_to_color(img): | |||||
| if len(img.shape) == 2: | |||||
| img = np.expand_dims(img, axis=-1) | |||||
| img = np.concatenate([img, img, img], axis=-1) | |||||
| return img | |||||
| def _is_iou_satisfied_constraint(min_iou, max_iou, box, crop_box): | |||||
| iou = bbox_iou(box, crop_box) | |||||
| return min_iou <= iou.min() and max_iou >= iou.max() | |||||
| def _choose_candidate_by_constraints(max_trial, input_w, input_h, image_w, image_h, jitter, box, use_constraints): | |||||
| """Choose candidate by constraints.""" | |||||
| if use_constraints: | |||||
| constraints = ( | |||||
| (0.1, None), | |||||
| (0.3, None), | |||||
| (0.5, None), | |||||
| (0.7, None), | |||||
| (0.9, None), | |||||
| (None, 1), | |||||
| ) | |||||
| else: | |||||
| constraints = ( | |||||
| (None, None), | |||||
| ) | |||||
| # add default candidate | |||||
| candidates = [(0, 0, input_w, input_h)] | |||||
| for constraint in constraints: | |||||
| min_iou, max_iou = constraint | |||||
| min_iou = -np.inf if min_iou is None else min_iou | |||||
| max_iou = np.inf if max_iou is None else max_iou | |||||
| for _ in range(max_trial): | |||||
| # box_data should have at least one box | |||||
| new_ar = float(input_w) / float(input_h) * _rand(1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter) | |||||
| scale = _rand(0.25, 2) | |||||
| if new_ar < 1: | |||||
| nh = int(scale * input_h) | |||||
| nw = int(nh * new_ar) | |||||
| else: | |||||
| nw = int(scale * input_w) | |||||
| nh = int(nw / new_ar) | |||||
| dx = int(_rand(0, input_w - nw)) | |||||
| dy = int(_rand(0, input_h - nh)) | |||||
| if box.size > 0: | |||||
| t_box = copy.deepcopy(box) | |||||
| t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx | |||||
| t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy | |||||
| crop_box = np.array((0, 0, input_w, input_h)) | |||||
| if not _is_iou_satisfied_constraint(min_iou, max_iou, t_box, crop_box[np.newaxis]): | |||||
| continue | |||||
| else: | |||||
| candidates.append((dx, dy, nw, nh)) | |||||
| else: | |||||
| raise Exception("!!! annotation box is less than 1") | |||||
| return candidates | |||||
| def _correct_bbox_by_candidates(candidates, input_w, input_h, image_w, | |||||
| image_h, flip, box, box_data, allow_outside_center): | |||||
| """Calculate correct boxes.""" | |||||
| while candidates: | |||||
| if len(candidates) > 1: | |||||
| # ignore default candidate which do not crop | |||||
| candidate = candidates.pop(np.random.randint(1, len(candidates))) | |||||
| else: | |||||
| candidate = candidates.pop(np.random.randint(0, len(candidates))) | |||||
| dx, dy, nw, nh = candidate | |||||
| t_box = copy.deepcopy(box) | |||||
| t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(image_w) + dx | |||||
| t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(image_h) + dy | |||||
| if flip: | |||||
| t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]] | |||||
| if allow_outside_center: | |||||
| pass | |||||
| else: | |||||
| t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2])/2. >= 0., (t_box[:, 1] + t_box[:, 3])/2. >= 0.)] | |||||
| t_box = t_box[np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w, | |||||
| (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)] | |||||
| # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero | |||||
| t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 | |||||
| # recorrect w,h not higher than input size | |||||
| t_box[:, 2][t_box[:, 2] > input_w] = input_w | |||||
| t_box[:, 3][t_box[:, 3] > input_h] = input_h | |||||
| box_w = t_box[:, 2] - t_box[:, 0] | |||||
| box_h = t_box[:, 3] - t_box[:, 1] | |||||
| # discard invalid box: w or h smaller than 1 pixel | |||||
| t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] | |||||
| if t_box.shape[0] > 0: | |||||
| # break if number of find t_box | |||||
| box_data[: len(t_box)] = t_box | |||||
| return box_data, candidate | |||||
| raise Exception('all candidates can not satisfied re-correct bbox') | |||||
| def _data_aug(image, box, jitter, hue, sat, val, image_input_size, max_boxes, | |||||
| anchors, num_classes, max_trial=10, device_num=1): | |||||
| """Crop an image randomly with bounding box constraints. | |||||
| This data augmentation is used in training of | |||||
| Single Shot Multibox Detector [#]_. More details can be found in | |||||
| data augmentation section of the original paper. | |||||
| .. [#] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, | |||||
| Scott Reed, Cheng-Yang Fu, Alexander C. Berg. | |||||
| SSD: Single Shot MultiBox Detector. ECCV 2016.""" | |||||
| if not isinstance(image, Image.Image): | |||||
| image = Image.fromarray(image) | |||||
| image_w, image_h = image.size | |||||
| input_h, input_w = image_input_size | |||||
| np.random.shuffle(box) | |||||
| if len(box) > max_boxes: | |||||
| box = box[:max_boxes] | |||||
| flip = _rand() < .5 | |||||
| box_data = np.zeros((max_boxes, 5)) | |||||
| candidates = _choose_candidate_by_constraints(use_constraints=False, | |||||
| max_trial=max_trial, | |||||
| input_w=input_w, | |||||
| input_h=input_h, | |||||
| image_w=image_w, | |||||
| image_h=image_h, | |||||
| jitter=jitter, | |||||
| box=box) | |||||
| box_data, candidate = _correct_bbox_by_candidates(candidates=candidates, | |||||
| input_w=input_w, | |||||
| input_h=input_h, | |||||
| image_w=image_w, | |||||
| image_h=image_h, | |||||
| flip=flip, | |||||
| box=box, | |||||
| box_data=box_data, | |||||
| allow_outside_center=True) | |||||
| dx, dy, nw, nh = candidate | |||||
| interp = get_interp_method(interp=10) | |||||
| image = image.resize((nw, nh), pil_image_reshape(interp)) | |||||
| # place image, gray color as back graoud | |||||
| new_image = Image.new('RGB', (input_w, input_h), (128, 128, 128)) | |||||
| new_image.paste(image, (dx, dy)) | |||||
| image = new_image | |||||
| if flip: | |||||
| image = filp_pil_image(image) | |||||
| image = np.array(image) | |||||
| image = convert_gray_to_color(image) | |||||
| image_data = color_distortion(image, hue, sat, val, device_num) | |||||
| image_data = statistic_normalize_img(image_data, statistic_norm=True) | |||||
| image_data = image_data.astype(np.float32) | |||||
| return image_data, box_data | |||||
| def preprocess_fn(image, box, config, input_size, device_num): | |||||
| """Preprocess data function.""" | |||||
| config_anchors = config.anchor_scales | |||||
| anchors = np.array([list(x) for x in config_anchors]) | |||||
| max_boxes = config.max_box | |||||
| num_classes = config.num_classes | |||||
| jitter = config.jitter | |||||
| hue = config.hue | |||||
| sat = config.saturation | |||||
| val = config.value | |||||
| image, anno = _data_aug(image, box, jitter=jitter, hue=hue, sat=sat, val=val, | |||||
| image_input_size=input_size, max_boxes=max_boxes, | |||||
| num_classes=num_classes, anchors=anchors, device_num=device_num) | |||||
| return image, anno | |||||
| def reshape_fn(image, img_id, config): | |||||
| input_size = config.test_img_shape | |||||
| image, ori_image_shape = _reshape_data(image, image_size=input_size) | |||||
| return image, ori_image_shape, img_id | |||||
| class MultiScaleTrans: | |||||
| """Multi scale transform.""" | |||||
| def __init__(self, config, device_num): | |||||
| self.config = config | |||||
| self.seed = 0 | |||||
| self.size_list = [] | |||||
| self.resize_rate = config.resize_rate | |||||
| self.dataset_size = config.dataset_size | |||||
| self.size_dict = {} | |||||
| self.seed_num = int(1e6) | |||||
| self.seed_list = self.generate_seed_list(seed_num=self.seed_num) | |||||
| self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) | |||||
| self.device_num = device_num | |||||
| self.anchor_scales = config.anchor_scales | |||||
| self.num_classes = config.num_classes | |||||
| self.max_box = config.max_box | |||||
| self.label_smooth = config.label_smooth | |||||
| self.label_smooth_factor = config.label_smooth_factor | |||||
| def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): | |||||
| seed_list = [] | |||||
| random.seed(init_seed) | |||||
| for _ in range(seed_num): | |||||
| seed = random.randint(seed_range[0], seed_range[1]) | |||||
| seed_list.append(seed) | |||||
| return seed_list | |||||
| def __call__(self, imgs, annos, x1, x2, x3, x4, x5, x6, batchInfo): | |||||
| epoch_num = batchInfo.get_epoch_num() | |||||
| size_idx = int(batchInfo.get_batch_num() / self.resize_rate) | |||||
| seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] | |||||
| ret_imgs = [] | |||||
| ret_annos = [] | |||||
| bbox1 = [] | |||||
| bbox2 = [] | |||||
| bbox3 = [] | |||||
| gt1 = [] | |||||
| gt2 = [] | |||||
| gt3 = [] | |||||
| if self.size_dict.get(seed_key, None) is None: | |||||
| random.seed(seed_key) | |||||
| new_size = random.choice(self.config.multi_scale) | |||||
| self.size_dict[seed_key] = new_size | |||||
| seed = seed_key | |||||
| input_size = self.size_dict[seed] | |||||
| for img, anno in zip(imgs, annos): | |||||
| img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) | |||||
| ret_imgs.append(img.transpose(2, 0, 1).copy()) | |||||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||||
| _preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2], | |||||
| num_classes=self.num_classes, max_boxes=self.max_box, | |||||
| label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor) | |||||
| bbox1.append(bbox_true_1) | |||||
| bbox2.append(bbox_true_2) | |||||
| bbox3.append(bbox_true_3) | |||||
| gt1.append(gt_box1) | |||||
| gt2.append(gt_box2) | |||||
| gt3.append(gt_box3) | |||||
| ret_annos.append(0) | |||||
| return np.array(ret_imgs), np.array(ret_annos), np.array(bbox1), np.array(bbox2), np.array(bbox3), \ | |||||
| np.array(gt1), np.array(gt2), np.array(gt3) | |||||
| def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, | |||||
| batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3): | |||||
| """Preprocess true box for multi-thread.""" | |||||
| i = 0 | |||||
| for anno in annos: | |||||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||||
| _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, | |||||
| num_classes=config.num_classes, max_boxes=config.max_box, | |||||
| label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) | |||||
| batch_bbox_true_1[result_index + i] = bbox_true_1 | |||||
| batch_bbox_true_2[result_index + i] = bbox_true_2 | |||||
| batch_bbox_true_3[result_index + i] = bbox_true_3 | |||||
| batch_gt_box1[result_index + i] = gt_box1 | |||||
| batch_gt_box2[result_index + i] = gt_box2 | |||||
| batch_gt_box3[result_index + i] = gt_box3 | |||||
| i = i + 1 | |||||
| def batch_preprocess_true_box(annos, config, input_shape): | |||||
| """Preprocess true box with multi-thread.""" | |||||
| batch_bbox_true_1 = [] | |||||
| batch_bbox_true_2 = [] | |||||
| batch_bbox_true_3 = [] | |||||
| batch_gt_box1 = [] | |||||
| batch_gt_box2 = [] | |||||
| batch_gt_box3 = [] | |||||
| threads = [] | |||||
| step = 4 | |||||
| for index in range(0, len(annos), step): | |||||
| for _ in range(step): | |||||
| batch_bbox_true_1.append(None) | |||||
| batch_bbox_true_2.append(None) | |||||
| batch_bbox_true_3.append(None) | |||||
| batch_gt_box1.append(None) | |||||
| batch_gt_box2.append(None) | |||||
| batch_gt_box3.append(None) | |||||
| step_anno = annos[index: index + step] | |||||
| t = threading.Thread(target=thread_batch_preprocess_true_box, | |||||
| args=(step_anno, config, input_shape, index, batch_bbox_true_1, batch_bbox_true_2, | |||||
| batch_bbox_true_3, batch_gt_box1, batch_gt_box2, batch_gt_box3)) | |||||
| t.start() | |||||
| threads.append(t) | |||||
| for t in threads: | |||||
| t.join() | |||||
| return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ | |||||
| np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) | |||||
| def batch_preprocess_true_box_single(annos, config, input_shape): | |||||
| """Preprocess true boxes.""" | |||||
| batch_bbox_true_1 = [] | |||||
| batch_bbox_true_2 = [] | |||||
| batch_bbox_true_3 = [] | |||||
| batch_gt_box1 = [] | |||||
| batch_gt_box2 = [] | |||||
| batch_gt_box3 = [] | |||||
| for anno in annos: | |||||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||||
| _preprocess_true_boxes(true_boxes=anno, anchors=config.anchor_scales, in_shape=input_shape, | |||||
| num_classes=config.num_classes, max_boxes=config.max_box, | |||||
| label_smooth=config.label_smooth, label_smooth_factor=config.label_smooth_factor) | |||||
| batch_bbox_true_1.append(bbox_true_1) | |||||
| batch_bbox_true_2.append(bbox_true_2) | |||||
| batch_bbox_true_3.append(bbox_true_3) | |||||
| batch_gt_box1.append(gt_box1) | |||||
| batch_gt_box2.append(gt_box2) | |||||
| batch_gt_box3.append(gt_box3) | |||||
| return np.array(batch_bbox_true_1), np.array(batch_bbox_true_2), np.array(batch_bbox_true_3), \ | |||||
| np.array(batch_gt_box1), np.array(batch_gt_box2), np.array(batch_gt_box3) | |||||
| @@ -0,0 +1,187 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Util class or function.""" | |||||
| from mindspore.train.serialization import load_checkpoint | |||||
| import mindspore.nn as nn | |||||
| import mindspore.common.dtype as mstype | |||||
| from .yolo import YoloLossBlock | |||||
| class AverageMeter: | |||||
| """Computes and stores the average and current value""" | |||||
| def __init__(self, name, fmt=':f', tb_writer=None): | |||||
| self.name = name | |||||
| self.fmt = fmt | |||||
| self.reset() | |||||
| self.tb_writer = tb_writer | |||||
| self.cur_step = 1 | |||||
| self.val = 0 | |||||
| self.avg = 0 | |||||
| self.sum = 0 | |||||
| self.count = 0 | |||||
| def reset(self): | |||||
| self.val = 0 | |||||
| self.avg = 0 | |||||
| self.sum = 0 | |||||
| self.count = 0 | |||||
| def update(self, val, n=1): | |||||
| self.val = val | |||||
| self.sum += val * n | |||||
| self.count += n | |||||
| self.avg = self.sum / self.count | |||||
| if self.tb_writer is not None: | |||||
| self.tb_writer.add_scalar(self.name, self.val, self.cur_step) | |||||
| self.cur_step += 1 | |||||
| def __str__(self): | |||||
| fmtstr = '{name}:{avg' + self.fmt + '}' | |||||
| return fmtstr.format(**self.__dict__) | |||||
| def load_backbone(net, ckpt_path, args): | |||||
| """Load darknet53 backbone checkpoint.""" | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| yolo_backbone_prefix = 'feature_map.backbone' | |||||
| darknet_backbone_prefix = 'network.backbone' | |||||
| find_param = [] | |||||
| not_found_param = [] | |||||
| net.init_parameters_data() | |||||
| for name, cell in net.cells_and_names(): | |||||
| if name.startswith(yolo_backbone_prefix): | |||||
| name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) | |||||
| if isinstance(cell, (nn.Conv2d, nn.Dense)): | |||||
| darknet_weight = '{}.weight'.format(name) | |||||
| darknet_bias = '{}.bias'.format(name) | |||||
| if darknet_weight in param_dict: | |||||
| cell.weight.set_data(param_dict[darknet_weight].data) | |||||
| find_param.append(darknet_weight) | |||||
| else: | |||||
| not_found_param.append(darknet_weight) | |||||
| if darknet_bias in param_dict: | |||||
| cell.bias.set_data(param_dict[darknet_bias].data) | |||||
| find_param.append(darknet_bias) | |||||
| else: | |||||
| not_found_param.append(darknet_bias) | |||||
| elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): | |||||
| darknet_moving_mean = '{}.moving_mean'.format(name) | |||||
| darknet_moving_variance = '{}.moving_variance'.format(name) | |||||
| darknet_gamma = '{}.gamma'.format(name) | |||||
| darknet_beta = '{}.beta'.format(name) | |||||
| if darknet_moving_mean in param_dict: | |||||
| cell.moving_mean.set_data(param_dict[darknet_moving_mean].data) | |||||
| find_param.append(darknet_moving_mean) | |||||
| else: | |||||
| not_found_param.append(darknet_moving_mean) | |||||
| if darknet_moving_variance in param_dict: | |||||
| cell.moving_variance.set_data(param_dict[darknet_moving_variance].data) | |||||
| find_param.append(darknet_moving_variance) | |||||
| else: | |||||
| not_found_param.append(darknet_moving_variance) | |||||
| if darknet_gamma in param_dict: | |||||
| cell.gamma.set_data(param_dict[darknet_gamma].data) | |||||
| find_param.append(darknet_gamma) | |||||
| else: | |||||
| not_found_param.append(darknet_gamma) | |||||
| if darknet_beta in param_dict: | |||||
| cell.beta.set_data(param_dict[darknet_beta].data) | |||||
| find_param.append(darknet_beta) | |||||
| else: | |||||
| not_found_param.append(darknet_beta) | |||||
| args.logger.info('================found_param {}========='.format(len(find_param))) | |||||
| args.logger.info(find_param) | |||||
| args.logger.info('================not_found_param {}========='.format(len(not_found_param))) | |||||
| args.logger.info(not_found_param) | |||||
| args.logger.info('=====load {} successfully ====='.format(ckpt_path)) | |||||
| return net | |||||
| def default_wd_filter(x): | |||||
| """default weight decay filter.""" | |||||
| parameter_name = x.name | |||||
| if parameter_name.endswith('.bias'): | |||||
| # all bias not using weight decay | |||||
| return False | |||||
| if parameter_name.endswith('.gamma'): | |||||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||||
| return False | |||||
| if parameter_name.endswith('.beta'): | |||||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||||
| return False | |||||
| return True | |||||
| def get_param_groups(network): | |||||
| """Param groups for optimizer.""" | |||||
| decay_params = [] | |||||
| no_decay_params = [] | |||||
| for x in network.trainable_params(): | |||||
| parameter_name = x.name | |||||
| if parameter_name.endswith('.bias'): | |||||
| # all bias not using weight decay | |||||
| no_decay_params.append(x) | |||||
| elif parameter_name.endswith('.gamma'): | |||||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||||
| no_decay_params.append(x) | |||||
| elif parameter_name.endswith('.beta'): | |||||
| # bn weight bias not using weight decay, be carefully for now x not include BN | |||||
| no_decay_params.append(x) | |||||
| else: | |||||
| decay_params.append(x) | |||||
| return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] | |||||
| class ShapeRecord: | |||||
| """Log image shape.""" | |||||
| def __init__(self): | |||||
| self.shape_record = { | |||||
| 320: 0, | |||||
| 352: 0, | |||||
| 384: 0, | |||||
| 416: 0, | |||||
| 448: 0, | |||||
| 480: 0, | |||||
| 512: 0, | |||||
| 544: 0, | |||||
| 576: 0, | |||||
| 608: 0, | |||||
| 'total': 0 | |||||
| } | |||||
| def set(self, shape): | |||||
| if len(shape) > 1: | |||||
| shape = shape[0] | |||||
| shape = int(shape) | |||||
| self.shape_record[shape] += 1 | |||||
| self.shape_record['total'] += 1 | |||||
| def show(self, logger): | |||||
| for key in self.shape_record: | |||||
| rate = self.shape_record[key] / float(self.shape_record['total']) | |||||
| logger.info('shape {}: {:.2f}%'.format(key, rate*100)) | |||||
| def keep_loss_fp32(network): | |||||
| """Keep loss of network with float32""" | |||||
| for _, cell in network.cells_and_names(): | |||||
| if isinstance(cell, (YoloLossBlock,)): | |||||
| cell.to_float(mstype.float32) | |||||
| @@ -0,0 +1,439 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """YOLOv3 based on DarkNet.""" | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.context import ParallelMode | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from mindspore.communication.management import get_group_size | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from src.darknet import DarkNet, ResidualBlock | |||||
| from src.config import ConfigYOLOV3DarkNet53 | |||||
| from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss | |||||
| def _conv_bn_relu(in_channel, | |||||
| out_channel, | |||||
| ksize, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| alpha=0.1, | |||||
| momentum=0.9, | |||||
| eps=1e-5, | |||||
| pad_mode="same"): | |||||
| """Get a conv2d batchnorm and relu layer""" | |||||
| return nn.SequentialCell( | |||||
| [nn.Conv2d(in_channel, | |||||
| out_channel, | |||||
| kernel_size=ksize, | |||||
| stride=stride, | |||||
| padding=padding, | |||||
| dilation=dilation, | |||||
| pad_mode=pad_mode), | |||||
| nn.BatchNorm2d(out_channel, momentum=momentum, eps=eps), | |||||
| nn.LeakyReLU(alpha)] | |||||
| ) | |||||
| class YoloBlock(nn.Cell): | |||||
| """ | |||||
| YoloBlock for YOLOv3. | |||||
| Args: | |||||
| in_channels: Integer. Input channel. | |||||
| out_chls: Interger. Middle channel. | |||||
| out_channels: Integer. Output channel. | |||||
| Returns: | |||||
| Tuple, tuple of output tensor,(f1,f2,f3). | |||||
| Examples: | |||||
| YoloBlock(1024, 512, 255) | |||||
| """ | |||||
| def __init__(self, in_channels, out_chls, out_channels): | |||||
| super(YoloBlock, self).__init__() | |||||
| out_chls_2 = out_chls*2 | |||||
| self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1) | |||||
| self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) | |||||
| self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) | |||||
| self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) | |||||
| self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) | |||||
| self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) | |||||
| self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True) | |||||
| def construct(self, x): | |||||
| c1 = self.conv0(x) | |||||
| c2 = self.conv1(c1) | |||||
| c3 = self.conv2(c2) | |||||
| c4 = self.conv3(c3) | |||||
| c5 = self.conv4(c4) | |||||
| c6 = self.conv5(c5) | |||||
| out = self.conv6(c6) | |||||
| return c5, out | |||||
| class YOLOv3(nn.Cell): | |||||
| """ | |||||
| YOLOv3 Network. | |||||
| Note: | |||||
| backbone = darknet53 | |||||
| Args: | |||||
| backbone_shape: List. Darknet output channels shape. | |||||
| backbone: Cell. Backbone Network. | |||||
| out_channel: Interger. Output channel. | |||||
| Returns: | |||||
| Tensor, output tensor. | |||||
| Examples: | |||||
| YOLOv3(backbone_shape=[64, 128, 256, 512, 1024] | |||||
| backbone=darknet53(), | |||||
| out_channel=255) | |||||
| """ | |||||
| def __init__(self, backbone_shape, backbone, out_channel): | |||||
| super(YOLOv3, self).__init__() | |||||
| self.out_channel = out_channel | |||||
| self.backbone = backbone | |||||
| self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel) | |||||
| self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1) | |||||
| self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3], | |||||
| out_chls=backbone_shape[-3], | |||||
| out_channels=out_channel) | |||||
| self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1) | |||||
| self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4], | |||||
| out_chls=backbone_shape[-4], | |||||
| out_channels=out_channel) | |||||
| self.concat = P.Concat(axis=1) | |||||
| def construct(self, x): | |||||
| # input_shape of x is (batch_size, 3, h, w) | |||||
| # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8) | |||||
| # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16) | |||||
| # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32) | |||||
| img_hight = P.Shape()(x)[2] | |||||
| img_width = P.Shape()(x)[3] | |||||
| feature_map1, feature_map2, feature_map3 = self.backbone(x) | |||||
| con1, big_object_output = self.backblock0(feature_map3) | |||||
| con1 = self.conv1(con1) | |||||
| ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1) | |||||
| con1 = self.concat((ups1, feature_map2)) | |||||
| con2, medium_object_output = self.backblock1(con1) | |||||
| con2 = self.conv2(con2) | |||||
| ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2) | |||||
| con3 = self.concat((ups2, feature_map1)) | |||||
| _, small_object_output = self.backblock2(con3) | |||||
| return big_object_output, medium_object_output, small_object_output | |||||
| class DetectionBlock(nn.Cell): | |||||
| """ | |||||
| YOLOv3 detection Network. It will finally output the detection result. | |||||
| Args: | |||||
| scale: Character. | |||||
| config: ConfigYOLOV3DarkNet53, Configuration instance. | |||||
| is_training: Bool, Whether train or not, default True. | |||||
| Returns: | |||||
| Tuple, tuple of output tensor,(f1,f2,f3). | |||||
| Examples: | |||||
| DetectionBlock(scale='l',stride=32) | |||||
| """ | |||||
| def __init__(self, scale, config=ConfigYOLOV3DarkNet53(), is_training=True): | |||||
| super(DetectionBlock, self).__init__() | |||||
| self.config = config | |||||
| if scale == 's': | |||||
| idx = (0, 1, 2) | |||||
| elif scale == 'm': | |||||
| idx = (3, 4, 5) | |||||
| elif scale == 'l': | |||||
| idx = (6, 7, 8) | |||||
| else: | |||||
| raise KeyError("Invalid scale value for DetectionBlock") | |||||
| self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) | |||||
| self.num_anchors_per_scale = 3 | |||||
| self.num_attrib = 4+1+self.config.num_classes | |||||
| self.lambda_coord = 1 | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| self.reshape = P.Reshape() | |||||
| self.tile = P.Tile() | |||||
| self.concat = P.Concat(axis=-1) | |||||
| self.conf_training = is_training | |||||
| def construct(self, x, input_shape): | |||||
| num_batch = P.Shape()(x)[0] | |||||
| grid_size = P.Shape()(x)[2:4] | |||||
| # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib] | |||||
| prediction = P.Reshape()(x, (num_batch, | |||||
| self.num_anchors_per_scale, | |||||
| self.num_attrib, | |||||
| grid_size[0], | |||||
| grid_size[1])) | |||||
| prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2)) | |||||
| range_x = range(grid_size[1]) | |||||
| range_y = range(grid_size[0]) | |||||
| grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32) | |||||
| grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32) | |||||
| # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid | |||||
| # [batch, gridx, gridy, 1, 1] | |||||
| grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1)) | |||||
| grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1)) | |||||
| # Shape is [grid_size[0], grid_size[1], 1, 2] | |||||
| grid = self.concat((grid_x, grid_y)) | |||||
| box_xy = prediction[:, :, :, :, :2] | |||||
| box_wh = prediction[:, :, :, :, 2:4] | |||||
| box_confidence = prediction[:, :, :, :, 4:5] | |||||
| box_probs = prediction[:, :, :, :, 5:] | |||||
| # gridsize1 is x | |||||
| # gridsize0 is y | |||||
| box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32) | |||||
| # box_wh is w->h | |||||
| box_wh = P.Exp()(box_wh) * self.anchors / input_shape | |||||
| box_confidence = self.sigmoid(box_confidence) | |||||
| box_probs = self.sigmoid(box_probs) | |||||
| if self.conf_training: | |||||
| return grid, prediction, box_xy, box_wh | |||||
| return self.concat((box_xy, box_wh, box_confidence, box_probs)) | |||||
| class Iou(nn.Cell): | |||||
| """Calculate the iou of boxes""" | |||||
| def __init__(self): | |||||
| super(Iou, self).__init__() | |||||
| self.min = P.Minimum() | |||||
| self.max = P.Maximum() | |||||
| def construct(self, box1, box2): | |||||
| # box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h] | |||||
| # box2: gt_box [batch, 1, 1, 1, maxbox, 4] | |||||
| # convert to topLeft and rightDown | |||||
| box1_xy = box1[:, :, :, :, :, :2] | |||||
| box1_wh = box1[:, :, :, :, :, 2:4] | |||||
| box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft | |||||
| box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown | |||||
| box2_xy = box2[:, :, :, :, :, :2] | |||||
| box2_wh = box2[:, :, :, :, :, 2:4] | |||||
| box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0) | |||||
| box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0) | |||||
| intersect_mins = self.max(box1_mins, box2_mins) | |||||
| intersect_maxs = self.min(box1_maxs, box2_maxs) | |||||
| intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0)) | |||||
| # P.squeeze: for effiecient slice | |||||
| intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \ | |||||
| P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2]) | |||||
| box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2]) | |||||
| box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2]) | |||||
| iou = intersect_area / (box1_area + box2_area - intersect_area) | |||||
| # iou : [batch, gx, gy, anchors, maxboxes] | |||||
| return iou | |||||
| class YoloLossBlock(nn.Cell): | |||||
| """ | |||||
| Loss block cell of YOLOV3 network. | |||||
| """ | |||||
| def __init__(self, scale, config=ConfigYOLOV3DarkNet53()): | |||||
| super(YoloLossBlock, self).__init__() | |||||
| self.config = config | |||||
| if scale == 's': | |||||
| # anchor mask | |||||
| idx = (0, 1, 2) | |||||
| elif scale == 'm': | |||||
| idx = (3, 4, 5) | |||||
| elif scale == 'l': | |||||
| idx = (6, 7, 8) | |||||
| else: | |||||
| raise KeyError("Invalid scale value for DetectionBlock") | |||||
| self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) | |||||
| self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32) | |||||
| self.concat = P.Concat(axis=-1) | |||||
| self.iou = Iou() | |||||
| self.reduce_max = P.ReduceMax(keep_dims=False) | |||||
| self.xy_loss = XYLoss() | |||||
| self.wh_loss = WHLoss() | |||||
| self.confidenceLoss = ConfidenceLoss() | |||||
| self.classLoss = ClassLoss() | |||||
| def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape): | |||||
| # prediction : origin output from yolo | |||||
| # pred_xy: (sigmoid(xy)+grid)/grid_size | |||||
| # pred_wh: (exp(wh)*anchors)/input_shape | |||||
| # y_true : after normalize | |||||
| # gt_box: [batch, maxboxes, xyhw] after normalize | |||||
| object_mask = y_true[:, :, :, :, 4:5] | |||||
| class_probs = y_true[:, :, :, :, 5:] | |||||
| grid_shape = P.Shape()(prediction)[1:3] | |||||
| grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32) | |||||
| pred_boxes = self.concat((pred_xy, pred_wh)) | |||||
| true_xy = y_true[:, :, :, :, :2] * grid_shape - grid | |||||
| true_wh = y_true[:, :, :, :, 2:4] | |||||
| true_wh = P.Select()(P.Equal()(true_wh, 0.0), | |||||
| P.Fill()(P.DType()(true_wh), | |||||
| P.Shape()(true_wh), 1.0), | |||||
| true_wh) | |||||
| true_wh = P.Log()(true_wh / self.anchors * input_shape) | |||||
| # 2-w*h for large picture, use small scale, since small obj need more precise | |||||
| box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4] | |||||
| gt_shape = P.Shape()(gt_box) | |||||
| gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2])) | |||||
| # add one more dimension for broadcast | |||||
| iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) | |||||
| # gt_box is x,y,h,w after normalize | |||||
| # [batch, grid[0], grid[1], num_anchor, num_gt] | |||||
| best_iou = self.reduce_max(iou, -1) | |||||
| # [batch, grid[0], grid[1], num_anchor] | |||||
| # ignore_mask IOU too small | |||||
| ignore_mask = best_iou < self.ignore_threshold | |||||
| ignore_mask = P.Cast()(ignore_mask, ms.float32) | |||||
| ignore_mask = P.ExpandDims()(ignore_mask, -1) | |||||
| # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume. | |||||
| # so we turn off its gradient | |||||
| ignore_mask = F.stop_gradient(ignore_mask) | |||||
| xy_loss = self.xy_loss(object_mask, box_loss_scale, prediction[:, :, :, :, :2], true_xy) | |||||
| wh_loss = self.wh_loss(object_mask, box_loss_scale, prediction[:, :, :, :, 2:4], true_wh) | |||||
| confidence_loss = self.confidenceLoss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask) | |||||
| class_loss = self.classLoss(object_mask, prediction[:, :, :, :, 5:], class_probs) | |||||
| loss = xy_loss + wh_loss + confidence_loss + class_loss | |||||
| batch_size = P.Shape()(prediction)[0] | |||||
| return loss / batch_size | |||||
| class YOLOV3DarkNet53(nn.Cell): | |||||
| """ | |||||
| Darknet based YOLOV3 network. | |||||
| Args: | |||||
| is_training: Bool. Whether train or not. | |||||
| Returns: | |||||
| Cell, cell instance of Darknet based YOLOV3 neural network. | |||||
| Examples: | |||||
| YOLOV3DarkNet53(True) | |||||
| """ | |||||
| def __init__(self, is_training): | |||||
| super(YOLOV3DarkNet53, self).__init__() | |||||
| self.config = ConfigYOLOV3DarkNet53() | |||||
| # YOLOv3 network | |||||
| self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, | |||||
| self.config.backbone_input_shape, | |||||
| self.config.backbone_shape, | |||||
| detect=True), | |||||
| backbone_shape=self.config.backbone_shape, | |||||
| out_channel=self.config.out_channel) | |||||
| # prediction on the default anchor boxes | |||||
| self.detect_1 = DetectionBlock('l', is_training=is_training) | |||||
| self.detect_2 = DetectionBlock('m', is_training=is_training) | |||||
| self.detect_3 = DetectionBlock('s', is_training=is_training) | |||||
| def construct(self, x, input_shape): | |||||
| big_object_output, medium_object_output, small_object_output = self.feature_map(x) | |||||
| output_big = self.detect_1(big_object_output, input_shape) | |||||
| output_me = self.detect_2(medium_object_output, input_shape) | |||||
| output_small = self.detect_3(small_object_output, input_shape) | |||||
| # big is the final output which has smallest feature map | |||||
| return output_big, output_me, output_small | |||||
| class YoloWithLossCell(nn.Cell): | |||||
| """YOLOV3 loss.""" | |||||
| def __init__(self, network): | |||||
| super(YoloWithLossCell, self).__init__() | |||||
| self.yolo_network = network | |||||
| self.config = ConfigYOLOV3DarkNet53() | |||||
| self.loss_big = YoloLossBlock('l', self.config) | |||||
| self.loss_me = YoloLossBlock('m', self.config) | |||||
| self.loss_small = YoloLossBlock('s', self.config) | |||||
| def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): | |||||
| yolo_out = self.yolo_network(x, input_shape) | |||||
| loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) | |||||
| loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) | |||||
| loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) | |||||
| return loss_l + loss_m + loss_s | |||||
| class TrainingWrapper(nn.Cell): | |||||
| """Training wrapper.""" | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| self.reducer_flag = False | |||||
| self.grad_reducer = None | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("gradients_mean") | |||||
| if auto_parallel_context().get_device_num_is_set(): | |||||
| degree = context.get_auto_parallel_context("device_num") | |||||
| else: | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| def construct(self, *args): | |||||
| weights = self.weights | |||||
| loss = self.network(*args) | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| grads = self.grad(self.network, weights)(*args, sens) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| @@ -0,0 +1,190 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """YOLOV3 dataset.""" | |||||
| import os | |||||
| import multiprocessing | |||||
| import cv2 | |||||
| from PIL import Image | |||||
| from pycocotools.coco import COCO | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.vision.c_transforms as CV | |||||
| from src.distributed_sampler import DistributedSampler | |||||
| from src.transforms import reshape_fn, MultiScaleTrans | |||||
| min_keypoints_per_image = 10 | |||||
| def _has_only_empty_bbox(anno): | |||||
| return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | |||||
| def _count_visible_keypoints(anno): | |||||
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |||||
| def has_valid_annotation(anno): | |||||
| """Check annotation file.""" | |||||
| # if it's empty, there is no annotation | |||||
| if not anno: | |||||
| return False | |||||
| # if all boxes have close to zero area, there is no annotation | |||||
| if _has_only_empty_bbox(anno): | |||||
| return False | |||||
| # keypoints task have a slight different critera for considering | |||||
| # if an annotation is valid | |||||
| if "keypoints" not in anno[0]: | |||||
| return True | |||||
| # for keypoint detection tasks, only consider valid images those | |||||
| # containing at least min_keypoints_per_image | |||||
| if _count_visible_keypoints(anno) >= min_keypoints_per_image: | |||||
| return True | |||||
| return False | |||||
| class COCOYoloDataset: | |||||
| """YOLOV3 Dataset for COCO.""" | |||||
| def __init__(self, root, ann_file, remove_images_without_annotations=True, | |||||
| filter_crowd_anno=True, is_training=True): | |||||
| self.coco = COCO(ann_file) | |||||
| self.root = root | |||||
| self.img_ids = list(sorted(self.coco.imgs.keys())) | |||||
| self.filter_crowd_anno = filter_crowd_anno | |||||
| self.is_training = is_training | |||||
| # filter images without any annotations | |||||
| if remove_images_without_annotations: | |||||
| img_ids = [] | |||||
| for img_id in self.img_ids: | |||||
| ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |||||
| anno = self.coco.loadAnns(ann_ids) | |||||
| if has_valid_annotation(anno): | |||||
| img_ids.append(img_id) | |||||
| self.img_ids = img_ids | |||||
| self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()} | |||||
| self.cat_ids_to_continuous_ids = { | |||||
| v: i for i, v in enumerate(self.coco.getCatIds()) | |||||
| } | |||||
| self.continuous_ids_cat_ids = { | |||||
| v: k for k, v in self.cat_ids_to_continuous_ids.items() | |||||
| } | |||||
| def __getitem__(self, index): | |||||
| """ | |||||
| Args: | |||||
| index (int): Index | |||||
| Returns: | |||||
| (img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints", | |||||
| generated by the image's annotation. img is a PIL image. | |||||
| """ | |||||
| coco = self.coco | |||||
| img_id = self.img_ids[index] | |||||
| img_path = coco.loadImgs(img_id)[0]["file_name"] | |||||
| img = Image.open(os.path.join(self.root, img_path)).convert("RGB") | |||||
| if not self.is_training: | |||||
| return img, img_id | |||||
| ann_ids = coco.getAnnIds(imgIds=img_id) | |||||
| target = coco.loadAnns(ann_ids) | |||||
| # filter crowd annotations | |||||
| if self.filter_crowd_anno: | |||||
| annos = [anno for anno in target if anno["iscrowd"] == 0] | |||||
| else: | |||||
| annos = [anno for anno in target] | |||||
| target = {} | |||||
| boxes = [anno["bbox"] for anno in annos] | |||||
| target["bboxes"] = boxes | |||||
| classes = [anno["category_id"] for anno in annos] | |||||
| classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes] | |||||
| target["labels"] = classes | |||||
| bboxes = target['bboxes'] | |||||
| labels = target['labels'] | |||||
| out_target = [] | |||||
| for bbox, label in zip(bboxes, labels): | |||||
| tmp = [] | |||||
| # convert to [x_min y_min x_max y_max] | |||||
| bbox = self._convetTopDown(bbox) | |||||
| tmp.extend(bbox) | |||||
| tmp.append(int(label)) | |||||
| # tmp [x_min y_min x_max y_max, label] | |||||
| out_target.append(tmp) | |||||
| return img, out_target, [], [], [], [], [], [] | |||||
| def __len__(self): | |||||
| return len(self.img_ids) | |||||
| def _convetTopDown(self, bbox): | |||||
| x_min = bbox[0] | |||||
| y_min = bbox[1] | |||||
| w = bbox[2] | |||||
| h = bbox[3] | |||||
| return [x_min, y_min, x_min+w, y_min+h] | |||||
| def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, | |||||
| config=None, is_training=True, shuffle=True, num_samples=256): | |||||
| """Create dataset for YOLOV3.""" | |||||
| cv2.setNumThreads(0) | |||||
| if is_training: | |||||
| filter_crowd = True | |||||
| remove_empty_anno = True | |||||
| else: | |||||
| filter_crowd = False | |||||
| remove_empty_anno = False | |||||
| yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd, | |||||
| remove_images_without_annotations=remove_empty_anno, is_training=is_training) | |||||
| distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle) | |||||
| hwc_to_chw = CV.HWC2CHW() | |||||
| config.dataset_size = len(yolo_dataset) | |||||
| cores = multiprocessing.cpu_count() | |||||
| num_parallel_workers = int(cores / device_num) | |||||
| if is_training: | |||||
| multi_scale_trans = MultiScaleTrans(config, device_num) | |||||
| dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", | |||||
| "gt_box1", "gt_box2", "gt_box3"] | |||||
| if device_num != 8: | |||||
| ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, | |||||
| num_parallel_workers=min(32, num_parallel_workers), | |||||
| sampler=distributed_sampler, num_samples=num_samples) | |||||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, | |||||
| num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True) | |||||
| else: | |||||
| ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler) | |||||
| ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, | |||||
| num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True) | |||||
| else: | |||||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], | |||||
| sampler=distributed_sampler, num_samples=num_samples) | |||||
| compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config)) | |||||
| ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"], | |||||
| output_columns=["image", "image_shape", "img_id"], | |||||
| column_order=["image", "image_shape", "img_id"], | |||||
| num_parallel_workers=8) | |||||
| ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(max_epoch) | |||||
| return ds, num_samples | |||||
| @@ -0,0 +1,211 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| ######################## train YOLOv3_DARKNET53 example ######################## | |||||
| train YOLOv3 and get network model files(.ckpt) : | |||||
| python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train | |||||
| If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path. | |||||
| Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path. | |||||
| """ | |||||
| import os | |||||
| import time | |||||
| import re | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.train.callback import Callback | |||||
| from mindspore.context import ParallelMode | |||||
| from mindspore.nn.optim.momentum import Momentum | |||||
| import mindspore as ms | |||||
| from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper | |||||
| from src.util import AverageMeter, get_param_groups | |||||
| from src.lr_scheduler import warmup_cosine_annealing_lr | |||||
| from src.yolo_dataset import create_yolo_dataset | |||||
| from src.initializer import default_recurisive_init | |||||
| from src.config import ConfigYOLOV3DarkNet53 | |||||
| np.random.seed(1) | |||||
| def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): | |||||
| """Set learning rate.""" | |||||
| lr_each_step = [] | |||||
| for i in range(global_step): | |||||
| if steps: | |||||
| lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step))) | |||||
| else: | |||||
| lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step))) | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| lr_each_step = lr_each_step[start_step:] | |||||
| return lr_each_step | |||||
| def init_net_param(network, init_value='ones'): | |||||
| """Init:wq the parameters in network.""" | |||||
| params = network.trainable_params() | |||||
| for p in params: | |||||
| if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: | |||||
| p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype)) | |||||
| class ModelCallback(Callback): | |||||
| def __init__(self): | |||||
| super(ModelCallback, self).__init__() | |||||
| self.loss_list = [] | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| self.loss_list.append(cb_params.net_outputs.asnumpy()) | |||||
| print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) | |||||
| class TimeMonitor(Callback): | |||||
| """Time Monitor.""" | |||||
| def __init__(self, data_size): | |||||
| super(TimeMonitor, self).__init__() | |||||
| self.data_size = data_size | |||||
| self.epoch_mseconds_list = [] | |||||
| self.per_step_mseconds_list = [] | |||||
| def epoch_begin(self, run_context): | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| self.epoch_mseconds_list.append(epoch_mseconds) | |||||
| self.per_step_mseconds_list.append(epoch_mseconds / self.data_size) | |||||
| DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2014/" | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_yolov3_darknet53(): | |||||
| devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | |||||
| device_target="Ascend", device_id=devid) | |||||
| rank = 0 | |||||
| device_num = 1 | |||||
| lr_init = 0.001 | |||||
| epoch_size = 3 | |||||
| batch_size = 32 | |||||
| loss_scale = 1024 | |||||
| mindrecord_dir = DATA_DIR | |||||
| # It will generate mindrecord file in args_opt.mindrecord_dir, | |||||
| # and the file name is yolo.mindrecord0, 1, ... file_num. | |||||
| if not os.path.isdir(mindrecord_dir): | |||||
| raise KeyError("mindrecord path is not exist.") | |||||
| data_root = os.path.join(mindrecord_dir, 'train2014') | |||||
| annFile = os.path.join(mindrecord_dir, 'annotations/instances_train2014.json') | |||||
| # print("yolov3 mindrecord is ", mindrecord_file) | |||||
| if not os.path.exists(annFile): | |||||
| print("instances_train2014 file is not exist.") | |||||
| assert False | |||||
| loss_meter = AverageMeter('loss') | |||||
| context.reset_auto_parallel_context() | |||||
| parallel_mode = ParallelMode.STAND_ALONE | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1) | |||||
| network = YOLOV3DarkNet53(is_training=True) | |||||
| # default is kaiming-normal | |||||
| default_recurisive_init(network) | |||||
| network = YoloWithLossCell(network) | |||||
| print('finish get network') | |||||
| config = ConfigYOLOV3DarkNet53() | |||||
| label_smooth = 0 | |||||
| label_smooth_factor = 0.1 | |||||
| config.label_smooth = label_smooth | |||||
| config.label_smooth_factor = label_smooth_factor | |||||
| # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. | |||||
| print("Create dataset begin!") | |||||
| training_shape = [int(416), int(416)] | |||||
| config.multi_scale = [training_shape] | |||||
| num_samples = 256 | |||||
| ds, data_size = create_yolo_dataset(image_dir=data_root, anno_path=annFile, is_training=True, | |||||
| batch_size=batch_size, max_epoch=epoch_size, | |||||
| device_num=device_num, rank=rank, config=config, num_samples=num_samples) | |||||
| print("Create dataset done!") | |||||
| per_batch_size = batch_size | |||||
| group_size = 1 | |||||
| print("data_size:", data_size) | |||||
| steps_per_epoch = int(data_size / per_batch_size / group_size) | |||||
| print("steps_per_epoch:", steps_per_epoch) | |||||
| warmup_epochs = 0. | |||||
| max_epoch = epoch_size | |||||
| T_max = 1 | |||||
| eta_min = 0 | |||||
| lr = warmup_cosine_annealing_lr(lr_init, | |||||
| steps_per_epoch, | |||||
| warmup_epochs, | |||||
| max_epoch, | |||||
| T_max, | |||||
| eta_min) | |||||
| opt = Momentum(params=get_param_groups(network), | |||||
| learning_rate=Tensor(lr), | |||||
| momentum=0.9, | |||||
| weight_decay=0.0005, | |||||
| loss_scale=loss_scale) | |||||
| network = TrainingWrapper(network, opt) | |||||
| network.set_train() | |||||
| old_progress = -1 | |||||
| t_end = time.time() | |||||
| data_loader = ds.create_dict_iterator(output_numpy=True) | |||||
| train_starttime = time.time() | |||||
| time_used_per_epoch = 0 | |||||
| print("time:", time.time()) | |||||
| for i, data in enumerate(data_loader): | |||||
| images = data["image"] | |||||
| input_shape = images.shape[2:4] | |||||
| print('iter[{}], shape{}'.format(i, input_shape[0])) | |||||
| images = Tensor.from_numpy(images) | |||||
| batch_y_true_0 = Tensor.from_numpy(data['bbox1']) | |||||
| batch_y_true_1 = Tensor.from_numpy(data['bbox2']) | |||||
| batch_y_true_2 = Tensor.from_numpy(data['bbox3']) | |||||
| batch_gt_box0 = Tensor.from_numpy(data['gt_box1']) | |||||
| batch_gt_box1 = Tensor.from_numpy(data['gt_box2']) | |||||
| batch_gt_box2 = Tensor.from_numpy(data['gt_box3']) | |||||
| input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) | |||||
| loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, | |||||
| batch_gt_box2, input_shape) | |||||
| loss_meter.update(loss.asnumpy()) | |||||
| if (i + 1) % steps_per_epoch == 0: | |||||
| time_used = time.time() - t_end | |||||
| epoch = int(i / steps_per_epoch) | |||||
| fps = per_batch_size * (i - old_progress) * group_size / time_used | |||||
| if rank == 0: | |||||
| print( | |||||
| 'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}, time_used:{}'.format(epoch, | |||||
| i, loss_meter, fps, lr[i], | |||||
| time_used)) | |||||
| t_end = time.time() | |||||
| loss_meter.reset() | |||||
| old_progress = i | |||||
| time_used_per_epoch = time_used | |||||
| train_endtime = time.time() - train_starttime | |||||
| print('train_time_used:{}'.format(train_endtime)) | |||||
| expect_loss_value = 3210.0 | |||||
| loss_value = re.findall(r"\d+\.?\d*", str(loss_meter)) | |||||
| print('loss_value:{}'.format(loss_value[0])) | |||||
| assert float(loss_value[0]) < expect_loss_value | |||||
| export_time_used = 20.0 | |||||
| print('time_used_per_epoch:{}'.format(time_used_per_epoch)) | |||||
| assert time_used_per_epoch < export_time_used | |||||
| print('==========test case passed===========') | |||||