Merge pull request !186 from guozhijian/fix_mindpage_errortags/v0.2.0-alpha
| @@ -33,6 +33,7 @@ | |||
| #include <map> | |||
| #include <random> | |||
| #include <set> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| @@ -117,6 +118,12 @@ const char kPoint = '.'; | |||
| // field type used by check schema validation | |||
| const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; | |||
| // can be searched field list | |||
| const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; | |||
| // number field list | |||
| const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; | |||
| /// \brief split a string using a character | |||
| /// \param[in] field target string | |||
| /// \param[in] separator a character for spliting | |||
| @@ -42,11 +42,11 @@ class ShardIndexGenerator { | |||
| ~ShardIndexGenerator() {} | |||
| /// \brief fetch value in json by field path | |||
| /// \param[in] field_path | |||
| /// \param[in] schema | |||
| /// \return the vector of value | |||
| static std::vector<std::string> GetField(const std::string &field_path, json schema); | |||
| /// \brief fetch value in json by field name | |||
| /// \param[in] field | |||
| /// \param[in] input | |||
| /// \return pair<MSRStatus, value> | |||
| std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input); | |||
| /// \brief fetch field type in schema n by field path | |||
| /// \param[in] field_path | |||
| @@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe | |||
| MSRStatus ShardIndexGenerator::Build() { | |||
| ShardHeader header = ShardHeader(); | |||
| if (header.Build(file_path_) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Build shard schema failed"; | |||
| MS_LOG(ERROR) << "Build shard schema failed."; | |||
| return FAILED; | |||
| } | |||
| shard_header_ = header; | |||
| @@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() { | |||
| return SUCCESS; | |||
| } | |||
| std::vector<std::string> ShardIndexGenerator::GetField(const string &field_path, json schema) { | |||
| std::vector<std::string> field_name = StringSplit(field_path, kPoint); | |||
| std::vector<std::string> res; | |||
| if (schema.empty()) { | |||
| res.emplace_back("null"); | |||
| return res; | |||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) { | |||
| if (field.empty()) { | |||
| MS_LOG(ERROR) << "The input field is None."; | |||
| return {FAILED, ""}; | |||
| } | |||
| for (uint64_t i = 0; i < field_name.size(); i++) { | |||
| // Check if field is part of an array of objects | |||
| auto &child = schema.at(field_name[i]); | |||
| if (child.is_array() && !child.empty() && child[0].is_object()) { | |||
| schema = schema[field_name[i]]; | |||
| std::string new_field_path; | |||
| for (uint64_t j = i + 1; j < field_name.size(); j++) { | |||
| if (j > i + 1) new_field_path += '.'; | |||
| new_field_path += field_name[j]; | |||
| } | |||
| // Return multiple field data since multiple objects in array | |||
| for (auto &single_schema : schema) { | |||
| auto child_res = GetField(new_field_path, single_schema); | |||
| res.insert(res.end(), child_res.begin(), child_res.end()); | |||
| } | |||
| return res; | |||
| if (input.empty()) { | |||
| MS_LOG(ERROR) << "The input json is None."; | |||
| return {FAILED, ""}; | |||
| } | |||
| // parameter input does not contain the field | |||
| if (input.find(field) == input.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; | |||
| return {FAILED, ""}; | |||
| } | |||
| // schema does not contain the field | |||
| auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"]; | |||
| if (schema.find(field) == schema.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | |||
| return {FAILED, ""}; | |||
| } | |||
| // field should be scalar type | |||
| if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; | |||
| return {FAILED, ""}; | |||
| } | |||
| if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { | |||
| auto schema_field_options = schema[field]; | |||
| if (schema_field_options.find("shape") == schema_field_options.end()) { | |||
| return {SUCCESS, input[field].dump()}; | |||
| } else { | |||
| // field with shape option | |||
| MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; | |||
| return {FAILED, ""}; | |||
| } | |||
| schema = schema.at(field_name[i]); | |||
| } | |||
| // Return vector of one field data (not array of objects) | |||
| return std::vector<std::string>{schema.dump()}; | |||
| // the field type is string in here | |||
| return {SUCCESS, input[field].get<std::string>()}; | |||
| } | |||
| std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { | |||
| @@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||
| const auto &place_holder = std::get<0>(field); | |||
| const auto &field_type = std::get<1>(field); | |||
| const auto &field_value = std::get<2>(field); | |||
| int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); | |||
| if (field_type == "INTEGER") { | |||
| if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { | |||
| @@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s | |||
| if (field.first >= schema_detail.size()) { | |||
| return {FAILED, {}}; | |||
| } | |||
| auto field_value = GetField(field.second, schema_detail[field.first]); | |||
| auto field_value = GetValueByField(field.second, schema_detail[field.first]); | |||
| if (field_value.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get value from json by field name failed"; | |||
| return {FAILED, {}}; | |||
| } | |||
| auto result = shard_header_.GetSchemaByID(field.first); | |||
| if (result.second != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); | |||
| auto ret = GenerateFieldName(field); | |||
| if (ret.first != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| fields.emplace_back(ret.second, field_type, field_value[0]); | |||
| fields.emplace_back(ret.second, field_type, field_value.second); | |||
| } | |||
| return {SUCCESS, std::move(fields)}; | |||
| } | |||
| @@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| template <class Type> | |||
| // convert the string to exactly number type (int32_t/int64_t/float/double) | |||
| Type StringToNum(const std::string &str) { | |||
| std::istringstream iss(str); | |||
| Type num; | |||
| iss >> num; | |||
| return num; | |||
| } | |||
| ShardReader::ShardReader() { | |||
| task_id_ = 0; | |||
| deliver_id_ = 0; | |||
| @@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str | |||
| } | |||
| column_values[shard_id].emplace_back(tmp); | |||
| } else { | |||
| string json_str = "{"; | |||
| json construct_json; | |||
| for (unsigned int j = 0; j < columns.size(); ++j) { | |||
| // construct the string json "f1": value | |||
| json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j + 3]; | |||
| if (j < columns.size() - 1) { | |||
| json_str += ","; | |||
| // construct json "f1": value | |||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||
| // convert the string to base type by schema | |||
| if (schema[columns[j]]["type"] == "int32") { | |||
| construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]); | |||
| } else if (schema[columns[j]]["type"] == "int64") { | |||
| construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]); | |||
| } else if (schema[columns[j]]["type"] == "float32") { | |||
| construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]); | |||
| } else if (schema[columns[j]]["type"] == "float64") { | |||
| construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]); | |||
| } else { | |||
| construct_json[columns[j]] = std::string(labels[i][j + 3]); | |||
| } | |||
| } | |||
| json_str += "}"; | |||
| column_values[shard_id].emplace_back(json::parse(json_str)); | |||
| column_values[shard_id].emplace_back(construct_json); | |||
| } | |||
| } | |||
| @@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int | |||
| // whether use index search | |||
| if (!criteria.first.empty()) { | |||
| sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; | |||
| auto schema = shard_header_->get_schemas()[0]->GetSchema(); | |||
| // not number field should add '' in sql | |||
| if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { | |||
| sql += | |||
| " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; | |||
| } else { | |||
| sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + | |||
| criteria.second + "'"; | |||
| } | |||
| } | |||
| sql += ";"; | |||
| std::vector<std::vector<std::string>> image_offsets; | |||
| @@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int | |||
| std::vector<json> ret; | |||
| for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); | |||
| for (unsigned int i = 0; i < labels.size(); ++i) { | |||
| string json_str = "{"; | |||
| json construct_json; | |||
| for (unsigned int j = 0; j < columns.size(); ++j) { | |||
| // construct string json "f1": value | |||
| json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j]; | |||
| if (j < columns.size() - 1) { | |||
| json_str += ","; | |||
| // construct json "f1": value | |||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||
| // convert the string to base type by schema | |||
| if (schema[columns[j]]["type"] == "int32") { | |||
| construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]); | |||
| } else if (schema[columns[j]]["type"] == "int64") { | |||
| construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]); | |||
| } else if (schema[columns[j]]["type"] == "float32") { | |||
| construct_json[columns[j]] = StringToNum<float>(labels[i][j]); | |||
| } else if (schema[columns[j]]["type"] == "float64") { | |||
| construct_json[columns[j]] = StringToNum<double>(labels[i][j]); | |||
| } else { | |||
| construct_json[columns[j]] = std::string(labels[i][j]); | |||
| } | |||
| } | |||
| json_str += "}"; | |||
| ret[i] = json::parse(json_str); | |||
| ret[i] = construct_json; | |||
| } | |||
| return {SUCCESS, ret}; | |||
| } | |||
| @@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS | |||
| MS_LOG(ERROR) << "Get category info"; | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| // category_name to category_id | |||
| int64_t category_id = -1; | |||
| for (const auto &categories : ret.second) { | |||
| if (std::get<1>(categories) == category_name) { | |||
| auto result = ReadAllAtPageById(std::get<0>(categories), page_no, n_rows_of_page); | |||
| return {SUCCESS, result.second}; | |||
| std::string categories_name = std::get<1>(categories); | |||
| if (categories_name == category_name) { | |||
| category_id = std::get<0>(categories); | |||
| break; | |||
| } | |||
| } | |||
| return {SUCCESS, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| if (category_id == -1) { | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| return ReadAllAtPageById(category_id, page_no, n_rows_of_page); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( | |||
| @@ -133,15 +133,15 @@ class MindPage: | |||
| Raises: | |||
| ParamValueError: If any parameter is invalid. | |||
| MRMFetchDataError: If failed to read by category id. | |||
| MRMFetchDataError: If failed to fetch data by category. | |||
| MRMUnsupportedSchemaError: If schema is invalid. | |||
| """ | |||
| if category_id < 0: | |||
| raise ParamValueError("Category id should be greater than 0.") | |||
| if page < 0: | |||
| raise ParamValueError("Page should be greater than 0.") | |||
| if num_row < 0: | |||
| raise ParamValueError("num_row should be greater than 0.") | |||
| if not isinstance(category_id, int) or category_id < 0: | |||
| raise ParamValueError("Category id should be int and greater than or equal to 0.") | |||
| if not isinstance(page, int) or page < 0: | |||
| raise ParamValueError("Page should be int and greater than or equal to 0.") | |||
| if not isinstance(num_row, int) or num_row <= 0: | |||
| raise ParamValueError("num_row should be int and greater than 0.") | |||
| return self._segment.read_at_page_by_id(category_id, page, num_row) | |||
| def read_at_page_by_name(self, category_name, page, num_row): | |||
| @@ -157,8 +157,10 @@ class MindPage: | |||
| Returns: | |||
| str, read at page. | |||
| """ | |||
| if page < 0: | |||
| raise ParamValueError("Page should be greater than 0.") | |||
| if num_row < 0: | |||
| raise ParamValueError("num_row should be greater than 0.") | |||
| if not isinstance(category_name, str): | |||
| raise ParamValueError("Category name should be str.") | |||
| if not isinstance(page, int) or page < 0: | |||
| raise ParamValueError("Page should be int and greater than or equal to 0.") | |||
| if not isinstance(num_row, int) or num_row <= 0: | |||
| raise ParamValueError("num_row should be int and greater than 0.") | |||
| return self._segment.read_at_page_by_name(category_name, page, num_row) | |||
| @@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common { | |||
| TestShardIndexGenerator() {} | |||
| }; | |||
| /* | |||
| TEST_F(TestShardIndexGenerator, GetField) { | |||
| MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); | |||
| @@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) { | |||
| } | |||
| } | |||
| } | |||
| */ | |||
| TEST_F(TestShardIndexGenerator, TakeFieldType) { | |||
| MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test mindrecord base""" | |||
| import numpy as np | |||
| import os | |||
| import uuid | |||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | |||
| @@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord" | |||
| CV3_FILE_NAME = "./imagenet_append.mindrecord" | |||
| NLP_FILE_NAME = "./aclImdb.mindrecord" | |||
| def test_write_read_process(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), | |||
| "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), | |||
| "data": bytes("image bytes abc", encoding='UTF-8')}, | |||
| {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), | |||
| "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), | |||
| "data": bytes("image bytes def", encoding='UTF-8')}, | |||
| {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), | |||
| "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), | |||
| "data": bytes("image bytes ghi", encoding='UTF-8')}, | |||
| {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), | |||
| "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), | |||
| "data": bytes("image bytes jkl", encoding='UTF-8')}, | |||
| {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), | |||
| "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), | |||
| "data": bytes("image bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), | |||
| "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), | |||
| "data": bytes("image bytes pqr", encoding='UTF-8')} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"file_name": {"type": "string"}, | |||
| "label": {"type": "int32"}, | |||
| "score": {"type": "float64"}, | |||
| "mask": {"type": "int64", "shape": [-1]}, | |||
| "segments": {"type": "float32", "shape": [2, 2]}, | |||
| "data": {"type": "bytes"}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| reader = FileReader(mindrecord_file_name) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 6 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| def test_write_read_process_with_define_index_field(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), | |||
| "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), | |||
| "data": bytes("image bytes abc", encoding='UTF-8')}, | |||
| {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), | |||
| "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), | |||
| "data": bytes("image bytes def", encoding='UTF-8')}, | |||
| {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), | |||
| "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), | |||
| "data": bytes("image bytes ghi", encoding='UTF-8')}, | |||
| {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), | |||
| "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), | |||
| "data": bytes("image bytes jkl", encoding='UTF-8')}, | |||
| {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), | |||
| "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), | |||
| "data": bytes("image bytes mno", encoding='UTF-8')}, | |||
| {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), | |||
| "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), | |||
| "data": bytes("image bytes pqr", encoding='UTF-8')} | |||
| ] | |||
| writer = FileWriter(mindrecord_file_name) | |||
| schema = {"file_name": {"type": "string"}, | |||
| "label": {"type": "int32"}, | |||
| "score": {"type": "float64"}, | |||
| "mask": {"type": "int64", "shape": [-1]}, | |||
| "segments": {"type": "float32", "shape": [2, 2]}, | |||
| "data": {"type": "bytes"}} | |||
| writer.add_schema(schema, "data is so cool") | |||
| writer.add_index(["label"]) | |||
| writer.write_raw_data(data) | |||
| writer.commit() | |||
| reader = FileReader(mindrecord_file_name) | |||
| count = 0 | |||
| for index, x in enumerate(reader.get_next()): | |||
| assert len(x) == 6 | |||
| for field in x: | |||
| if isinstance(x[field], np.ndarray): | |||
| assert (x[field] == data[count][field]).all() | |||
| else: | |||
| assert x[field] == data[count][field] | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| assert count == 6 | |||
| reader.close() | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| def test_cv_file_writer_tutorial(): | |||
| """tutorial for cv dataset writer.""" | |||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | |||
| @@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial(): | |||
| assert len(row1[0]) == 3 | |||
| assert row1[0]['label'] == 822 | |||
| def test_cv_page_reader_tutorial_by_file_name(): | |||
| """tutorial for cv page reader.""" | |||
| reader = MindPage(CV_FILE_NAME + "0") | |||
| fields = reader.get_category_fields() | |||
| assert fields == ['file_name', 'label'],\ | |||
| 'failed on getting candidate category fields.' | |||
| ret = reader.set_category_field("file_name") | |||
| assert ret == SUCCESS, 'failed on setting category field.' | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| row = reader.read_at_page_by_id(0, 0, 1) | |||
| assert len(row) == 1 | |||
| assert len(row[0]) == 3 | |||
| assert row[0]['label'] == 490 | |||
| row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1) | |||
| assert len(row1) == 1 | |||
| assert len(row1[0]) == 3 | |||
| assert row1[0]['label'] == 13 | |||
| def test_cv_page_reader_tutorial_new_api(): | |||
| """tutorial for cv page reader.""" | |||
| reader = MindPage(CV_FILE_NAME + "0") | |||
| fields = reader.candidate_fields | |||
| assert fields == ['file_name', 'label'],\ | |||
| 'failed on getting candidate category fields.' | |||
| reader.category_field = "file_name" | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| row = reader.read_at_page_by_id(0, 0, 1) | |||
| assert len(row) == 1 | |||
| assert len(row[0]) == 3 | |||
| assert row[0]['label'] == 490 | |||
| row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1) | |||
| assert len(row1) == 1 | |||
| assert len(row1[0]) == 3 | |||
| assert row1[0]['label'] == 13 | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| @@ -15,8 +15,9 @@ | |||
| """test mindrecord exception""" | |||
| import os | |||
| import pytest | |||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage | |||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError | |||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | |||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \ | |||
| MRMFetchDataError | |||
| from mindspore import log as logger | |||
| from utils import get_data | |||
| @@ -286,3 +287,67 @@ def test_add_index_without_add_schema(): | |||
| fw = FileWriter(CV_FILE_NAME) | |||
| fw.add_index(["label"]) | |||
| assert 'Failed to get meta info' in str(err.value) | |||
| def test_mindpage_pageno_pagesize_not_int(): | |||
| """test page reader when some partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| reader = MindPage(CV_FILE_NAME + "0") | |||
| fields = reader.get_category_fields() | |||
| assert fields == ['file_name', 'label'],\ | |||
| 'failed on getting candidate category fields.' | |||
| ret = reader.set_category_field("label") | |||
| assert ret == SUCCESS, 'failed on setting category field.' | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| with pytest.raises(ParamValueError) as err: | |||
| reader.read_at_page_by_id(0, "0", 1) | |||
| with pytest.raises(ParamValueError) as err: | |||
| reader.read_at_page_by_id(0, 0, "b") | |||
| with pytest.raises(ParamValueError) as err: | |||
| reader.read_at_page_by_name("822", "e", 1) | |||
| with pytest.raises(ParamValueError) as err: | |||
| reader.read_at_page_by_name("822", 0, "qwer") | |||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | |||
| reader.read_at_page_by_id(99999, 0, 1) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) | |||
| os.remove("{}.db".format(x)) | |||
| def test_mindpage_filename_not_exist(): | |||
| """test page reader when some partition does not exist.""" | |||
| create_cv_mindrecord(4) | |||
| reader = MindPage(CV_FILE_NAME + "0") | |||
| fields = reader.get_category_fields() | |||
| assert fields == ['file_name', 'label'],\ | |||
| 'failed on getting candidate category fields.' | |||
| ret = reader.set_category_field("file_name") | |||
| assert ret == SUCCESS, 'failed on setting category field.' | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| with pytest.raises(MRMFetchDataError) as err: | |||
| reader.read_at_page_by_id(9999, 0, 1) | |||
| with pytest.raises(MRMFetchDataError) as err: | |||
| reader.read_at_page_by_name("abc.jpg", 0, 1) | |||
| with pytest.raises(ParamValueError) as err: | |||
| reader.read_at_page_by_name(1, 0, 1) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) | |||
| os.remove("{}.db".format(x)) | |||