| @@ -16,6 +16,10 @@ | |||
| This module is to write data into mindrecord. | |||
| """ | |||
| import os | |||
| import sys | |||
| import threading | |||
| import traceback | |||
| import numpy as np | |||
| import mindspore._c_mindrecord as ms | |||
| from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError | |||
| @@ -41,6 +45,23 @@ VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "s | |||
| VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"] | |||
| VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"] | |||
| class ExceptionThread(threading.Thread): | |||
| """ class to pass exception""" | |||
| def __init__(self, *args, **kwargs): | |||
| threading.Thread.__init__(self, *args, **kwargs) | |||
| self.res = SUCCESS | |||
| self.exitcode = 0 | |||
| self.exception = None | |||
| self.exc_traceback = '' | |||
| def run(self): | |||
| try: | |||
| if self._target: | |||
| self.res = self._target(*self._args, **self._kwargs) | |||
| except Exception as e: # pylint: disable=W0703 | |||
| self.exitcode = 1 | |||
| self.exception = e | |||
| self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info())) | |||
| def check_filename(path): | |||
| """ | |||
| @@ -24,7 +24,8 @@ from mindspore import log as logger | |||
| from .cifar100 import Cifar100 | |||
| from ..common.exceptions import PathNotExistsError | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename, SUCCESS | |||
| from ..shardutils import check_filename, ExceptionThread, SUCCESS | |||
| try: | |||
| cv2 = import_module("cv2") | |||
| except ModuleNotFoundError: | |||
| @@ -65,7 +66,7 @@ class Cifar100ToMR: | |||
| self.destination = destination | |||
| self.writer = None | |||
| def transform(self, fields=None): | |||
| def run(self, fields=None): | |||
| """ | |||
| Executes transformation from cifar100 to MindRecord. | |||
| @@ -104,6 +105,15 @@ class Cifar100ToMR: | |||
| return FAILED | |||
| return SUCCESS | |||
| def transform(self, fields=None): | |||
| t = ExceptionThread(target=self.run, kwargs={'fields': fields}) | |||
| t.daemon = True | |||
| t.start() | |||
| t.join() | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||
| def _construct_raw_data(images, fine_labels, coarse_labels): | |||
| """ | |||
| Construct raw data from cifar100 data. | |||
| @@ -24,7 +24,7 @@ from mindspore import log as logger | |||
| from .cifar10 import Cifar10 | |||
| from ..common.exceptions import PathNotExistsError | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename, SUCCESS, FAILED | |||
| from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED | |||
| try: | |||
| cv2 = import_module("cv2") | |||
| except ModuleNotFoundError: | |||
| @@ -65,7 +65,7 @@ class Cifar10ToMR: | |||
| self.destination = destination | |||
| self.writer = None | |||
| def transform(self, fields=None): | |||
| def run(self, fields=None): | |||
| """ | |||
| Executes transformation from cifar10 to MindRecord. | |||
| @@ -100,6 +100,15 @@ class Cifar10ToMR: | |||
| return FAILED | |||
| return SUCCESS | |||
| def transform(self, fields=None): | |||
| t = ExceptionThread(target=self.run, kwargs={'fields': fields}) | |||
| t.daemon = True | |||
| t.start() | |||
| t.join() | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||
| def _construct_raw_data(images, labels): | |||
| """ | |||
| Construct raw data from cifar10 data. | |||
| @@ -20,7 +20,7 @@ import os | |||
| from mindspore import log as logger | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename | |||
| from ..shardutils import check_filename, ExceptionThread | |||
| try: | |||
| pd = import_module("pandas") | |||
| @@ -116,7 +116,7 @@ class CsvToMR: | |||
| row[col] = r[col] | |||
| yield row | |||
| def transform(self): | |||
| def run(self): | |||
| """ | |||
| Executes transformation from csv to MindRecord. | |||
| @@ -166,3 +166,12 @@ class CsvToMR: | |||
| ret = self.writer.commit() | |||
| return ret | |||
| def transform(self): | |||
| t = ExceptionThread(target=self.run) | |||
| t.daemon = True | |||
| t.start() | |||
| t.join() | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||
| @@ -21,7 +21,7 @@ import time | |||
| from mindspore import log as logger | |||
| from ..common.exceptions import PathNotExistsError | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename | |||
| from ..shardutils import check_filename, ExceptionThread | |||
| __all__ = ['ImageNetToMR'] | |||
| @@ -118,7 +118,7 @@ class ImageNetToMR: | |||
| data["image"] = image_bytes | |||
| yield data | |||
| def transform(self): | |||
| def run(self): | |||
| """ | |||
| Executes transformation from imagenet to MindRecord. | |||
| @@ -170,3 +170,12 @@ class ImageNetToMR: | |||
| logger.info("--------------------------------------------") | |||
| return ret | |||
| def transform(self): | |||
| t = ExceptionThread(target=self.run) | |||
| t.daemon = True | |||
| t.start() | |||
| t.join() | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||
| @@ -23,7 +23,7 @@ import numpy as np | |||
| from mindspore import log as logger | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename, SUCCESS, FAILED | |||
| from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED | |||
| try: | |||
| cv2 = import_module("cv2") | |||
| @@ -217,7 +217,7 @@ class MnistToMR: | |||
| return ret | |||
| def transform(self): | |||
| def run(self): | |||
| """ | |||
| Executes transformation from Mnist to MindRecord. | |||
| @@ -233,3 +233,12 @@ class MnistToMR: | |||
| return FAILED | |||
| return SUCCESS | |||
| def transform(self): | |||
| t = ExceptionThread(target=self.run) | |||
| t.daemon = True | |||
| t.start() | |||
| t.join() | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||
| @@ -21,7 +21,7 @@ import numpy as np | |||
| from mindspore import log as logger | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename | |||
| from ..shardutils import check_filename, ExceptionThread | |||
| try: | |||
| tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||
| @@ -235,7 +235,7 @@ class TFRecordToMR: | |||
| except tf.errors.InvalidArgumentError: | |||
| raise ValueError("TFRecord feature_dict parameter error.") | |||
| def transform(self): | |||
| def run(self): | |||
| """ | |||
| Executes transform from TFRecord to MindRecord. | |||
| @@ -267,3 +267,12 @@ class TFRecordToMR: | |||
| logger.info("Transformed {} records...".format(transform_count)) | |||
| break | |||
| return writer.commit() | |||
| def transform(self): | |||
| t = ExceptionThread(target=self.run) | |||
| t.daemon = True | |||
| t.start() | |||
| t.join() | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||