Merge pull request !107 from zhaoting/add-YOLOv3-infer-scipt-and-change-dataset-to-MindRecordtags/v0.2.0-alpha
| @@ -26,6 +26,7 @@ class ConfigYOLOV3ResNet18: | |||||
| img_shape = [352, 640] | img_shape = [352, 640] | ||||
| feature_shape = [32, 3, 352, 640] | feature_shape = [32, 3, 352, 640] | ||||
| num_classes = 80 | num_classes = 80 | ||||
| nms_max_num = 50 | |||||
| backbone_input_shape = [64, 64, 128, 256] | backbone_input_shape = [64, 64, 128, 256] | ||||
| backbone_shape = [64, 128, 256, 512] | backbone_shape = [64, 128, 256, 512] | ||||
| @@ -33,6 +34,8 @@ class ConfigYOLOV3ResNet18: | |||||
| backbone_stride = [1, 2, 2, 2] | backbone_stride = [1, 2, 2, 2] | ||||
| ignore_threshold = 0.5 | ignore_threshold = 0.5 | ||||
| obj_threshold = 0.3 | |||||
| nms_threshold = 0.4 | |||||
| anchor_scales = [(10, 13), | anchor_scales = [(10, 13), | ||||
| (16, 30), | (16, 30), | ||||
| @@ -16,16 +16,14 @@ | |||||
| """YOLOv3 dataset""" | """YOLOv3 dataset""" | ||||
| from __future__ import division | from __future__ import division | ||||
| import abc | |||||
| import io | |||||
| import os | import os | ||||
| import math | |||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image | from PIL import Image | ||||
| from matplotlib.colors import rgb_to_hsv, hsv_to_rgb | from matplotlib.colors import rgb_to_hsv, hsv_to_rgb | ||||
| import mindspore.dataset as de | import mindspore.dataset as de | ||||
| from mindspore.mindrecord import FileWriter | |||||
| import mindspore.dataset.transforms.vision.py_transforms as P | import mindspore.dataset.transforms.vision.py_transforms as P | ||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| from config import ConfigYOLOV3ResNet18 | from config import ConfigYOLOV3ResNet18 | ||||
| iter_cnt = 0 | iter_cnt = 0 | ||||
| @@ -114,6 +112,29 @@ def preprocess_fn(image, box, is_training): | |||||
| return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 | return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 | ||||
| def _infer_data(img_data, input_shape, box): | |||||
| w, h = img_data.size | |||||
| input_h, input_w = input_shape | |||||
| scale = min(float(input_w) / float(w), float(input_h) / float(h)) | |||||
| nw = int(w * scale) | |||||
| nh = int(h * scale) | |||||
| img_data = img_data.resize((nw, nh), Image.BICUBIC) | |||||
| new_image = np.zeros((input_h, input_w, 3), np.float32) | |||||
| new_image.fill(128) | |||||
| img_data = np.array(img_data) | |||||
| if len(img_data.shape) == 2: | |||||
| img_data = np.expand_dims(img_data, axis=-1) | |||||
| img_data = np.concatenate([img_data, img_data, img_data], axis=-1) | |||||
| dh = int((input_h - nh) / 2) | |||||
| dw = int((input_w - nw) / 2) | |||||
| new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data | |||||
| new_image /= 255. | |||||
| new_image = np.transpose(new_image, (2, 0, 1)) | |||||
| new_image = np.expand_dims(new_image, 0) | |||||
| return new_image, np.array([h, w], np.float32), box | |||||
| def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)): | def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)): | ||||
| """Data augmentation function.""" | """Data augmentation function.""" | ||||
| if not isinstance(image, Image.Image): | if not isinstance(image, Image.Image): | ||||
| @@ -124,32 +145,7 @@ def preprocess_fn(image, box, is_training): | |||||
| h, w = image_size | h, w = image_size | ||||
| if not is_training: | if not is_training: | ||||
| image = image.resize((w, h), Image.BICUBIC) | |||||
| image_data = np.array(image) / 255. | |||||
| if len(image_data.shape) == 2: | |||||
| image_data = np.expand_dims(image_data, axis=-1) | |||||
| image_data = np.concatenate([image_data, image_data, image_data], axis=-1) | |||||
| image_data = image_data.astype(np.float32) | |||||
| # correct boxes | |||||
| box_data = np.zeros((max_boxes, 5)) | |||||
| if len(box) >= 1: | |||||
| np.random.shuffle(box) | |||||
| if len(box) > max_boxes: | |||||
| box = box[:max_boxes] | |||||
| # xmin ymin xmax ymax | |||||
| box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw) | |||||
| box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih) | |||||
| box_data[:len(box)] = box | |||||
| else: | |||||
| image_data, box_data = None, None | |||||
| # preprocess bounding boxes | |||||
| bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ | |||||
| _preprocess_true_boxes(box_data, anchors, image_size) | |||||
| return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ | |||||
| ori_image_shape, gt_box1, gt_box2, gt_box3 | |||||
| return _infer_data(image, image_size, box) | |||||
| flip = _rand() < .5 | flip = _rand() < .5 | ||||
| # correct boxes | # correct boxes | ||||
| @@ -235,12 +231,16 @@ def preprocess_fn(image, box, is_training): | |||||
| return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ | return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ | ||||
| ori_image_shape, gt_box1, gt_box2, gt_box3 | ori_image_shape, gt_box1, gt_box2, gt_box3 | ||||
| images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training) | |||||
| return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3 | |||||
| if is_training: | |||||
| images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training) | |||||
| return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3 | |||||
| images, shape, anno = _data_aug(image, box, is_training) | |||||
| return images, shape, anno | |||||
| def anno_parser(annos_str): | def anno_parser(annos_str): | ||||
| """Annotation parser.""" | |||||
| """Parse annotation from string to list.""" | |||||
| annos = [] | annos = [] | ||||
| for anno_str in annos_str: | for anno_str in annos_str: | ||||
| anno = list(map(int, anno_str.strip().split(','))) | anno = list(map(int, anno_str.strip().split(','))) | ||||
| @@ -248,142 +248,71 @@ def anno_parser(annos_str): | |||||
| return annos | return annos | ||||
| def expand_path(path): | |||||
| """Get file list from path.""" | |||||
| files = [] | |||||
| if os.path.isdir(path): | |||||
| for file in os.listdir(path): | |||||
| if os.path.isfile(os.path.join(path, file)): | |||||
| files.append(file) | |||||
| else: | |||||
| def filter_valid_data(image_dir, anno_path): | |||||
| """Filter valid image file, which both in image_dir and anno_path.""" | |||||
| image_files = [] | |||||
| image_anno_dict = {} | |||||
| if not os.path.isdir(image_dir): | |||||
| raise RuntimeError("Path given is not valid.") | raise RuntimeError("Path given is not valid.") | ||||
| return files | |||||
| def read_image(img_path): | |||||
| """Read image with PIL.""" | |||||
| with open(img_path, "rb") as f: | |||||
| img = f.read() | |||||
| data = io.BytesIO(img) | |||||
| img = Image.open(data) | |||||
| return np.array(img) | |||||
| class BaseDataset(): | |||||
| """BaseDataset for GeneratorDataset iterator.""" | |||||
| def __init__(self, image_dir, anno_path): | |||||
| self.image_dir = image_dir | |||||
| self.anno_path = anno_path | |||||
| self.cur_index = 0 | |||||
| self.samples = [] | |||||
| self.image_anno_dict = {} | |||||
| self._load_samples() | |||||
| def __getitem__(self, item): | |||||
| sample = self.samples[item] | |||||
| return self._next_data(sample, self.image_dir, self.image_anno_dict) | |||||
| def __len__(self): | |||||
| return len(self.samples) | |||||
| @staticmethod | |||||
| def _next_data(sample, image_dir, image_anno_dict): | |||||
| """Get next data.""" | |||||
| image = read_image(os.path.join(image_dir, sample)) | |||||
| annos = image_anno_dict[sample] | |||||
| return [np.array(image), np.array(annos)] | |||||
| @abc.abstractmethod | |||||
| def _load_samples(self): | |||||
| """Base load samples.""" | |||||
| class YoloDataset(BaseDataset): | |||||
| """YoloDataset for GeneratorDataset iterator.""" | |||||
| def _load_samples(self): | |||||
| """Load samples.""" | |||||
| image_files_raw = expand_path(self.image_dir) | |||||
| self.samples = self._filter_valid_data(self.anno_path, image_files_raw) | |||||
| self.dataset_size = len(self.samples) | |||||
| if self.dataset_size == 0: | |||||
| raise RuntimeError("Valid dataset is none!") | |||||
| def _filter_valid_data(self, anno_path, image_files_raw): | |||||
| """Filter valid data.""" | |||||
| image_files = [] | |||||
| anno_dict = {} | |||||
| print("Start filter valid data.") | |||||
| with open(anno_path, "rb") as f: | |||||
| lines = f.readlines() | |||||
| for line in lines: | |||||
| line_str = line.decode("utf-8") | |||||
| line_split = str(line_str).split(' ') | |||||
| anno_dict[line_split[0].split("/")[-1]] = line_split[1:] | |||||
| anno_set = set(anno_dict.keys()) | |||||
| image_set = set(image_files_raw) | |||||
| for image_file in (anno_set & image_set): | |||||
| image_files.append(image_file) | |||||
| self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file]) | |||||
| image_files.sort() | |||||
| print("Filter valid data done!") | |||||
| return image_files | |||||
| class DistributedSampler(): | |||||
| """DistributedSampler for YOLOv3""" | |||||
| def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True): | |||||
| if num_replicas is None: | |||||
| num_replicas = 1 | |||||
| if rank is None: | |||||
| rank = 0 | |||||
| self.dataset_size = dataset_size | |||||
| self.num_replicas = num_replicas | |||||
| self.rank = rank % num_replicas | |||||
| self.epoch = 0 | |||||
| self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas))) | |||||
| self.total_size = self.num_samples * self.num_replicas | |||||
| self.shuffle = shuffle | |||||
| def __iter__(self): | |||||
| # deterministically shuffle based on epoch | |||||
| if self.shuffle: | |||||
| indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) | |||||
| indices = indices.tolist() | |||||
| else: | |||||
| indices = list(range(self.dataset_size)) | |||||
| # add extra samples to make it evenly divisible | |||||
| indices += indices[:(self.total_size - len(indices))] | |||||
| assert len(indices) == self.total_size | |||||
| # subsample | |||||
| indices = indices[self.rank:self.total_size:self.num_replicas] | |||||
| assert len(indices) == self.num_samples | |||||
| return iter(indices) | |||||
| def __len__(self): | |||||
| return self.num_samples | |||||
| def set_epoch(self, epoch): | |||||
| self.epoch = epoch | |||||
| def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0, | |||||
| if not os.path.isfile(anno_path): | |||||
| raise RuntimeError("Annotation file is not valid.") | |||||
| with open(anno_path, "rb") as f: | |||||
| lines = f.readlines() | |||||
| for line in lines: | |||||
| line_str = line.decode("utf-8").strip() | |||||
| line_split = str(line_str).split(' ') | |||||
| file_name = line_split[0] | |||||
| if os.path.isfile(os.path.join(image_dir, file_name)): | |||||
| image_anno_dict[file_name] = anno_parser(line_split[1:]) | |||||
| image_files.append(file_name) | |||||
| return image_files, image_anno_dict | |||||
| def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8): | |||||
| """Create MindRecord file by image_dir and anno_path.""" | |||||
| mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||||
| writer = FileWriter(mindrecord_path, file_num) | |||||
| image_files, image_anno_dict = filter_valid_data(image_dir, anno_path) | |||||
| yolo_json = { | |||||
| "image": {"type": "bytes"}, | |||||
| "annotation": {"type": "int64", "shape": [-1, 5]}, | |||||
| } | |||||
| writer.add_schema(yolo_json, "yolo_json") | |||||
| for image_name in image_files: | |||||
| image_path = os.path.join(image_dir, image_name) | |||||
| with open(image_path, 'rb') as f: | |||||
| img = f.read() | |||||
| annos = np.array(image_anno_dict[image_name]) | |||||
| row = {"image": img, "annotation": annos} | |||||
| writer.write_raw_data([row]) | |||||
| writer.commit() | |||||
| def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0, | |||||
| is_training=True, num_parallel_workers=8): | is_training=True, num_parallel_workers=8): | ||||
| """Creatr YOLOv3 dataset with GeneratorDataset.""" | |||||
| yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path) | |||||
| distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank) | |||||
| ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler) | |||||
| ds.set_dataset_size(len(distributed_sampler)) | |||||
| """Creatr YOLOv3 dataset with MindDataset.""" | |||||
| ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, | |||||
| num_parallel_workers=num_parallel_workers, shuffle=is_training) | |||||
| decode = C.Decode() | |||||
| ds = ds.map(input_columns=["image"], operations=decode) | |||||
| compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | ||||
| hwc_to_chw = P.HWC2CHW() | |||||
| ds = ds.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], | |||||
| columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], | |||||
| operations=compose_map_func, num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.shuffle(buffer_size=256) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_num) | |||||
| if is_training: | |||||
| hwc_to_chw = P.HWC2CHW() | |||||
| ds = ds.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], | |||||
| columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], | |||||
| operations=compose_map_func, num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) | |||||
| ds = ds.shuffle(buffer_size=256) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_num) | |||||
| else: | |||||
| ds = ds.map(input_columns=["image", "annotation"], | |||||
| output_columns=["image", "image_shape", "annotation"], | |||||
| columns_order=["image", "image_shape", "annotation"], | |||||
| operations=compose_map_func, num_parallel_workers=num_parallel_workers) | |||||
| return ds | return ds | ||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Evaluation for yolo_v3""" | |||||
| import os | |||||
| import argparse | |||||
| import time | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithEval | |||||
| from dataset import create_yolo_dataset, data_to_mindrecord_byte_image | |||||
| from config import ConfigYOLOV3ResNet18 | |||||
| from util import metrics | |||||
| def yolo_eval(dataset_path, ckpt_path): | |||||
| """Yolov3 evaluation.""" | |||||
| ds = create_yolo_dataset(dataset_path, is_training=False) | |||||
| config = ConfigYOLOV3ResNet18() | |||||
| net = yolov3_resnet18(config) | |||||
| eval_net = YoloWithEval(net, config) | |||||
| print("Load Checkpoint!") | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| eval_net.set_train(False) | |||||
| i = 1. | |||||
| total = ds.get_dataset_size() | |||||
| start = time.time() | |||||
| pred_data = [] | |||||
| print("\n========================================\n") | |||||
| print("total images num: ", total) | |||||
| print("Processing, please wait a moment.") | |||||
| for data in ds.create_dict_iterator(): | |||||
| img_np = data['image'] | |||||
| image_shape = data['image_shape'] | |||||
| annotation = data['annotation'] | |||||
| eval_net.set_train(False) | |||||
| output = eval_net(Tensor(img_np), Tensor(image_shape)) | |||||
| for batch_idx in range(img_np.shape[0]): | |||||
| pred_data.append({"boxes": output[0].asnumpy()[batch_idx], | |||||
| "box_scores": output[1].asnumpy()[batch_idx], | |||||
| "annotation": annotation}) | |||||
| percent = round(i / total * 100, 2) | |||||
| print(' %s [%d/%d]' % (str(percent) + '%', i, total), end='\r') | |||||
| i += 1 | |||||
| print(' %s [%d/%d] cost %d ms' % (str(100.0) + '%', total, total, int((time.time() - start) * 1000)), end='\n') | |||||
| precisions, recalls = metrics(pred_data) | |||||
| print("\n========================================\n") | |||||
| for i in range(config.num_classes): | |||||
| print("class {} precision is {:.2f}%, recall is {:.2f}%".format(i, precisions[i] * 100, recalls[i] * 100)) | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description='Yolov3 evaluation') | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||||
| parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_eval", | |||||
| help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by" | |||||
| "image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir " | |||||
| "rather than image_dir and anno_path. Default is ./Mindrecord_eval") | |||||
| parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, " | |||||
| "the absolute image path is joined by the image_dir " | |||||
| "and the relative path in anno_path.") | |||||
| parser.add_argument("--anno_path", type=str, default="", help="Annotation path.") | |||||
| parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path.") | |||||
| args_opt = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True, | |||||
| enable_auto_mixed_precision=False) | |||||
| # It will generate mindrecord file in args_opt.mindrecord_dir, | |||||
| # and the file name is yolo.mindrecord0, 1, ... file_num. | |||||
| if not os.path.isdir(args_opt.mindrecord_dir): | |||||
| os.makedirs(args_opt.mindrecord_dir) | |||||
| prefix = "yolo.mindrecord" | |||||
| mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0") | |||||
| if not os.path.exists(mindrecord_file): | |||||
| if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path): | |||||
| print("Create Mindrecord") | |||||
| data_to_mindrecord_byte_image(args_opt.image_dir, | |||||
| args_opt.anno_path, | |||||
| args_opt.mindrecord_dir, | |||||
| prefix=prefix, | |||||
| file_num=8) | |||||
| print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir)) | |||||
| else: | |||||
| print("image_dir or anno_path not exits") | |||||
| print("Start Eval!") | |||||
| yolo_eval(mindrecord_file, args_opt.ckpt_path) | |||||
| @@ -14,17 +14,26 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH" | |||||
| echo "for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json" | |||||
| echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" | |||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH" | |||||
| echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json" | |||||
| echo "It is better to use absolute path." | |||||
| echo "==============================================================================================================" | |||||
| export RANK_SIZE=$1 | |||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| IMAGE_DIR=$3 | |||||
| ANNO_PATH=$4 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$5 | |||||
| MINDRECORD_DIR=$3 | |||||
| IMAGE_DIR=$4 | |||||
| ANNO_PATH=$5 | |||||
| # Before start distribute train, first create mindrecord files. | |||||
| python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \ | |||||
| --anno_path=$ANNO_PATH | |||||
| echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt" | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$6 | |||||
| export RANK_SIZE=$1 | |||||
| for((i=0;i<RANK_SIZE;i++)) | for((i=0;i<RANK_SIZE;i++)) | ||||
| do | do | ||||
| @@ -40,6 +49,7 @@ do | |||||
| --distribute=1 \ | --distribute=1 \ | ||||
| --device_num=$RANK_SIZE \ | --device_num=$RANK_SIZE \ | ||||
| --device_id=$DEVICE_ID \ | --device_id=$DEVICE_ID \ | ||||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||||
| --image_dir=$IMAGE_DIR \ | --image_dir=$IMAGE_DIR \ | ||||
| --epoch_size=$EPOCH_SIZE \ | --epoch_size=$EPOCH_SIZE \ | ||||
| --anno_path=$ANNO_PATH > log.txt 2>&1 & | --anno_path=$ANNO_PATH > log.txt 2>&1 & | ||||
| @@ -0,0 +1,23 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | |||||
| echo "sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH" | |||||
| echo "for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt" | |||||
| echo "==============================================================================================================" | |||||
| python eval.py --device_id=$1 --ckpt_path=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5 | |||||
| @@ -14,8 +14,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH" | |||||
| echo "for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt" | |||||
| echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH" | |||||
| echo "for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt" | |||||
| echo "==============================================================================================================" | |||||
| python train.py --device_id=$1 --epoch_size=$2 --image_dir=$3 --anno_path=$4 | |||||
| python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5 | |||||
| @@ -16,26 +16,30 @@ | |||||
| """ | """ | ||||
| ######################## train YOLOv3 example ######################## | ######################## train YOLOv3 example ######################## | ||||
| train YOLOv3 and get network model files(.ckpt) : | train YOLOv3 and get network model files(.ckpt) : | ||||
| python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt | |||||
| python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train | |||||
| If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path. | |||||
| Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path. | |||||
| """ | """ | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.communication.management import init | from mindspore.communication.management import init | ||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor | from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor | ||||
| from mindspore.train import Model, ParallelMode | from mindspore.train import Model, ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper | from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper | ||||
| from dataset import create_yolo_dataset | |||||
| from dataset import create_yolo_dataset, data_to_mindrecord_byte_image | |||||
| from config import ConfigYOLOV3ResNet18 | from config import ConfigYOLOV3ResNet18 | ||||
| def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): | def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): | ||||
| """Set learning rate""" | |||||
| """Set learning rate.""" | |||||
| lr_each_step = [] | lr_each_step = [] | ||||
| lr = learning_rate | lr = learning_rate | ||||
| for i in range(global_step): | for i in range(global_step): | ||||
| @@ -57,7 +61,9 @@ def init_net_param(net, init='ones'): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| parser = argparse.ArgumentParser(description="YOLOv3") | |||||
| parser = argparse.ArgumentParser(description="YOLOv3 train") | |||||
| parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " | |||||
| "Mindrecord, default is false.") | |||||
| parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") | parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | ||||
| @@ -67,12 +73,19 @@ if __name__ == '__main__': | |||||
| parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") | parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") | ||||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | ||||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | ||||
| parser.add_argument("--image_dir", type=str, required=True, help="Dataset image dir.") | |||||
| parser.add_argument("--anno_path", type=str, required=True, help="Dataset anno path.") | |||||
| parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", | |||||
| help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by" | |||||
| "image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir " | |||||
| "rather than image_dir and anno_path. Default is ./Mindrecord_train") | |||||
| parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, " | |||||
| "the absolute image path is joined by the image_dir " | |||||
| "and the relative path in anno_path") | |||||
| parser.add_argument("--anno_path", type=str, default="", help="Annotation path.") | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | ||||
| context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True) | |||||
| context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True, | |||||
| enable_auto_mixed_precision=False) | |||||
| if args_opt.distribute: | if args_opt.distribute: | ||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| @@ -80,36 +93,65 @@ if __name__ == '__main__': | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | ||||
| device_num=device_num) | device_num=device_num) | ||||
| init() | init() | ||||
| rank = args_opt.device_id | |||||
| rank = args_opt.device_id % device_num | |||||
| else: | else: | ||||
| context.set_context(enable_hccl=False) | context.set_context(enable_hccl=False) | ||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| loss_scale = float(args_opt.loss_scale) | |||||
| dataset = create_yolo_dataset(args_opt.image_dir, args_opt.anno_path, repeat_num=args_opt.epoch_size, | |||||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||||
| dataset_size = dataset.get_dataset_size() | |||||
| net = yolov3_resnet18(ConfigYOLOV3ResNet18()) | |||||
| net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) | |||||
| init_net_param(net, "XavierUniform") | |||||
| # checkpoint | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) | |||||
| if args_opt.checkpoint_path != "": | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size, | |||||
| decay_step=1000, decay_rate=0.95)) | |||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||||
| model = Model(net) | |||||
| dataset_sink_mode = False | |||||
| if args_opt.mode == "graph": | |||||
| dataset_sink_mode = True | |||||
| print("Start train YOLOv3.") | |||||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||||
| print("Start create dataset!") | |||||
| # It will generate mindrecord file in args_opt.mindrecord_dir, | |||||
| # and the file name is yolo.mindrecord0, 1, ... file_num. | |||||
| if not os.path.isdir(args_opt.mindrecord_dir): | |||||
| os.makedirs(args_opt.mindrecord_dir) | |||||
| prefix = "yolo.mindrecord" | |||||
| mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0") | |||||
| if not os.path.exists(mindrecord_file): | |||||
| if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path): | |||||
| print("Create Mindrecord.") | |||||
| data_to_mindrecord_byte_image(args_opt.image_dir, | |||||
| args_opt.anno_path, | |||||
| args_opt.mindrecord_dir, | |||||
| prefix=prefix, | |||||
| file_num=8) | |||||
| print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir)) | |||||
| else: | |||||
| print("image_dir or anno_path not exits.") | |||||
| if not args_opt.only_create_dataset: | |||||
| loss_scale = float(args_opt.loss_scale) | |||||
| # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. | |||||
| dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size, | |||||
| batch_size=args_opt.batch_size, device_num=device_num, rank=rank) | |||||
| dataset_size = dataset.get_dataset_size() | |||||
| print("Create dataset done!") | |||||
| net = yolov3_resnet18(ConfigYOLOV3ResNet18()) | |||||
| net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) | |||||
| init_net_param(net, "XavierUniform") | |||||
| # checkpoint | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) | |||||
| lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size, | |||||
| decay_step=1000, decay_rate=0.95)) | |||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| if args_opt.checkpoint_path != "": | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||||
| model = Model(net) | |||||
| dataset_sink_mode = False | |||||
| if args_opt.mode == "graph": | |||||
| print("In graph mode, one epoch return a loss.") | |||||
| dataset_sink_mode = True | |||||
| print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") | |||||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | |||||
| @@ -0,0 +1,146 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """metrics utils""" | |||||
| import numpy as np | |||||
| from config import ConfigYOLOV3ResNet18 | |||||
| def calc_iou(bbox_pred, bbox_ground): | |||||
| """Calculate iou of predicted bbox and ground truth.""" | |||||
| x1 = bbox_pred[0] | |||||
| y1 = bbox_pred[1] | |||||
| width1 = bbox_pred[2] - bbox_pred[0] | |||||
| height1 = bbox_pred[3] - bbox_pred[1] | |||||
| x2 = bbox_ground[0] | |||||
| y2 = bbox_ground[1] | |||||
| width2 = bbox_ground[2] - bbox_ground[0] | |||||
| height2 = bbox_ground[3] - bbox_ground[1] | |||||
| endx = max(x1 + width1, x2 + width2) | |||||
| startx = min(x1, x2) | |||||
| width = width1 + width2 - (endx - startx) | |||||
| endy = max(y1 + height1, y2 + height2) | |||||
| starty = min(y1, y2) | |||||
| height = height1 + height2 - (endy - starty) | |||||
| if width <= 0 or height <= 0: | |||||
| iou = 0 | |||||
| else: | |||||
| area = width * height | |||||
| area1 = width1 * height1 | |||||
| area2 = width2 * height2 | |||||
| iou = area * 1. / (area1 + area2 - area) | |||||
| return iou | |||||
| def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||||
| """Apply NMS to bboxes.""" | |||||
| x1 = all_boxes[:, 0] | |||||
| y1 = all_boxes[:, 1] | |||||
| x2 = all_boxes[:, 2] | |||||
| y2 = all_boxes[:, 3] | |||||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||||
| order = all_scores.argsort()[::-1] | |||||
| keep = [] | |||||
| while order.size > 0: | |||||
| i = order[0] | |||||
| keep.append(i) | |||||
| if len(keep) >= max_boxes: | |||||
| break | |||||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||||
| inter = w * h | |||||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||||
| inds = np.where(ovr <= thres)[0] | |||||
| order = order[inds + 1] | |||||
| return keep | |||||
| def metrics(pred_data): | |||||
| """Calculate precision and recall of predicted bboxes.""" | |||||
| config = ConfigYOLOV3ResNet18() | |||||
| num_classes = config.num_classes | |||||
| count_corrects = [1e-6 for _ in range(num_classes)] | |||||
| count_grounds = [1e-6 for _ in range(num_classes)] | |||||
| count_preds = [1e-6 for _ in range(num_classes)] | |||||
| for i, sample in enumerate(pred_data): | |||||
| gt_anno = sample["annotation"] | |||||
| box_scores = sample['box_scores'] | |||||
| boxes = sample['boxes'] | |||||
| mask = box_scores >= config.obj_threshold | |||||
| boxes_ = [] | |||||
| scores_ = [] | |||||
| classes_ = [] | |||||
| max_boxes = config.nms_max_num | |||||
| for c in range(num_classes): | |||||
| class_boxes = np.reshape(boxes, [-1, 4])[np.reshape(mask[:, c], [-1])] | |||||
| class_box_scores = np.reshape(box_scores[:, c], [-1])[np.reshape(mask[:, c], [-1])] | |||||
| nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, max_boxes) | |||||
| class_boxes = class_boxes[nms_index] | |||||
| class_box_scores = class_box_scores[nms_index] | |||||
| classes = np.ones_like(class_box_scores, 'int32') * c | |||||
| boxes_.append(class_boxes) | |||||
| scores_.append(class_box_scores) | |||||
| classes_.append(classes) | |||||
| boxes = np.concatenate(boxes_, axis=0) | |||||
| classes = np.concatenate(classes_, axis=0) | |||||
| # metric | |||||
| count_correct = [1e-6 for _ in range(num_classes)] | |||||
| count_ground = [1e-6 for _ in range(num_classes)] | |||||
| count_pred = [1e-6 for _ in range(num_classes)] | |||||
| for anno in gt_anno: | |||||
| count_ground[anno[4]] += 1 | |||||
| for box_index, box in enumerate(boxes): | |||||
| bbox_pred = [box[1], box[0], box[3], box[2]] | |||||
| count_pred[classes[box_index]] += 1 | |||||
| for anno in gt_anno: | |||||
| class_ground = anno[4] | |||||
| if classes[box_index] == class_ground: | |||||
| iou = calc_iou(bbox_pred, anno) | |||||
| if iou >= 0.5: | |||||
| count_correct[class_ground] += 1 | |||||
| break | |||||
| count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)] | |||||
| count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)] | |||||
| count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)] | |||||
| precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)]) | |||||
| recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)]) | |||||
| return precision, recall | |||||
| @@ -34,6 +34,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"tensor_add", "add"}, | {"tensor_add", "add"}, | ||||
| {"reduce_mean", "reduce_mean_d"}, | {"reduce_mean", "reduce_mean_d"}, | ||||
| {"reduce_max", "reduce_max_d"}, | {"reduce_max", "reduce_max_d"}, | ||||
| {"reduce_min", "reduce_min_d"}, | |||||
| {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, | {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, | ||||
| {"conv2d_backprop_input", "conv2d_backprop_input_d"}, | {"conv2d_backprop_input", "conv2d_backprop_input_d"}, | ||||
| {"top_kv2", "top_k"}, | {"top_kv2", "top_k"}, | ||||
| @@ -15,6 +15,7 @@ | |||||
| """YOLOv3 based on ResNet18.""" | """YOLOv3 based on ResNet18.""" | ||||
| import numpy as np | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context, Tensor | from mindspore import context, Tensor | ||||
| @@ -31,19 +32,14 @@ def weight_variable(): | |||||
| return TruncatedNormal(0.02) | return TruncatedNormal(0.02) | ||||
| class _conv_with_pad(nn.Cell): | |||||
| class _conv2d(nn.Cell): | |||||
| """Create Conv2D with padding.""" | """Create Conv2D with padding.""" | ||||
| def __init__(self, in_channels, out_channels, kernel_size, stride=1): | def __init__(self, in_channels, out_channels, kernel_size, stride=1): | ||||
| super(_conv_with_pad, self).__init__() | |||||
| total_pad = kernel_size - 1 | |||||
| pad_begin = total_pad // 2 | |||||
| pad_end = total_pad - pad_begin | |||||
| self.pad = P.Pad(((0, 0), (0, 0), (pad_begin, pad_end), (pad_begin, pad_end))) | |||||
| super(_conv2d, self).__init__() | |||||
| self.conv = nn.Conv2d(in_channels, out_channels, | self.conv = nn.Conv2d(in_channels, out_channels, | ||||
| kernel_size=kernel_size, stride=stride, padding=0, pad_mode='valid', | |||||
| kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same', | |||||
| weight_init=weight_variable()) | weight_init=weight_variable()) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.pad(x) | |||||
| x = self.conv(x) | x = self.conv(x) | ||||
| return x | return x | ||||
| @@ -101,15 +97,15 @@ class BasicBlock(nn.Cell): | |||||
| momentum=0.99): | momentum=0.99): | ||||
| super(BasicBlock, self).__init__() | super(BasicBlock, self).__init__() | ||||
| self.conv1 = _conv_with_pad(in_channels, out_channels, 3, stride=stride) | |||||
| self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride) | |||||
| self.bn1 = _fused_bn(out_channels, momentum=momentum) | self.bn1 = _fused_bn(out_channels, momentum=momentum) | ||||
| self.conv2 = _conv_with_pad(out_channels, out_channels, 3) | |||||
| self.conv2 = _conv2d(out_channels, out_channels, 3) | |||||
| self.bn2 = _fused_bn(out_channels, momentum=momentum) | self.bn2 = _fused_bn(out_channels, momentum=momentum) | ||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| self.down_sample_layer = None | self.down_sample_layer = None | ||||
| self.downsample = (in_channels != out_channels) | self.downsample = (in_channels != out_channels) | ||||
| if self.downsample: | if self.downsample: | ||||
| self.down_sample_layer = _conv_with_pad(in_channels, out_channels, 1, stride=stride) | |||||
| self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride) | |||||
| self.add = P.TensorAdd() | self.add = P.TensorAdd() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -166,7 +162,7 @@ class ResNet(nn.Cell): | |||||
| raise ValueError("the length of " | raise ValueError("the length of " | ||||
| "layer_num, inchannel, outchannel list must be 4!") | "layer_num, inchannel, outchannel list must be 4!") | ||||
| self.conv1 = _conv_with_pad(3, 64, 7, stride=2) | |||||
| self.conv1 = _conv2d(3, 64, 7, stride=2) | |||||
| self.bn1 = _fused_bn(64) | self.bn1 = _fused_bn(64) | ||||
| self.relu = P.ReLU() | self.relu = P.ReLU() | ||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | ||||
| @@ -452,7 +448,7 @@ class DetectionBlock(nn.Cell): | |||||
| if self.training: | if self.training: | ||||
| return grid, prediction, box_xy, box_wh | return grid, prediction, box_xy, box_wh | ||||
| return self.concat((box_xy, box_wh, box_confidence, box_probs)) | |||||
| return box_xy, box_wh, box_confidence, box_probs | |||||
| class Iou(nn.Cell): | class Iou(nn.Cell): | ||||
| @@ -675,3 +671,78 @@ class TrainingWrapper(nn.Cell): | |||||
| # apply grad reducer on grads | # apply grad reducer on grads | ||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| return F.depend(loss, self.optimizer(grads)) | return F.depend(loss, self.optimizer(grads)) | ||||
| class YoloBoxScores(nn.Cell): | |||||
| """ | |||||
| Calculate the boxes of the original picture size and the score of each box. | |||||
| Args: | |||||
| config (Class): YOLOv3 config. | |||||
| Returns: | |||||
| Tensor, the boxes of the original picture size. | |||||
| Tensor, the score of each box. | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(YoloBoxScores, self).__init__() | |||||
| self.input_shape = Tensor(np.array(config.img_shape), ms.float32) | |||||
| self.num_classes = config.num_classes | |||||
| def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape): | |||||
| batch_size = F.shape(box_xy)[0] | |||||
| x = box_xy[:, :, :, :, 0:1] | |||||
| y = box_xy[:, :, :, :, 1:2] | |||||
| box_yx = P.Concat(-1)((y, x)) | |||||
| w = box_wh[:, :, :, :, 0:1] | |||||
| h = box_wh[:, :, :, :, 1:2] | |||||
| box_hw = P.Concat(-1)((h, w)) | |||||
| new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape)) | |||||
| offset = (self.input_shape - new_shape) / 2.0 / self.input_shape | |||||
| scale = self.input_shape / new_shape | |||||
| box_yx = (box_yx - offset) * scale | |||||
| box_hw = box_hw * scale | |||||
| box_min = box_yx - box_hw / 2.0 | |||||
| box_max = box_yx + box_hw / 2.0 | |||||
| boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1], | |||||
| box_min[:, :, :, :, 1:2], | |||||
| box_max[:, :, :, :, 0:1], | |||||
| box_max[:, :, :, :, 1:2])) | |||||
| image_scale = P.Tile()(image_shape, (1, 2)) | |||||
| boxes = boxes * image_scale | |||||
| boxes = F.reshape(boxes, (batch_size, -1, 4)) | |||||
| boxes_scores = box_confidence * box_probs | |||||
| boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes)) | |||||
| return boxes, boxes_scores | |||||
| class YoloWithEval(nn.Cell): | |||||
| """ | |||||
| Encapsulation class of YOLOv3 evaluation. | |||||
| Args: | |||||
| network (Cell): The training network. Note that loss function and optimizer must not be added. | |||||
| config (Class): YOLOv3 config. | |||||
| Returns: | |||||
| Tensor, the boxes of the original picture size. | |||||
| Tensor, the score of each box. | |||||
| Tensor, the original picture size. | |||||
| """ | |||||
| def __init__(self, network, config): | |||||
| super(YoloWithEval, self).__init__() | |||||
| self.yolo_network = network | |||||
| self.box_score_0 = YoloBoxScores(config) | |||||
| self.box_score_1 = YoloBoxScores(config) | |||||
| self.box_score_2 = YoloBoxScores(config) | |||||
| def construct(self, x, image_shape): | |||||
| yolo_output = self.yolo_network(x) | |||||
| boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape) | |||||
| boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape) | |||||
| boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape) | |||||
| boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2)) | |||||
| boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2)) | |||||
| return boxes, boxes_scores, image_shape | |||||
| @@ -18,7 +18,8 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore._checkparam import ParamValidator as validator | from mindspore._checkparam import ParamValidator as validator | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from .optimizer import Optimizer, grad_scale | |||||
| from mindspore.common import Tensor | |||||
| from .optimizer import Optimizer, grad_scale, apply_decay | |||||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | ||||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | ||||
| @@ -118,6 +119,9 @@ class RMSProp(Optimizer): | |||||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | ||||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False | ||||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | loss_scale (float): A floating point value for the loss scale. Default: 1.0. | ||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | |||||
| Inputs: | Inputs: | ||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | ||||
| @@ -132,7 +136,8 @@ class RMSProp(Optimizer): | |||||
| >>> model = Model(net, loss, opt) | >>> model = Model(net, loss, opt) | ||||
| """ | """ | ||||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | ||||
| use_locking=False, centered=False, loss_scale=1.0): | |||||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | |||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||||
| super(RMSProp, self).__init__(learning_rate, params) | super(RMSProp, self).__init__(learning_rate, params) | ||||
| if isinstance(momentum, float) and momentum < 0.0: | if isinstance(momentum, float) and momentum < 0.0: | ||||
| @@ -159,6 +164,7 @@ class RMSProp(Optimizer): | |||||
| self.assignadd = P.AssignAdd() | self.assignadd = P.AssignAdd() | ||||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | ||||
| self.axis = 0 | self.axis = 0 | ||||
| self.one = Tensor(1, mstype.int32) | |||||
| self.momentum = momentum | self.momentum = momentum | ||||
| @@ -167,10 +173,14 @@ class RMSProp(Optimizer): | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.decay = decay | self.decay = decay | ||||
| self.decay_tf = tuple(decay_filter(x) for x in self.parameters) | |||||
| self.reciprocal_scale = 1.0 / loss_scale | self.reciprocal_scale = 1.0 / loss_scale | ||||
| self.weight_decay = weight_decay * loss_scale | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| params = self.parameters | params = self.parameters | ||||
| if self.weight_decay > 0: | |||||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) | |||||
| if self.reciprocal_scale != 1.0: | if self.reciprocal_scale != 1.0: | ||||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | ||||
| if self.dynamic_lr: | if self.dynamic_lr: | ||||
| @@ -85,7 +85,9 @@ from .logical_and import _logical_and_tbe | |||||
| from .logical_not import _logical_not_tbe | from .logical_not import _logical_not_tbe | ||||
| from .logical_or import _logical_or_tbe | from .logical_or import _logical_or_tbe | ||||
| from .reduce_max import _reduce_max_tbe | from .reduce_max import _reduce_max_tbe | ||||
| from .reduce_min import _reduce_min_tbe | |||||
| from .reduce_sum import _reduce_sum_tbe | from .reduce_sum import _reduce_sum_tbe | ||||
| from .round import _round_tbe | |||||
| from .tanh import _tanh_tbe | from .tanh import _tanh_tbe | ||||
| from .tanh_grad import _tanh_grad_tbe | from .tanh_grad import _tanh_grad_tbe | ||||
| from .softmax import _softmax_tbe | from .softmax import _softmax_tbe | ||||
| @@ -0,0 +1,76 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ReduceMin op""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "ReduceMin", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "reduce_min_d.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "reduce_min_d", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| { | |||||
| "name": "axis", | |||||
| "param_type": "required", | |||||
| "type": "listInt", | |||||
| "value": "all" | |||||
| }, | |||||
| { | |||||
| "name": "keep_dims", | |||||
| "param_type": "required", | |||||
| "type": "bool", | |||||
| "value": "all" | |||||
| } | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ" | |||||
| ], | |||||
| "name": "x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def _reduce_min_tbe(): | |||||
| """ReduceMin TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Round op""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "Round", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "ELEMWISE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "round.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "round", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16", "float16", "float16", "float", "float", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ" | |||||
| ], | |||||
| "name": "x", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16", "float16", "float16", "float", "float", "float" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def _round_tbe(): | |||||
| """Round TBE register""" | |||||
| return | |||||