Browse Source

!94 enhance: reduce execution time for mindrecord test case

Merge pull request !94 from yanzhenxiang2020/fix_mindrecord_ut_long_time
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
d245792842
7 changed files with 20 additions and 20 deletions
  1. +9
    -9
      mindspore/mindrecord/tools/mnist_to_mr.py
  2. BIN
      tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz
  3. BIN
      tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz
  4. BIN
      tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz
  5. BIN
      tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz
  6. +5
    -5
      tests/ut/python/mindrecord/test_mindrecord_base.py
  7. +6
    -6
      tests/ut/python/mindrecord/test_mnist_to_mr.py

+ 9
- 9
mindspore/mindrecord/tools/mnist_to_mr.py View File

@@ -77,20 +77,20 @@ class MnistToMR:

self.mnist_schema_json = {"label": {"type": "int64"}, "data": {"type": "bytes"}}

def _extract_images(self, filename, num_images):
def _extract_images(self, filename):
"""Extract the images into a 4D tensor [image index, y, x, channels]."""
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(self.image_size * self.image_size * num_images * self.num_channels)
buf = bytestream.read()
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(num_images, self.image_size, self.image_size, self.num_channels)
data = data.reshape(-1, self.image_size, self.image_size, self.num_channels)
return data

def _extract_labels(self, filename, num_images):
def _extract_labels(self, filename):
"""Extract the labels into a vector of int64 label IDs."""
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
buf = bytestream.read()
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
return labels

@@ -101,8 +101,8 @@ class MnistToMR:
Yields:
data (dict of list): mnist data list which contains dict.
"""
train_data = self._extract_images(self.train_data_filename_, 60000)
train_labels = self._extract_labels(self.train_labels_filename_, 60000)
train_data = self._extract_images(self.train_data_filename_)
train_labels = self._extract_labels(self.train_labels_filename_)
for data, label in zip(train_data, train_labels):
_, img = cv2.imencode(".jpeg", data)
yield {"label": int(label), "data": img.tobytes()}
@@ -114,8 +114,8 @@ class MnistToMR:
Yields:
data (dict of list): mnist data list which contains dict.
"""
test_data = self._extract_images(self.test_data_filename_, 10000)
test_labels = self._extract_labels(self.test_labels_filename_, 10000)
test_data = self._extract_images(self.test_data_filename_)
test_labels = self._extract_labels(self.test_labels_filename_)
for data, label in zip(test_data, test_labels):
_, img = cv2.imencode(".jpeg", data)
yield {"label": int(label), "data": img.tobytes()}


BIN
tests/ut/data/mindrecord/testMnistData/t10k-images-idx3-ubyte.gz View File


BIN
tests/ut/data/mindrecord/testMnistData/t10k-labels-idx1-ubyte.gz View File


BIN
tests/ut/data/mindrecord/testMnistData/train-images-idx3-ubyte.gz View File


BIN
tests/ut/data/mindrecord/testMnistData/train-labels-idx1-ubyte.gz View File


+ 5
- 5
tests/ut/python/mindrecord/test_mindrecord_base.py View File

@@ -203,9 +203,9 @@ def test_nlp_page_reader_tutorial():
os.remove("{}".format(x))
os.remove("{}.db".format(x))

def test_cv_file_writer_shard_num_1000():
"""test file writer when shard num equals 1000."""
writer = FileWriter(CV_FILE_NAME, 1000)
def test_cv_file_writer_shard_num_10():
"""test file writer when shard num equals 10."""
writer = FileWriter(CV_FILE_NAME, 10)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
@@ -214,8 +214,8 @@ def test_cv_file_writer_shard_num_1000():
writer.write_raw_data(data)
writer.commit()

paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(3, '0'))
for x in range(1000)]
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(10)]
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))


+ 6
- 6
tests/ut/python/mindrecord/test_mnist_to_mr.py View File

@@ -37,7 +37,7 @@ def read(train_name, test_name):
count = count + 1
if count == 1:
logger.info("data: {}".format(x))
assert count == 60000
assert count == 20
reader.close()

count = 0
@@ -47,7 +47,7 @@ def read(train_name, test_name):
count = count + 1
if count == 1:
logger.info("data: {}".format(x))
assert count == 10000
assert count == 10
reader.close()


@@ -102,10 +102,10 @@ def test_mnist_to_mindrecord_compare_data():
't10k-images-idx3-ubyte.gz')
test_labels_filename_ = os.path.join(MNIST_DIR,
't10k-labels-idx1-ubyte.gz')
train_data = _extract_images(train_data_filename_, 60000)
train_labels = _extract_labels(train_labels_filename_, 60000)
test_data = _extract_images(test_data_filename_, 10000)
test_labels = _extract_labels(test_labels_filename_, 10000)
train_data = _extract_images(train_data_filename_, 20)
train_labels = _extract_labels(train_labels_filename_, 20)
test_data = _extract_images(test_data_filename_, 10)
test_labels = _extract_labels(test_labels_filename_, 10)

reader = FileReader(train_name)
for x, data, label in zip(reader.get_next(), train_data, train_labels):


Loading…
Cancel
Save