Browse Source

fix(mge/data): process mnist without generate new files

GitOrigin-RevId: 44a697c3fe
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
da52256891
8 changed files with 80 additions and 129 deletions
  1. +5
    -5
      python_module/megengine/data/dataset/vision/cifar.py
  2. +4
    -4
      python_module/megengine/data/dataset/vision/cityscapes.py
  3. +15
    -11
      python_module/megengine/data/dataset/vision/coco.py
  4. +2
    -1
      python_module/megengine/data/dataset/vision/folder.py
  5. +14
    -12
      python_module/megengine/data/dataset/vision/imagenet.py
  6. +33
    -89
      python_module/megengine/data/dataset/vision/mnist.py
  7. +2
    -2
      python_module/megengine/data/dataset/vision/utils.py
  8. +5
    -5
      python_module/megengine/data/dataset/vision/voc.py

+ 5
- 5
python_module/megengine/data/dataset/vision/cifar.py View File

@@ -78,7 +78,7 @@ class CIFAR10(VisionDataset):
else:
raise ValueError(
"dir does not contain target file\
%s,please set download=True"
%s, please set download=True"
% (self.target_file)
)

@@ -108,7 +108,7 @@ class CIFAR10(VisionDataset):

def untar(self, file_path, dirs):
assert file_path.endswith(".tar.gz")
logger.debug("untar file %s to %s" % (file_path, dirs))
logger.debug("untar file %s to %s", file_path, dirs)
t = tarfile.open(file_path)
t.extractall(path=dirs)

@@ -117,13 +117,13 @@ class CIFAR10(VisionDataset):
label = []
for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s" % path)
logger.debug("unpickle file %s", path)
with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
data.extend(list(batch_data[..., [2, 1, 0]]))
label.extend(dic[b"labels"])
label = np.array(label)
label = np.array(label, dtype=np.int32)
return (data, label)

def process(self):
@@ -153,7 +153,7 @@ class CIFAR100(CIFAR10):
coarse_label = []
for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s" % path)
logger.debug("unpickle file %s", path)
with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))


+ 4
- 4
python_module/megengine/data/dataset/vision/cityscapes.py View File

@@ -71,7 +71,7 @@ class Cityscapes(VisionDataset):
elif k == "mask":
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
mask = self._trans_mask(mask)
mask = mask[:, :, None]
mask = mask[:, :, np.newaxis]
target.append(mask)
elif k == "info":
if image is None:
@@ -109,9 +109,9 @@ class Cityscapes(VisionDataset):
33,
]
label = np.ones(mask.shape) * 255
for i in range(len(trans_labels)):
label[mask == trans_labels[i]] = i
return label.astype("uint8")
for i, tl in enumerate(trans_labels):
label[mask == tl] = i
return label.astype(np.uint8)

def _get_target_suffix(self, mode, target_type):
if target_type == "instance":


+ 15
- 11
python_module/megengine/data/dataset/vision/coco.py View File

@@ -139,7 +139,7 @@ class COCO(VisionDataset):
target.append(image)
elif k == "boxes":
boxes = [obj["bbox"] for obj in anno]
boxes = np.array(boxes).reshape(-1, 4)
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
# transfer boxes from xywh to xyxy
boxes[:, 2:] += boxes[:, :2]
target.append(boxes)
@@ -148,17 +148,21 @@ class COCO(VisionDataset):
boxes_category = [
self.json_category_id_to_contiguous_id[c] for c in boxes_category
]
boxes_category = np.array(boxes_category)
boxes_category = np.array(boxes_category, dtype=np.int32)
target.append(boxes_category)
# TODO: need to check
# elif k == "keypoints":
# keypoints = [obj["keypoints"] for obj in anno]
# keypoints = np.array(keypoints).reshape(-1, len(self.keypoint_names), 3)
# target.append(keypoints)
# elif k == "polygons":
# polygons = [obj["segmentation"] for obj in anno]
# polygons = [[np.array(p).reshape(-1, 2) for p in ps] for ps in polygons]
# target.append(polygons)
elif k == "keypoints":
keypoints = [obj["keypoints"] for obj in anno]
keypoints = np.array(keypoints, dtype=np.float32).reshape(
-1, len(self.keypoint_names), 3
)
target.append(keypoints)
elif k == "polygons":
polygons = [obj["segmentation"] for obj in anno]
polygons = [
[np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps]
for ps in polygons
]
target.append(polygons)
elif k == "info":
info = self.imgs[img_id]
info = [info["height"], info["width"], info["file_name"]]


+ 2
- 1
python_module/megengine/data/dataset/vision/folder.py View File

@@ -19,6 +19,7 @@ import os
from typing import Dict, List, Tuple

import cv2
import numpy as np

from .meta_vision import VisionDataset
from .utils import is_img
@@ -78,7 +79,7 @@ class ImageFolder(VisionDataset):
def collect_class(self) -> Dict:
classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
classes.sort()
return {classes[i]: i for i in range(len(classes))}
return {classes[i]: np.int32(i) for i in range(len(classes))}

