Merge pull request !2239 from liyong126/csv_to_mindrecordtags/v0.5.0-beta
| @@ -29,10 +29,11 @@ from .common.exceptions import * | |||
| from .shardutils import SUCCESS, FAILED | |||
| from .tools.cifar10_to_mr import Cifar10ToMR | |||
| from .tools.cifar100_to_mr import Cifar100ToMR | |||
| from .tools.csv_to_mr import CsvToMR | |||
| from .tools.imagenet_to_mr import ImageNetToMR | |||
| from .tools.mnist_to_mr import MnistToMR | |||
| from .tools.tfrecord_to_mr import TFRecordToMR | |||
| __all__ = ['FileWriter', 'FileReader', 'MindPage', | |||
| 'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR', | |||
| 'Cifar10ToMR', 'Cifar100ToMR', 'CsvToMR', 'ImageNetToMR', 'MnistToMR', 'TFRecordToMR', | |||
| 'SUCCESS', 'FAILED'] | |||
| @@ -0,0 +1,168 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Csv format convert tool for MindRecord. | |||
| """ | |||
| from importlib import import_module | |||
| import os | |||
| from mindspore import log as logger | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename | |||
| try: | |||
| pd = import_module("pandas") | |||
| except ModuleNotFoundError: | |||
| pd = None | |||
| __all__ = ['CsvToMR'] | |||
| class CsvToMR: | |||
| """ | |||
| Class is for transformation from csv to MindRecord. | |||
| Args: | |||
| source (str): the file path of csv. | |||
| destination (str): the MindRecord file path to transform into. | |||
| columns_list(list[str], optional): List of columns to be read(default=None). | |||
| partition_number (int, optional): partition size (default=1). | |||
| Raises: | |||
| ValueError: If source, destination, partition_number is invalid. | |||
| RuntimeError: If columns_list is invalid. | |||
| """ | |||
| def __init__(self, source, destination, columns_list=None, partition_number=1): | |||
| if not pd: | |||
| raise Exception("Module pandas is not found, please use pip install it.") | |||
| if isinstance(source, str): | |||
| check_filename(source) | |||
| self.source = source | |||
| else: | |||
| raise ValueError("The parameter source must be str.") | |||
| self._check_columns(columns_list, "columns_list") | |||
| self.columns_list = columns_list | |||
| if isinstance(destination, str): | |||
| check_filename(destination) | |||
| self.destination = destination | |||
| else: | |||
| raise ValueError("The parameter destination must be str.") | |||
| if partition_number is not None: | |||
| if not isinstance(partition_number, int): | |||
| raise ValueError("The parameter partition_number must be int") | |||
| self.partition_number = partition_number | |||
| else: | |||
| raise ValueError("The parameter partition_number must be int") | |||
| self.writer = FileWriter(self.destination, self.partition_number) | |||
| def _check_columns(self, columns, columns_name): | |||
| if columns: | |||
| if isinstance(columns, list): | |||
| for col in columns: | |||
| if not isinstance(col, str): | |||
| raise ValueError("The parameter {} must be list of str.".format(columns_name)) | |||
| else: | |||
| raise ValueError("The parameter {} must be list of str.".format(columns_name)) | |||
| def _get_schema(self, df): | |||
| """ | |||
| Construct schema from df columns | |||
| """ | |||
| if self.columns_list: | |||
| for col in self.columns_list: | |||
| if col not in df.columns: | |||
| raise RuntimeError("The parameter columns_list is illegal, column {} does not exist.".format(col)) | |||
| else: | |||
| self.columns_list = df.columns | |||
| schema = {} | |||
| for col in self.columns_list: | |||
| if str(df[col].dtype) == 'int64': | |||
| schema[col] = {"type": "int64"} | |||
| elif str(df[col].dtype) == 'float64': | |||
| schema[col] = {"type": "float64"} | |||
| elif str(df[col].dtype) == 'bool': | |||
| schema[col] = {"type": "int32"} | |||
| else: | |||
| schema[col] = {"type": "string"} | |||
| if not schema: | |||
| raise RuntimeError("Failed to generate schema from csv file.") | |||
| return schema | |||
| def _get_row_of_csv(self, df): | |||
| """Get row data from csv file.""" | |||
| for _, r in df.iterrows(): | |||
| row = {} | |||
| for col in self.columns_list: | |||
| if str(df[col].dtype) == 'bool': | |||
| row[col] = int(r[col]) | |||
| else: | |||
| row[col] = r[col] | |||
| yield row | |||
| def transform(self): | |||
| """ | |||
| Executes transformation from csv to MindRecord. | |||
| Returns: | |||
| SUCCESS/FAILED, whether successfully written into MindRecord. | |||
| """ | |||
| if not os.path.exists(self.source): | |||
| raise IOError("Csv file {} do not exist.".format(self.source)) | |||
| pd.set_option('display.max_columns', None) | |||
| df = pd.read_csv(self.source) | |||
| csv_schema = self._get_schema(df) | |||
| logger.info("transformed MindRecord schema is: {}".format(csv_schema)) | |||
| # set the header size | |||
| self.writer.set_header_size(1 << 24) | |||
| # set the page size | |||
| self.writer.set_page_size(1 << 26) | |||
| # create the schema | |||
| self.writer.add_schema(csv_schema, "csv_schema") | |||
| # add the index | |||
| self.writer.add_index(list(self.columns_list)) | |||
| csv_iter = self._get_row_of_csv(df) | |||
| batch_size = 256 | |||
| transform_count = 0 | |||
| while True: | |||
| data_list = [] | |||
| try: | |||
| for _ in range(batch_size): | |||
| data_list.append(csv_iter.__next__()) | |||
| transform_count += 1 | |||
| self.writer.write_raw_data(data_list) | |||
| logger.info("transformed {} record...".format(transform_count)) | |||
| except StopIteration: | |||
| if data_list: | |||
| self.writer.write_raw_data(data_list) | |||
| logger.info( | |||
| "transformed {} record...".format(transform_count)) | |||
| break | |||
| ret = self.writer.commit() | |||
| return ret | |||
| @@ -115,10 +115,8 @@ class TFRecordToMR: | |||
| "sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}} | |||
| bytes_fields (list): the bytes fields which are in feature_dict. | |||
| Rasies: | |||
| ValueError: the following condition will cause ValueError, 1) parameter TFRecord is not string, 2) parameter | |||
| MindRecord is not string, 3) feature_dict is not FixedLenFeature, 4) parameter bytes_field is not list(str) | |||
| or not in feature_dict. | |||
| Raises: | |||
| ValueError: If parameter is invalid. | |||
| Exception: when tensorflow module not found or version is not correct. | |||
| """ | |||
| def __init__(self, source, destination, feature_dict, bytes_fields=None): | |||
| @@ -0,0 +1,7 @@ | |||
| Age,EmployNumber,Name,Sales,Over18 | |||
| 21, 10023,john, 123.45,True | |||
| 41, 10223,tom, 12111,True | |||
| 51, 10231,bob, 8779.0,True | |||
| 86, 10053,alice, 7777,True | |||
| 26, 1053,carol, 12345.8,False | |||
| @@ -0,0 +1,143 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test csv to mindrecord tool""" | |||
| import os | |||
| from importlib import import_module | |||
| import pytest | |||
| from mindspore import log as logger | |||
| from mindspore.mindrecord import FileReader | |||
| from mindspore.mindrecord import CsvToMR | |||
| try: | |||
| pd = import_module('pandas') | |||
| except ModuleNotFoundError: | |||
| pd = None | |||
| CSV_FILE = "../data/mindrecord/testCsv/data.csv" | |||
| MINDRECORD_FILE = "../data/mindrecord/testCsv/csv.mindrecord" | |||
| PARTITION_NUMBER = 4 | |||
| @pytest.fixture(name="remove_mindrecord_file") | |||
| def fixture_remove(): | |||
| """add/remove file""" | |||
| def remove_one_file(x): | |||
| if os.path.exists(x): | |||
| os.remove(x) | |||
| def remove_file(): | |||
| x = MINDRECORD_FILE | |||
| remove_one_file(x) | |||
| x = MINDRECORD_FILE + ".db" | |||
| remove_one_file(x) | |||
| for i in range(PARTITION_NUMBER): | |||
| x = MINDRECORD_FILE + str(i) | |||
| remove_one_file(x) | |||
| x = MINDRECORD_FILE + str(i) + ".db" | |||
| remove_one_file(x) | |||
| remove_file() | |||
| yield "yield_fixture_data" | |||
| remove_file() | |||
| def read(filename, columns, row_num): | |||
| """test file reade""" | |||
| if not pd: | |||
| raise Exception("Module pandas is not found, please use pip install it.") | |||
| df = pd.read_csv(CSV_FILE) | |||
| count = 0 | |||
| reader = FileReader(filename) | |||
| for _, x in enumerate(reader.get_next()): | |||
| for col in columns: | |||
| assert x[col] == df[col].iloc[count] | |||
| assert len(x) == len(columns) | |||
| count = count + 1 | |||
| if count == 1: | |||
| logger.info("data: {}".format(x)) | |||
| assert count == row_num | |||
| reader.close() | |||
| def test_csv_to_mindrecord(remove_mindrecord_file): | |||
| """test transform csv to mindrecord.""" | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=PARTITION_NUMBER) | |||
| csv_trans.transform() | |||
| for i in range(PARTITION_NUMBER): | |||
| assert os.path.exists(MINDRECORD_FILE + str(i)) | |||
| assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") | |||
| read(MINDRECORD_FILE + "0", ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5) | |||
| def test_csv_to_mindrecord_with_columns(remove_mindrecord_file): | |||
| """test transform csv to mindrecord.""" | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'Sales'], partition_number=PARTITION_NUMBER) | |||
| csv_trans.transform() | |||
| for i in range(PARTITION_NUMBER): | |||
| assert os.path.exists(MINDRECORD_FILE + str(i)) | |||
| assert os.path.exists(MINDRECORD_FILE + str(i) + ".db") | |||
| read(MINDRECORD_FILE + "0", ["Age", "Sales"], 5) | |||
| def test_csv_to_mindrecord_with_no_exist_columns(remove_mindrecord_file): | |||
| """test transform csv to mindrecord.""" | |||
| with pytest.raises(Exception, match="The parameter columns_list is illegal, column ssales does not exist."): | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, columns_list=['Age', 'ssales'], | |||
| partition_number=PARTITION_NUMBER) | |||
| csv_trans.transform() | |||
| def test_csv_partition_number_with_illegal_columns(remove_mindrecord_file): | |||
| """ | |||
| test transform csv to mindrecord | |||
| """ | |||
| with pytest.raises(Exception, match="The parameter columns_list must be list of str."): | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, ["Sales", 2]) | |||
| csv_trans.transform() | |||
| def test_csv_to_mindrecord_default_partition_number(remove_mindrecord_file): | |||
| """ | |||
| test transform csv to mindrecord | |||
| when partition number is default. | |||
| """ | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE) | |||
| csv_trans.transform() | |||
| assert os.path.exists(MINDRECORD_FILE) | |||
| assert os.path.exists(MINDRECORD_FILE + ".db") | |||
| read(MINDRECORD_FILE, ["Age", "EmployNumber", "Name", "Sales", "Over18"], 5) | |||
| def test_csv_partition_number_0(remove_mindrecord_file): | |||
| """ | |||
| test transform csv to mindrecord | |||
| when partition number is 0. | |||
| """ | |||
| with pytest.raises(Exception, match="Invalid parameter value"): | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, 0) | |||
| csv_trans.transform() | |||
| def test_csv_to_mindrecord_partition_number_none(remove_mindrecord_file): | |||
| """ | |||
| test transform csv to mindrecord | |||
| when partition number is none. | |||
| """ | |||
| with pytest.raises(Exception, | |||
| match="The parameter partition_number must be int"): | |||
| csv_trans = CsvToMR(CSV_FILE, MINDRECORD_FILE, None, None) | |||
| csv_trans.transform() | |||
| def test_csv_to_mindrecord_illegal_filename(remove_mindrecord_file): | |||
| """ | |||
| test transform csv to mindrecord | |||
| when file name contains illegal character. | |||
| """ | |||
| filename = "not_*ok" | |||
| with pytest.raises(Exception, match="File name should not contains"): | |||
| csv_trans = CsvToMR(CSV_FILE, filename) | |||
| csv_trans.transform() | |||