Merge pull request !158 from guozhijian/fix_block_reader_hungtags/v0.2.0-alpha
| @@ -785,6 +785,8 @@ vector<std::string> ShardReader::GetAllColumns() { | |||||
| MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | const std::vector<std::shared_ptr<ShardOperator>> &operators) { | ||||
| vector<std::string> columns = GetAllColumns(); | |||||
| CheckIfColumnInIndex(columns); | |||||
| for (const auto &rg : row_group_summary) { | for (const auto &rg : row_group_summary) { | ||||
| auto shard_id = std::get<0>(rg); | auto shard_id = std::get<0>(rg); | ||||
| auto group_id = std::get<1>(rg); | auto group_id = std::get<1>(rg); | ||||
| @@ -143,6 +143,7 @@ class FileWriter: | |||||
| ParamTypeError: If index field is invalid. | ParamTypeError: If index field is invalid. | ||||
| MRMDefineIndexError: If index field is not primitive type. | MRMDefineIndexError: If index field is not primitive type. | ||||
| MRMAddIndexError: If failed to add index field. | MRMAddIndexError: If failed to add index field. | ||||
| MRMGetMetaError: If the schema is not set or get meta failed. | |||||
| """ | """ | ||||
| if not index_fields or not isinstance(index_fields, list): | if not index_fields or not isinstance(index_fields, list): | ||||
| raise ParamTypeError('index_fields', 'list') | raise ParamTypeError('index_fields', 'list') | ||||
| @@ -24,7 +24,7 @@ from mindspore import log as logger | |||||
| from .cifar100 import Cifar100 | from .cifar100 import Cifar100 | ||||
| from ..common.exceptions import PathNotExistsError | from ..common.exceptions import PathNotExistsError | ||||
| from ..filewriter import FileWriter | from ..filewriter import FileWriter | ||||
| from ..shardutils import check_filename | |||||
| from ..shardutils import check_filename, SUCCESS | |||||
| try: | try: | ||||
| cv2 = import_module("cv2") | cv2 = import_module("cv2") | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| @@ -98,8 +98,11 @@ class Cifar100ToMR: | |||||
| data_list = _construct_raw_data(images, fine_labels, coarse_labels) | data_list = _construct_raw_data(images, fine_labels, coarse_labels) | ||||
| test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels) | test_data_list = _construct_raw_data(test_images, test_fine_labels, test_coarse_labels) | ||||
| _generate_mindrecord(self.destination, data_list, fields, "img_train") | |||||
| _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") | |||||
| if _generate_mindrecord(self.destination, data_list, fields, "img_train") != SUCCESS: | |||||
| return FAILED | |||||
| if _generate_mindrecord(self.destination + "_test", test_data_list, fields, "img_test") != SUCCESS: | |||||
| return FAILED | |||||
| return SUCCESS | |||||
| def _construct_raw_data(images, fine_labels, coarse_labels): | def _construct_raw_data(images, fine_labels, coarse_labels): | ||||
| """ | """ | ||||
| @@ -47,7 +47,9 @@ def add_and_remove_cv_file(): | |||||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | ||||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | writer = FileWriter(CV_FILE_NAME, FILES_NUM) | ||||
| data = get_data(CV_DIR_NAME) | data = get_data(CV_DIR_NAME) | ||||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, | |||||
| cv_schema_json = {"id": {"type": "int32"}, | |||||
| "file_name": {"type": "string"}, | |||||
| "label": {"type": "int32"}, | |||||
| "data": {"type": "bytes"}} | "data": {"type": "bytes"}} | ||||
| writer.add_schema(cv_schema_json, "img_schema") | writer.add_schema(cv_schema_json, "img_schema") | ||||
| writer.add_index(["file_name", "label"]) | writer.add_index(["file_name", "label"]) | ||||
| @@ -226,6 +228,24 @@ def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file): | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 20 | assert num_iter == 20 | ||||
| def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file): | |||||
| """tutorial for cv minddataset.""" | |||||
| columns_list = ["id", "data", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, | |||||
| block_reader=True) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| repeat_num = 2 | |||||
| data_set = data_set.repeat(repeat_num) | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter)) | |||||
| logger.info("-------------- item[id]: {} ----------------------------".format(item["id"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 20 | |||||
| def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): | def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): | ||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| @@ -359,13 +379,14 @@ def get_data(dir_name): | |||||
| lines = file_reader.readlines() | lines = file_reader.readlines() | ||||
| data_list = [] | data_list = [] | ||||
| for line in lines: | |||||
| for i, line in enumerate(lines): | |||||
| try: | try: | ||||
| filename, label = line.split(",") | filename, label = line.split(",") | ||||
| label = label.strip("\n") | label = label.strip("\n") | ||||
| with open(os.path.join(img_dir, filename), "rb") as file_reader: | with open(os.path.join(img_dir, filename), "rb") as file_reader: | ||||
| img = file_reader.read() | img = file_reader.read() | ||||
| data_json = {"file_name": filename, | |||||
| data_json = {"id": i, | |||||
| "file_name": filename, | |||||
| "data": img, | "data": img, | ||||
| "label": int(label)} | "label": int(label)} | ||||
| data_list.append(data_json) | data_list.append(data_json) | ||||
| @@ -18,6 +18,7 @@ import pytest | |||||
| from mindspore.mindrecord import Cifar100ToMR | from mindspore.mindrecord import Cifar100ToMR | ||||
| from mindspore.mindrecord import FileReader | from mindspore.mindrecord import FileReader | ||||
| from mindspore.mindrecord import MRMOpenError | from mindspore.mindrecord import MRMOpenError | ||||
| from mindspore.mindrecord import SUCCESS | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | ||||
| @@ -26,7 +27,8 @@ MINDRECORD_FILE = "./cifar100.mindrecord" | |||||
| def test_cifar100_to_mindrecord_without_index_fields(): | def test_cifar100_to_mindrecord_without_index_fields(): | ||||
| """test transform cifar100 dataset to mindrecord without index fields.""" | """test transform cifar100 dataset to mindrecord without index fields.""" | ||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE) | ||||
| cifar100_transformer.transform() | |||||
| ret = cifar100_transformer.transform() | |||||
| assert ret == SUCCESS, "Failed to tranform from cifar100 to mindrecord" | |||||
| assert os.path.exists(MINDRECORD_FILE) | assert os.path.exists(MINDRECORD_FILE) | ||||
| assert os.path.exists(MINDRECORD_FILE + "_test") | assert os.path.exists(MINDRECORD_FILE + "_test") | ||||
| read() | read() | ||||
| @@ -16,7 +16,7 @@ | |||||
| import os | import os | ||||
| import pytest | import pytest | ||||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage | from mindspore.mindrecord import FileWriter, FileReader, MindPage | ||||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError | |||||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from utils import get_data | from utils import get_data | ||||
| @@ -280,3 +280,9 @@ def test_cv_file_writer_shard_num_greater_than_1000(): | |||||
| with pytest.raises(ParamValueError) as err: | with pytest.raises(ParamValueError) as err: | ||||
| FileWriter(CV_FILE_NAME, 1001) | FileWriter(CV_FILE_NAME, 1001) | ||||
| assert 'Shard number should between' in str(err.value) | assert 'Shard number should between' in str(err.value) | ||||
| def test_add_index_without_add_schema(): | |||||
| with pytest.raises(MRMGetMetaError) as err: | |||||
| fw = FileWriter(CV_FILE_NAME) | |||||
| fw.add_index(["label"]) | |||||
| assert 'Failed to get meta info' in str(err.value) | |||||