def __getitem__(self, index: int) -> Tuple:
path, label = self.samples[index]


+ 14
- 12
python_module/megengine/data/dataset/vision/imagenet.py View File

@@ -93,7 +93,7 @@ class ImageNet(ImageFolder):
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)

if not os.path.exists(self.devkit_dir):
logger.warning("devkit directory %s does not exists" % self.devkit_dir)
logger.warning("devkit directory %s does not exists", self.devkit_dir)
self._prepare_devkit()

self.train = train
@@ -105,8 +105,8 @@ class ImageNet(ImageFolder):

if not os.path.exists(self.target_folder):
logger.warning(
"expected image folder %s does not exist, try to load from raw file"
% self.target_folder
"expected image folder %s does not exist, try to load from raw file",
self.target_folder,
)
if not self.check_raw_file():
raise FileNotFoundError(
@@ -117,8 +117,10 @@ class ImageNet(ImageFolder):
raise RuntimeError(
"extracting raw file shouldn't be done in distributed mode, use single process instead"
)
elif train:
self._prepare_train()
else:
self._prepare_train() if train else self._prepare_val()
self._prepare_val()

super().__init__(self.target_folder, **kwargs)

@@ -145,12 +147,12 @@ class ImageNet(ImageFolder):
try:
return load(os.path.join(self.devkit_dir, "meta.pkl"))
except FileNotFoundError:
import scipy.io as sio
import scipy.io

meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
if not os.path.exists(meta_path):
raise FileNotFoundError("meta file %s does not exist" % meta_path)
meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"]
meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4]
meta = [
meta[idx]
@@ -159,8 +161,8 @@ class ImageNet(ImageFolder):
]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(", ")) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
idx_to_wnid = dict(zip(idcs, wnids))
wnid_to_classes = dict(zip(wnids, classes))
logger.info(
"saving cached meta file to %s",
os.path.join(self.devkit_dir, "meta.pkl"),
@@ -208,11 +210,11 @@ class ImageNet(ImageFolder):
assert not self.train
raw_filename, checksum = self.raw_file_meta["val"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum valid tar file {} ..".format(raw_file))
logger.info("checksum valid tar file %s ...", raw_file)
assert (
calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file)
logger.info("extract valid tar file.. this may take 10-20 minutes")
logger.info("extract valid tar file... this may take 10-20 minutes")
untar(os.path.join(self.root, raw_file), self.target_folder)
self._organize_val_data()

@@ -220,7 +222,7 @@ class ImageNet(ImageFolder):
assert self.train
raw_filename, checksum = self.raw_file_meta["train"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum train tar file {} ..".format(raw_file))
logger.info("checksum train tar file %s ...", raw_file)
assert (
calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file)
@@ -238,7 +240,7 @@ class ImageNet(ImageFolder):
def _prepare_devkit(self):
raw_filename, checksum = self.raw_file_meta["devkit"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum devkit tar file {} ..".format(raw_file))
logger.info("checksum devkit tar file %s ...", raw_file)
assert (
calculate_md5(raw_file) == checksum
), "checksum mismatch, {} may be damaged".format(raw_file)


+ 33
- 89
python_module/megengine/data/dataset/vision/mnist.py View File

@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gzip
import os
import pickle
import struct
from typing import Tuple

@@ -48,14 +47,6 @@ class MNIST(VisionDataset):
"""
md5 for checking raw files
"""
train_file = "train.pkl"
"""
default pickle file name of training set and its meta data
"""
test_file = "test.pkl"
"""
default pickle file name of test set and its meta data
"""

def __init__(
self,
@@ -65,30 +56,11 @@ class MNIST(VisionDataset):
timeout: int = 500,
):
r"""
initialization:

1. check root path and target file (train or test)
2. check target file exists

* if exists:

* load pickle file as meta-data and data in MNIST dataset

* else:

* if download:

a. load all raw datas (both train and test set) by url
b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) )
c. save meta-data and data as pickle file
d. load pickle file as meta-data and data in MNIST dataset

:param root: path for mnist dataset downloading or loading, if ``None``,
set ``root`` to the ``_default_root``
:param train: if ``True``, loading trainingset, else loading test set
:param download: after checking the target files existence, if target files do not
exists and download sets to ``True``, download raw files and process,
then load, otherwise raise ValueError, default is True
:param download: if raw files do not exists and download sets to ``True``,
download raw files and process, otherwise raise ValueError, default is True

"""
super().__init__(root, order=("image", "image_category"))
@@ -105,29 +77,15 @@ class MNIST(VisionDataset):
if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root)

# choose the target pickle file
if train:
self.target_file = os.path.join(self.root, self.train_file)
if self._check_raw_files():
self.process(train)
elif download:
self.download()
self.process(train)
else:
self.target_file = os.path.join(self.root, self.test_file)

# check existence of target pickle file, if exists load the
# pickle file no matter what download is set
if os.path.exists(self.target_file):
self._meta_data, self.arrays = self._load_file(self.target_file)
elif self._check_raw_files():
self.process()
self._meta_data, self.arrays = self._load_file(self.target_file)
else:
if download:
self.download()
self._meta_data, self.arrays = self._load_file(self.target_file)
else:
raise ValueError(
"dir does not contain target file\
%s,please set download=True"
% (self.target_file)
)
raise ValueError(
"root does not contain valid raw files, please set download=True"
)

def __getitem__(self, index: int) -> Tuple:
return tuple(array[index] for array in self.arrays)
@@ -143,10 +101,6 @@ class MNIST(VisionDataset):
def meta(self):
return self._meta_data

def _load_file(self, target_file):
with open(target_file, "rb") as f:
return pickle.load(f)

def _check_raw_files(self):
return all(
[
@@ -159,45 +113,35 @@ class MNIST(VisionDataset):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
self.process()

def process(self):
def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array)
logger.info("process raw data ...")
meta_data_images_train, images_train = parse_idx3(
os.path.join(self.root, self.raw_file_name[0])
)
meta_data_labels_train, labels_train = parse_idx1(
os.path.join(self.root, self.raw_file_name[1])
)
meta_data_images_test, images_test = parse_idx3(
os.path.join(self.root, self.raw_file_name[2])
)
meta_data_labels_test, labels_test = parse_idx1(
os.path.join(self.root, self.raw_file_name[3])
)

meta_data_train = {
"images": meta_data_images_train,
"labels": meta_data_labels_train,
}
meta_data_test = {
"images": meta_data_images_test,
"labels": meta_data_labels_test,
logger.info("process the raw files of %s set...", "train" if train else "test")
if train:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[0])
)
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[1])
)
else:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[2])
)
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[3])
)

