GitOrigin-RevId: 44a697c3fe
tags/v0.3.2
| @@ -78,7 +78,7 @@ class CIFAR10(VisionDataset): | |||||
| else: | else: | ||||
| raise ValueError( | raise ValueError( | ||||
| "dir does not contain target file\ | "dir does not contain target file\ | ||||
| %s,please set download=True" | |||||
| %s, please set download=True" | |||||
| % (self.target_file) | % (self.target_file) | ||||
| ) | ) | ||||
| @@ -108,7 +108,7 @@ class CIFAR10(VisionDataset): | |||||
| def untar(self, file_path, dirs): | def untar(self, file_path, dirs): | ||||
| assert file_path.endswith(".tar.gz") | assert file_path.endswith(".tar.gz") | ||||
| logger.debug("untar file %s to %s" % (file_path, dirs)) | |||||
| logger.debug("untar file %s to %s", file_path, dirs) | |||||
| t = tarfile.open(file_path) | t = tarfile.open(file_path) | ||||
| t.extractall(path=dirs) | t.extractall(path=dirs) | ||||
| @@ -117,13 +117,13 @@ class CIFAR10(VisionDataset): | |||||
| label = [] | label = [] | ||||
| for filename in filenames: | for filename in filenames: | ||||
| path = os.path.join(self.root, self.raw_file_dir, filename) | path = os.path.join(self.root, self.raw_file_dir, filename) | ||||
| logger.debug("unpickle file %s" % path) | |||||
| logger.debug("unpickle file %s", path) | |||||
| with open(path, "rb") as fo: | with open(path, "rb") as fo: | ||||
| dic = pickle.load(fo, encoding="bytes") | dic = pickle.load(fo, encoding="bytes") | ||||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | ||||
| data.extend(list(batch_data[..., [2, 1, 0]])) | data.extend(list(batch_data[..., [2, 1, 0]])) | ||||
| label.extend(dic[b"labels"]) | label.extend(dic[b"labels"]) | ||||
| label = np.array(label) | |||||
| label = np.array(label, dtype=np.int32) | |||||
| return (data, label) | return (data, label) | ||||
| def process(self): | def process(self): | ||||
| @@ -153,7 +153,7 @@ class CIFAR100(CIFAR10): | |||||
| coarse_label = [] | coarse_label = [] | ||||
| for filename in filenames: | for filename in filenames: | ||||
| path = os.path.join(self.root, self.raw_file_dir, filename) | path = os.path.join(self.root, self.raw_file_dir, filename) | ||||
| logger.debug("unpickle file %s" % path) | |||||
| logger.debug("unpickle file %s", path) | |||||
| with open(path, "rb") as fo: | with open(path, "rb") as fo: | ||||
| dic = pickle.load(fo, encoding="bytes") | dic = pickle.load(fo, encoding="bytes") | ||||
| batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) | ||||
| @@ -71,7 +71,7 @@ class Cityscapes(VisionDataset): | |||||
| elif k == "mask": | elif k == "mask": | ||||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | ||||
| mask = self._trans_mask(mask) | mask = self._trans_mask(mask) | ||||
| mask = mask[:, :, None] | |||||
| mask = mask[:, :, np.newaxis] | |||||
| target.append(mask) | target.append(mask) | ||||
| elif k == "info": | elif k == "info": | ||||
| if image is None: | if image is None: | ||||
| @@ -109,9 +109,9 @@ class Cityscapes(VisionDataset): | |||||
| 33, | 33, | ||||
| ] | ] | ||||
| label = np.ones(mask.shape) * 255 | label = np.ones(mask.shape) * 255 | ||||
| for i in range(len(trans_labels)): | |||||
| label[mask == trans_labels[i]] = i | |||||
| return label.astype("uint8") | |||||
| for i, tl in enumerate(trans_labels): | |||||
| label[mask == tl] = i | |||||
| return label.astype(np.uint8) | |||||
| def _get_target_suffix(self, mode, target_type): | def _get_target_suffix(self, mode, target_type): | ||||
| if target_type == "instance": | if target_type == "instance": | ||||
| @@ -139,7 +139,7 @@ class COCO(VisionDataset): | |||||
| target.append(image) | target.append(image) | ||||
| elif k == "boxes": | elif k == "boxes": | ||||
| boxes = [obj["bbox"] for obj in anno] | boxes = [obj["bbox"] for obj in anno] | ||||
| boxes = np.array(boxes).reshape(-1, 4) | |||||
| boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) | |||||
| # transfer boxes from xywh to xyxy | # transfer boxes from xywh to xyxy | ||||
| boxes[:, 2:] += boxes[:, :2] | boxes[:, 2:] += boxes[:, :2] | ||||
| target.append(boxes) | target.append(boxes) | ||||
| @@ -148,17 +148,21 @@ class COCO(VisionDataset): | |||||
| boxes_category = [ | boxes_category = [ | ||||
| self.json_category_id_to_contiguous_id[c] for c in boxes_category | self.json_category_id_to_contiguous_id[c] for c in boxes_category | ||||
| ] | ] | ||||
| boxes_category = np.array(boxes_category) | |||||
| boxes_category = np.array(boxes_category, dtype=np.int32) | |||||
| target.append(boxes_category) | target.append(boxes_category) | ||||
| # TODO: need to check | |||||
| # elif k == "keypoints": | |||||
| # keypoints = [obj["keypoints"] for obj in anno] | |||||
| # keypoints = np.array(keypoints).reshape(-1, len(self.keypoint_names), 3) | |||||
| # target.append(keypoints) | |||||
| # elif k == "polygons": | |||||
| # polygons = [obj["segmentation"] for obj in anno] | |||||
| # polygons = [[np.array(p).reshape(-1, 2) for p in ps] for ps in polygons] | |||||
| # target.append(polygons) | |||||
| elif k == "keypoints": | |||||
| keypoints = [obj["keypoints"] for obj in anno] | |||||
| keypoints = np.array(keypoints, dtype=np.float32).reshape( | |||||
| -1, len(self.keypoint_names), 3 | |||||
| ) | |||||
| target.append(keypoints) | |||||
| elif k == "polygons": | |||||
| polygons = [obj["segmentation"] for obj in anno] | |||||
| polygons = [ | |||||
| [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] | |||||
| for ps in polygons | |||||
| ] | |||||
| target.append(polygons) | |||||
| elif k == "info": | elif k == "info": | ||||
| info = self.imgs[img_id] | info = self.imgs[img_id] | ||||
| info = [info["height"], info["width"], info["file_name"]] | info = [info["height"], info["width"], info["file_name"]] | ||||
| @@ -19,6 +19,7 @@ import os | |||||
| from typing import Dict, List, Tuple | from typing import Dict, List, Tuple | ||||
| import cv2 | import cv2 | ||||
| import numpy as np | |||||
| from .meta_vision import VisionDataset | from .meta_vision import VisionDataset | ||||
| from .utils import is_img | from .utils import is_img | ||||
| @@ -78,7 +79,7 @@ class ImageFolder(VisionDataset): | |||||
| def collect_class(self) -> Dict: | def collect_class(self) -> Dict: | ||||
| classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | ||||
| classes.sort() | classes.sort() | ||||
| return {classes[i]: i for i in range(len(classes))} | |||||
| return {classes[i]: np.int32(i) for i in range(len(classes))} | |||||
| def __getitem__(self, index: int) -> Tuple: | def __getitem__(self, index: int) -> Tuple: | ||||
| path, label = self.samples[index] | path, label = self.samples[index] | ||||
| @@ -93,7 +93,7 @@ class ImageNet(ImageFolder): | |||||
| self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) | ||||
| if not os.path.exists(self.devkit_dir): | if not os.path.exists(self.devkit_dir): | ||||
| logger.warning("devkit directory %s does not exists" % self.devkit_dir) | |||||
| logger.warning("devkit directory %s does not exists", self.devkit_dir) | |||||
| self._prepare_devkit() | self._prepare_devkit() | ||||
| self.train = train | self.train = train | ||||
| @@ -105,8 +105,8 @@ class ImageNet(ImageFolder): | |||||
| if not os.path.exists(self.target_folder): | if not os.path.exists(self.target_folder): | ||||
| logger.warning( | logger.warning( | ||||
| "expected image folder %s does not exist, try to load from raw file" | |||||
| % self.target_folder | |||||
| "expected image folder %s does not exist, try to load from raw file", | |||||
| self.target_folder, | |||||
| ) | ) | ||||
| if not self.check_raw_file(): | if not self.check_raw_file(): | ||||
| raise FileNotFoundError( | raise FileNotFoundError( | ||||
| @@ -117,8 +117,10 @@ class ImageNet(ImageFolder): | |||||
| raise RuntimeError( | raise RuntimeError( | ||||
| "extracting raw file shouldn't be done in distributed mode, use single process instead" | "extracting raw file shouldn't be done in distributed mode, use single process instead" | ||||
| ) | ) | ||||
| elif train: | |||||
| self._prepare_train() | |||||
| else: | else: | ||||
| self._prepare_train() if train else self._prepare_val() | |||||
| self._prepare_val() | |||||
| super().__init__(self.target_folder, **kwargs) | super().__init__(self.target_folder, **kwargs) | ||||
| @@ -145,12 +147,12 @@ class ImageNet(ImageFolder): | |||||
| try: | try: | ||||
| return load(os.path.join(self.devkit_dir, "meta.pkl")) | return load(os.path.join(self.devkit_dir, "meta.pkl")) | ||||
| except FileNotFoundError: | except FileNotFoundError: | ||||
| import scipy.io as sio | |||||
| import scipy.io | |||||
| meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") | ||||
| if not os.path.exists(meta_path): | if not os.path.exists(meta_path): | ||||
| raise FileNotFoundError("meta file %s does not exist" % meta_path) | raise FileNotFoundError("meta file %s does not exist" % meta_path) | ||||
| meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"] | |||||
| meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"] | |||||
| nums_children = list(zip(*meta))[4] | nums_children = list(zip(*meta))[4] | ||||
| meta = [ | meta = [ | ||||
| meta[idx] | meta[idx] | ||||
| @@ -159,8 +161,8 @@ class ImageNet(ImageFolder): | |||||
| ] | ] | ||||
| idcs, wnids, classes = list(zip(*meta))[:3] | idcs, wnids, classes = list(zip(*meta))[:3] | ||||
| classes = [tuple(clss.split(", ")) for clss in classes] | classes = [tuple(clss.split(", ")) for clss in classes] | ||||
| idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} | |||||
| wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} | |||||
| idx_to_wnid = dict(zip(idcs, wnids)) | |||||
| wnid_to_classes = dict(zip(wnids, classes)) | |||||
| logger.info( | logger.info( | ||||
| "saving cached meta file to %s", | "saving cached meta file to %s", | ||||
| os.path.join(self.devkit_dir, "meta.pkl"), | os.path.join(self.devkit_dir, "meta.pkl"), | ||||
| @@ -208,11 +210,11 @@ class ImageNet(ImageFolder): | |||||
| assert not self.train | assert not self.train | ||||
| raw_filename, checksum = self.raw_file_meta["val"] | raw_filename, checksum = self.raw_file_meta["val"] | ||||
| raw_file = os.path.join(self.root, raw_filename) | raw_file = os.path.join(self.root, raw_filename) | ||||
| logger.info("checksum valid tar file {} ..".format(raw_file)) | |||||
| logger.info("checksum valid tar file %s ...", raw_file) | |||||
| assert ( | assert ( | ||||
| calculate_md5(raw_file) == checksum | calculate_md5(raw_file) == checksum | ||||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | ), "checksum mismatch, {} may be damaged".format(raw_file) | ||||
| logger.info("extract valid tar file.. this may take 10-20 minutes") | |||||
| logger.info("extract valid tar file... this may take 10-20 minutes") | |||||
| untar(os.path.join(self.root, raw_file), self.target_folder) | untar(os.path.join(self.root, raw_file), self.target_folder) | ||||
| self._organize_val_data() | self._organize_val_data() | ||||
| @@ -220,7 +222,7 @@ class ImageNet(ImageFolder): | |||||
| assert self.train | assert self.train | ||||
| raw_filename, checksum = self.raw_file_meta["train"] | raw_filename, checksum = self.raw_file_meta["train"] | ||||
| raw_file = os.path.join(self.root, raw_filename) | raw_file = os.path.join(self.root, raw_filename) | ||||
| logger.info("checksum train tar file {} ..".format(raw_file)) | |||||
| logger.info("checksum train tar file %s ...", raw_file) | |||||
| assert ( | assert ( | ||||
| calculate_md5(raw_file) == checksum | calculate_md5(raw_file) == checksum | ||||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | ), "checksum mismatch, {} may be damaged".format(raw_file) | ||||
| @@ -238,7 +240,7 @@ class ImageNet(ImageFolder): | |||||
| def _prepare_devkit(self): | def _prepare_devkit(self): | ||||
| raw_filename, checksum = self.raw_file_meta["devkit"] | raw_filename, checksum = self.raw_file_meta["devkit"] | ||||
| raw_file = os.path.join(self.root, raw_filename) | raw_file = os.path.join(self.root, raw_filename) | ||||
| logger.info("checksum devkit tar file {} ..".format(raw_file)) | |||||
| logger.info("checksum devkit tar file %s ...", raw_file) | |||||
| assert ( | assert ( | ||||
| calculate_md5(raw_file) == checksum | calculate_md5(raw_file) == checksum | ||||
| ), "checksum mismatch, {} may be damaged".format(raw_file) | ), "checksum mismatch, {} may be damaged".format(raw_file) | ||||
| @@ -8,7 +8,6 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import gzip | import gzip | ||||
| import os | import os | ||||
| import pickle | |||||
| import struct | import struct | ||||
| from typing import Tuple | from typing import Tuple | ||||
| @@ -48,14 +47,6 @@ class MNIST(VisionDataset): | |||||
| """ | """ | ||||
| md5 for checking raw files | md5 for checking raw files | ||||
| """ | """ | ||||
| train_file = "train.pkl" | |||||
| """ | |||||
| default pickle file name of training set and its meta data | |||||
| """ | |||||
| test_file = "test.pkl" | |||||
| """ | |||||
| default pickle file name of test set and its meta data | |||||
| """ | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| @@ -65,30 +56,11 @@ class MNIST(VisionDataset): | |||||
| timeout: int = 500, | timeout: int = 500, | ||||
| ): | ): | ||||
| r""" | r""" | ||||
| initialization: | |||||
| 1. check root path and target file (train or test) | |||||
| 2. check target file exists | |||||
| * if exists: | |||||
| * load pickle file as meta-data and data in MNIST dataset | |||||
| * else: | |||||
| * if download: | |||||
| a. load all raw datas (both train and test set) by url | |||||
| b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) ) | |||||
| c. save meta-data and data as pickle file | |||||
| d. load pickle file as meta-data and data in MNIST dataset | |||||
| :param root: path for mnist dataset downloading or loading, if ``None``, | :param root: path for mnist dataset downloading or loading, if ``None``, | ||||
| set ``root`` to the ``_default_root`` | set ``root`` to the ``_default_root`` | ||||
| :param train: if ``True``, loading trainingset, else loading test set | :param train: if ``True``, loading trainingset, else loading test set | ||||
| :param download: after checking the target files existence, if target files do not | |||||
| exists and download sets to ``True``, download raw files and process, | |||||
| then load, otherwise raise ValueError, default is True | |||||
| :param download: if raw files do not exists and download sets to ``True``, | |||||
| download raw files and process, otherwise raise ValueError, default is True | |||||
| """ | """ | ||||
| super().__init__(root, order=("image", "image_category")) | super().__init__(root, order=("image", "image_category")) | ||||
| @@ -105,29 +77,15 @@ class MNIST(VisionDataset): | |||||
| if not os.path.exists(self.root): | if not os.path.exists(self.root): | ||||
| raise ValueError("dir %s does not exist" % self.root) | raise ValueError("dir %s does not exist" % self.root) | ||||
| # choose the target pickle file | |||||
| if train: | |||||
| self.target_file = os.path.join(self.root, self.train_file) | |||||
| if self._check_raw_files(): | |||||
| self.process(train) | |||||
| elif download: | |||||
| self.download() | |||||
| self.process(train) | |||||
| else: | else: | ||||
| self.target_file = os.path.join(self.root, self.test_file) | |||||
| # check existence of target pickle file, if exists load the | |||||
| # pickle file no matter what download is set | |||||
| if os.path.exists(self.target_file): | |||||
| self._meta_data, self.arrays = self._load_file(self.target_file) | |||||
| elif self._check_raw_files(): | |||||
| self.process() | |||||
| self._meta_data, self.arrays = self._load_file(self.target_file) | |||||
| else: | |||||
| if download: | |||||
| self.download() | |||||
| self._meta_data, self.arrays = self._load_file(self.target_file) | |||||
| else: | |||||
| raise ValueError( | |||||
| "dir does not contain target file\ | |||||
| %s,please set download=True" | |||||
| % (self.target_file) | |||||
| ) | |||||
| raise ValueError( | |||||
| "root does not contain valid raw files, please set download=True" | |||||
| ) | |||||
| def __getitem__(self, index: int) -> Tuple: | def __getitem__(self, index: int) -> Tuple: | ||||
| return tuple(array[index] for array in self.arrays) | return tuple(array[index] for array in self.arrays) | ||||
| @@ -143,10 +101,6 @@ class MNIST(VisionDataset): | |||||
| def meta(self): | def meta(self): | ||||
| return self._meta_data | return self._meta_data | ||||
| def _load_file(self, target_file): | |||||
| with open(target_file, "rb") as f: | |||||
| return pickle.load(f) | |||||
| def _check_raw_files(self): | def _check_raw_files(self): | ||||
| return all( | return all( | ||||
| [ | [ | ||||
| @@ -159,45 +113,35 @@ class MNIST(VisionDataset): | |||||
| for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): | ||||
| url = self.url_path + file_name | url = self.url_path + file_name | ||||
| load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) | ||||
| self.process() | |||||
| def process(self): | |||||
| def process(self, train): | |||||
| # load raw files and transform them into meta data and datasets Tuple(np.array) | # load raw files and transform them into meta data and datasets Tuple(np.array) | ||||
| logger.info("process raw data ...") | |||||
| meta_data_images_train, images_train = parse_idx3( | |||||
| os.path.join(self.root, self.raw_file_name[0]) | |||||
| ) | |||||
| meta_data_labels_train, labels_train = parse_idx1( | |||||
| os.path.join(self.root, self.raw_file_name[1]) | |||||
| ) | |||||
| meta_data_images_test, images_test = parse_idx3( | |||||
| os.path.join(self.root, self.raw_file_name[2]) | |||||
| ) | |||||
| meta_data_labels_test, labels_test = parse_idx1( | |||||
| os.path.join(self.root, self.raw_file_name[3]) | |||||
| ) | |||||
| meta_data_train = { | |||||
| "images": meta_data_images_train, | |||||
| "labels": meta_data_labels_train, | |||||
| } | |||||
| meta_data_test = { | |||||
| "images": meta_data_images_test, | |||||
| "labels": meta_data_labels_test, | |||||
| logger.info("process the raw files of %s set...", "train" if train else "test") | |||||
| if train: | |||||
| meta_data_images, images = parse_idx3( | |||||
| os.path.join(self.root, self.raw_file_name[0]) | |||||
| ) | |||||
| meta_data_labels, labels = parse_idx1( | |||||
| os.path.join(self.root, self.raw_file_name[1]) | |||||
| ) | |||||
| else: | |||||
| meta_data_images, images = parse_idx3( | |||||
| os.path.join(self.root, self.raw_file_name[2]) | |||||
| ) | |||||
| meta_data_labels, labels = parse_idx1( | |||||
| os.path.join(self.root, self.raw_file_name[3]) | |||||
| ) | |||||
| self._meta_data = { | |||||
| "images": meta_data_images, | |||||
| "labels": meta_data_labels, | |||||
| } | } | ||||
| dataset_train = (images_train, labels_train) | |||||
| dataset_test = (images_test, labels_test) | |||||
| # save both training set and test set as pickle files | |||||
| with open(os.path.join(self.root, self.train_file), "wb") as f: | |||||
| pickle.dump((meta_data_train, dataset_train), f, pickle.HIGHEST_PROTOCOL) | |||||
| with open(os.path.join(self.root, self.test_file), "wb") as f: | |||||
| pickle.dump((meta_data_test, dataset_test), f, pickle.HIGHEST_PROTOCOL) | |||||
| self.arrays = (images, labels.astype(np.int32)) | |||||
| def parse_idx3(idx3_file): | def parse_idx3(idx3_file): | ||||
| # parse idx3 file to meta data and data in numpy array (images) | # parse idx3 file to meta data and data in numpy array (images) | ||||
| logger.debug("parse idx3 file %s ..." % idx3_file) | |||||
| logger.debug("parse idx3 file %s ...", idx3_file) | |||||
| assert idx3_file.endswith(".gz") | assert idx3_file.endswith(".gz") | ||||
| with gzip.open(idx3_file, "rb") as f: | with gzip.open(idx3_file, "rb") as f: | ||||
| bin_data = f.read() | bin_data = f.read() | ||||
| @@ -223,7 +167,7 @@ def parse_idx3(idx3_file): | |||||
| def parse_idx1(idx1_file): | def parse_idx1(idx1_file): | ||||
| # parse idx1 file to meta data and data in numpy array (labels) | # parse idx1 file to meta data and data in numpy array (labels) | ||||
| logger.debug("parse idx1 file %s ..." % idx1_file) | |||||
| logger.debug("parse idx1 file %s ...", idx1_file) | |||||
| assert idx1_file.endswith(".gz") | assert idx1_file.endswith(".gz") | ||||
| with gzip.open(idx1_file, "rb") as f: | with gzip.open(idx1_file, "rb") as f: | ||||
| bin_data = f.read() | bin_data = f.read() | ||||
| @@ -32,7 +32,7 @@ def load_raw_data_from_url( | |||||
| ): | ): | ||||
| cached_file = os.path.join(raw_data_dir, filename) | cached_file = os.path.join(raw_data_dir, filename) | ||||
| logger.debug( | logger.debug( | ||||
| "load_raw_data_from_url: downloading to or using cached %s ..." % cached_file | |||||
| "load_raw_data_from_url: downloading to or using cached %s ...", cached_file | |||||
| ) | ) | ||||
| if not os.path.exists(cached_file): | if not os.path.exists(cached_file): | ||||
| if is_distributed(): | if is_distributed(): | ||||
| @@ -45,7 +45,7 @@ def load_raw_data_from_url( | |||||
| else: | else: | ||||
| md5 = calculate_md5(cached_file) | md5 = calculate_md5(cached_file) | ||||
| if target_md5 == md5: | if target_md5 == md5: | ||||
| logger.debug("%s exists with correct md5: %s" % (filename, target_md5)) | |||||
| logger.debug("%s exists with correct md5: %s", filename, target_md5) | |||||
| else: | else: | ||||
| os.remove(cached_file) | os.remove(cached_file) | ||||
| raise RuntimeError("{} exists but fail to match md5".format(filename)) | raise RuntimeError("{} exists but fail to match md5".format(filename)) | ||||
| @@ -77,13 +77,13 @@ class PascalVOC(VisionDataset): | |||||
| if "aug" in self.image_set: | if "aug" in self.image_set: | ||||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) | ||||
| else: | else: | ||||
| mask = np.array(cv2.imread(self.masks[index], cv2.IMREAD_COLOR)) | |||||
| mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) | |||||
| mask = self._trans_mask(mask) | mask = self._trans_mask(mask) | ||||
| mask = mask[:, :, np.newaxis] | mask = mask[:, :, np.newaxis] | ||||
| target.append(mask) | target.append(mask) | ||||
| # elif k == "boxes": | |||||
| # boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||||
| # target.append(boxes) | |||||
| elif k == "boxes": | |||||
| boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) | |||||
| target.append(boxes) | |||||
| elif k == "info": | elif k == "info": | ||||
| if image is None: | if image is None: | ||||
| image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) | ||||
| @@ -104,7 +104,7 @@ class PascalVOC(VisionDataset): | |||||
| label[ | label[ | ||||
| (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) | ||||
| ] = i | ] = i | ||||
| return label.astype("uint8") | |||||
| return label.astype(np.uint8) | |||||
| def parse_voc_xml(self, node): | def parse_voc_xml(self, node): | ||||
| voc_dict = {} | voc_dict = {} | ||||