self._meta_data = {
"images": meta_data_images,
"labels": meta_data_labels,
}
dataset_train = (images_train, labels_train)
dataset_test = (images_test, labels_test)

# save both training set and test set as pickle files
with open(os.path.join(self.root, self.train_file), "wb") as f:
pickle.dump((meta_data_train, dataset_train), f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(self.root, self.test_file), "wb") as f:
pickle.dump((meta_data_test, dataset_test), f, pickle.HIGHEST_PROTOCOL)
self.arrays = (images, labels.astype(np.int32))


def parse_idx3(idx3_file):
# parse idx3 file to meta data and data in numpy array (images)
logger.debug("parse idx3 file %s ..." % idx3_file)
logger.debug("parse idx3 file %s ...", idx3_file)
assert idx3_file.endswith(".gz")
with gzip.open(idx3_file, "rb") as f:
bin_data = f.read()
@@ -223,7 +167,7 @@ def parse_idx3(idx3_file):

def parse_idx1(idx1_file):
# parse idx1 file to meta data and data in numpy array (labels)
logger.debug("parse idx1 file %s ..." % idx1_file)
logger.debug("parse idx1 file %s ...", idx1_file)
assert idx1_file.endswith(".gz")
with gzip.open(idx1_file, "rb") as f:
bin_data = f.read()


+ 2
- 2
python_module/megengine/data/dataset/vision/utils.py View File

@@ -32,7 +32,7 @@ def load_raw_data_from_url(
):
cached_file = os.path.join(raw_data_dir, filename)
logger.debug(
"load_raw_data_from_url: downloading to or using cached %s ..." % cached_file
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file
)
if not os.path.exists(cached_file):
if is_distributed():
@@ -45,7 +45,7 @@ def load_raw_data_from_url(
else:
md5 = calculate_md5(cached_file)
if target_md5 == md5:
logger.debug("%s exists with correct md5: %s" % (filename, target_md5))
logger.debug("%s exists with correct md5: %s", filename, target_md5)
else:
os.remove(cached_file)
raise RuntimeError("{} exists but fail to match md5".format(filename))


+ 5
- 5
python_module/megengine/data/dataset/vision/voc.py View File

@@ -77,13 +77,13 @@ class PascalVOC(VisionDataset):
if "aug" in self.image_set:
mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
else:
mask = np.array(cv2.imread(self.masks[index], cv2.IMREAD_COLOR))
mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR)
mask = self._trans_mask(mask)
mask = mask[:, :, np.newaxis]
target.append(mask)
# elif k == "boxes":
# boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
# target.append(boxes)
elif k == "boxes":
boxes = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
target.append(boxes)
elif k == "info":
if image is None:
image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
@@ -104,7 +104,7 @@ class PascalVOC(VisionDataset):
label[
(mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
] = i
return label.astype("uint8")
return label.astype(np.uint8)

def parse_voc_xml(self, node):
voc_dict = {}


Loading…
Cancel
Save