Merge pull request !186 from guozhijian/fix_mindpage_errortags/v0.2.0-alpha
| @@ -33,6 +33,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <random> | #include <random> | ||||
| #include <set> | #include <set> | ||||
| #include <sstream> | |||||
| #include <string> | #include <string> | ||||
| #include <thread> | #include <thread> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| @@ -117,6 +118,12 @@ const char kPoint = '.'; | |||||
| // field type used by check schema validation | // field type used by check schema validation | ||||
| const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; | const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; | ||||
| // can be searched field list | |||||
| const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; | |||||
| // number field list | |||||
| const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; | |||||
| /// \brief split a string using a character | /// \brief split a string using a character | ||||
| /// \param[in] field target string | /// \param[in] field target string | ||||
| /// \param[in] separator a character for spliting | /// \param[in] separator a character for spliting | ||||
| @@ -42,11 +42,11 @@ class ShardIndexGenerator { | |||||
| ~ShardIndexGenerator() {} | ~ShardIndexGenerator() {} | ||||
| /// \brief fetch value in json by field path | |||||
| /// \param[in] field_path | |||||
| /// \param[in] schema | |||||
| /// \return the vector of value | |||||
| static std::vector<std::string> GetField(const std::string &field_path, json schema); | |||||
| /// \brief fetch value in json by field name | |||||
| /// \param[in] field | |||||
| /// \param[in] input | |||||
| /// \return pair<MSRStatus, value> | |||||
| std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input); | |||||
| /// \brief fetch field type in schema n by field path | /// \brief fetch field type in schema n by field path | ||||
| /// \param[in] field_path | /// \param[in] field_path | ||||
| @@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe | |||||
| MSRStatus ShardIndexGenerator::Build() { | MSRStatus ShardIndexGenerator::Build() { | ||||
| ShardHeader header = ShardHeader(); | ShardHeader header = ShardHeader(); | ||||
| if (header.Build(file_path_) != SUCCESS) { | if (header.Build(file_path_) != SUCCESS) { | ||||
| MS_LOG(ERROR) << "Build shard schema failed"; | |||||
| MS_LOG(ERROR) << "Build shard schema failed."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_header_ = header; | shard_header_ = header; | ||||
| @@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::vector<std::string> ShardIndexGenerator::GetField(const string &field_path, json schema) { | |||||
| std::vector<std::string> field_name = StringSplit(field_path, kPoint); | |||||
| std::vector<std::string> res; | |||||
| if (schema.empty()) { | |||||
| res.emplace_back("null"); | |||||
| return res; | |||||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) { | |||||
| if (field.empty()) { | |||||
| MS_LOG(ERROR) << "The input field is None."; | |||||
| return {FAILED, ""}; | |||||
| } | } | ||||
| for (uint64_t i = 0; i < field_name.size(); i++) { | |||||
| // Check if field is part of an array of objects | |||||
| auto &child = schema.at(field_name[i]); | |||||
| if (child.is_array() && !child.empty() && child[0].is_object()) { | |||||
| schema = schema[field_name[i]]; | |||||
| std::string new_field_path; | |||||
| for (uint64_t j = i + 1; j < field_name.size(); j++) { | |||||
| if (j > i + 1) new_field_path += '.'; | |||||
| new_field_path += field_name[j]; | |||||
| } | |||||
| // Return multiple field data since multiple objects in array | |||||
| for (auto &single_schema : schema) { | |||||
| auto child_res = GetField(new_field_path, single_schema); | |||||
| res.insert(res.end(), child_res.begin(), child_res.end()); | |||||
| } | |||||
| return res; | |||||
| if (input.empty()) { | |||||
| MS_LOG(ERROR) << "The input json is None."; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| // parameter input does not contain the field | |||||
| if (input.find(field) == input.end()) { | |||||
| MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| // schema does not contain the field | |||||
| auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"]; | |||||
| if (schema.find(field) == schema.end()) { | |||||
| MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| // field should be scalar type | |||||
| if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { | |||||
| MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { | |||||
| auto schema_field_options = schema[field]; | |||||
| if (schema_field_options.find("shape") == schema_field_options.end()) { | |||||
| return {SUCCESS, input[field].dump()}; | |||||
| } else { | |||||
| // field with shape option | |||||
| MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; | |||||
| return {FAILED, ""}; | |||||
| } | } | ||||
| schema = schema.at(field_name[i]); | |||||
| } | } | ||||
| // Return vector of one field data (not array of objects) | |||||
| return std::vector<std::string>{schema.dump()}; | |||||
| // the field type is string in here | |||||
| return {SUCCESS, input[field].get<std::string>()}; | |||||
| } | } | ||||
| std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { | std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { | ||||
| @@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||||
| const auto &place_holder = std::get<0>(field); | const auto &place_holder = std::get<0>(field); | ||||
| const auto &field_type = std::get<1>(field); | const auto &field_type = std::get<1>(field); | ||||
| const auto &field_value = std::get<2>(field); | const auto &field_value = std::get<2>(field); | ||||
| int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); | int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); | ||||
| if (field_type == "INTEGER") { | if (field_type == "INTEGER") { | ||||
| if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { | if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { | ||||
| @@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s | |||||
| if (field.first >= schema_detail.size()) { | if (field.first >= schema_detail.size()) { | ||||
| return {FAILED, {}}; | return {FAILED, {}}; | ||||
| } | } | ||||
| auto field_value = GetField(field.second, schema_detail[field.first]); | |||||
| auto field_value = GetValueByField(field.second, schema_detail[field.first]); | |||||
| if (field_value.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get value from json by field name failed"; | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto result = shard_header_.GetSchemaByID(field.first); | auto result = shard_header_.GetSchemaByID(field.first); | ||||
| if (result.second != SUCCESS) { | if (result.second != SUCCESS) { | ||||
| return {FAILED, {}}; | return {FAILED, {}}; | ||||
| } | } | ||||
| std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); | std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); | ||||
| auto ret = GenerateFieldName(field); | auto ret = GenerateFieldName(field); | ||||
| if (ret.first != SUCCESS) { | if (ret.first != SUCCESS) { | ||||
| return {FAILED, {}}; | return {FAILED, {}}; | ||||
| } | } | ||||
| fields.emplace_back(ret.second, field_type, field_value[0]); | |||||
| fields.emplace_back(ret.second, field_type, field_value.second); | |||||
| } | } | ||||
| return {SUCCESS, std::move(fields)}; | return {SUCCESS, std::move(fields)}; | ||||
| } | } | ||||
| @@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| template <class Type> | |||||
| // convert the string to exactly number type (int32_t/int64_t/float/double) | |||||
| Type StringToNum(const std::string &str) { | |||||
| std::istringstream iss(str); | |||||
| Type num; | |||||
| iss >> num; | |||||
| return num; | |||||
| } | |||||
| ShardReader::ShardReader() { | ShardReader::ShardReader() { | ||||
| task_id_ = 0; | task_id_ = 0; | ||||
| deliver_id_ = 0; | deliver_id_ = 0; | ||||
| @@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str | |||||
| } | } | ||||
| column_values[shard_id].emplace_back(tmp); | column_values[shard_id].emplace_back(tmp); | ||||
| } else { | } else { | ||||
| string json_str = "{"; | |||||
| json construct_json; | |||||
| for (unsigned int j = 0; j < columns.size(); ++j) { | for (unsigned int j = 0; j < columns.size(); ++j) { | ||||
| // construct the string json "f1": value | |||||
| json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j + 3]; | |||||
| if (j < columns.size() - 1) { | |||||
| json_str += ","; | |||||
| // construct json "f1": value | |||||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||||
| // convert the string to base type by schema | |||||
| if (schema[columns[j]]["type"] == "int32") { | |||||
| construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]); | |||||
| } else if (schema[columns[j]]["type"] == "int64") { | |||||
| construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]); | |||||
| } else if (schema[columns[j]]["type"] == "float32") { | |||||
| construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]); | |||||
| } else if (schema[columns[j]]["type"] == "float64") { | |||||
| construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]); | |||||
| } else { | |||||
| construct_json[columns[j]] = std::string(labels[i][j + 3]); | |||||
| } | } | ||||
| } | } | ||||
| json_str += "}"; | |||||
| column_values[shard_id].emplace_back(json::parse(json_str)); | |||||
| column_values[shard_id].emplace_back(construct_json); | |||||
| } | } | ||||
| } | } | ||||
| @@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int | |||||
| // whether use index search | // whether use index search | ||||
| if (!criteria.first.empty()) { | if (!criteria.first.empty()) { | ||||
| sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; | |||||
| auto schema = shard_header_->get_schemas()[0]->GetSchema(); | |||||
| // not number field should add '' in sql | |||||
| if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { | |||||
| sql += | |||||
| " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; | |||||
| } else { | |||||
| sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + | |||||
| criteria.second + "'"; | |||||
| } | |||||
| } | } | ||||
| sql += ";"; | sql += ";"; | ||||
| std::vector<std::vector<std::string>> image_offsets; | std::vector<std::vector<std::string>> image_offsets; | ||||
| @@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int | |||||
| std::vector<json> ret; | std::vector<json> ret; | ||||
| for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); | for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); | ||||
| for (unsigned int i = 0; i < labels.size(); ++i) { | for (unsigned int i = 0; i < labels.size(); ++i) { | ||||
| string json_str = "{"; | |||||
| json construct_json; | |||||
| for (unsigned int j = 0; j < columns.size(); ++j) { | for (unsigned int j = 0; j < columns.size(); ++j) { | ||||
| // construct string json "f1": value | |||||
| json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j]; | |||||
| if (j < columns.size() - 1) { | |||||
| json_str += ","; | |||||
| // construct json "f1": value | |||||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||||
| // convert the string to base type by schema | |||||
| if (schema[columns[j]]["type"] == "int32") { | |||||
| construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]); | |||||
| } else if (schema[columns[j]]["type"] == "int64") { | |||||
| construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]); | |||||
| } else if (schema[columns[j]]["type"] == "float32") { | |||||
| construct_json[columns[j]] = StringToNum<float>(labels[i][j]); | |||||
| } else if (schema[columns[j]]["type"] == "float64") { | |||||
| construct_json[columns[j]] = StringToNum<double>(labels[i][j]); | |||||
| } else { | |||||
| construct_json[columns[j]] = std::string(labels[i][j]); | |||||
| } | } | ||||
| } | } | ||||
| json_str += "}"; | |||||
| ret[i] = json::parse(json_str); | |||||
| ret[i] = construct_json; | |||||
| } | } | ||||
| return {SUCCESS, ret}; | return {SUCCESS, ret}; | ||||
| } | } | ||||
| @@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS | |||||
| MS_LOG(ERROR) << "Get category info"; | MS_LOG(ERROR) << "Get category info"; | ||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | ||||
| } | } | ||||
| // category_name to category_id | |||||
| int64_t category_id = -1; | |||||
| for (const auto &categories : ret.second) { | for (const auto &categories : ret.second) { | ||||
| if (std::get<1>(categories) == category_name) { | |||||
| auto result = ReadAllAtPageById(std::get<0>(categories), page_no, n_rows_of_page); | |||||
| return {SUCCESS, result.second}; | |||||
| std::string categories_name = std::get<1>(categories); | |||||
| if (categories_name == category_name) { | |||||
| category_id = std::get<0>(categories); | |||||
| break; | |||||
| } | } | ||||
| } | } | ||||
| return {SUCCESS, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| if (category_id == -1) { | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| return ReadAllAtPageById(category_id, page_no, n_rows_of_page); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( | std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( | ||||
| @@ -133,15 +133,15 @@ class MindPage: | |||||
| Raises: | Raises: | ||||
| ParamValueError: If any parameter is invalid. | ParamValueError: If any parameter is invalid. | ||||
| MRMFetchDataError: If failed to read by category id. | |||||
| MRMFetchDataError: If failed to fetch data by category. | |||||
| MRMUnsupportedSchemaError: If schema is invalid. | MRMUnsupportedSchemaError: If schema is invalid. | ||||
| """ | """ | ||||
| if category_id < 0: | |||||
| raise ParamValueError("Category id should be greater than 0.") | |||||
| if page < 0: | |||||
| raise ParamValueError("Page should be greater than 0.") | |||||
| if num_row < 0: | |||||
| raise ParamValueError("num_row should be greater than 0.") | |||||
| if not isinstance(category_id, int) or category_id < 0: | |||||
| raise ParamValueError("Category id should be int and greater than or equal to 0.") | |||||
| if not isinstance(page, int) or page < 0: | |||||
| raise ParamValueError("Page should be int and greater than or equal to 0.") | |||||
| if not isinstance(num_row, int) or num_row <= 0: | |||||
| raise ParamValueError("num_row should be int and greater than 0.") | |||||
| return self._segment.read_at_page_by_id(category_id, page, num_row) | return self._segment.read_at_page_by_id(category_id, page, num_row) | ||||
| def read_at_page_by_name(self, category_name, page, num_row): | def read_at_page_by_name(self, category_name, page, num_row): | ||||
| @@ -157,8 +157,10 @@ class MindPage: | |||||
| Returns: | Returns: | ||||
| str, read at page. | str, read at page. | ||||
| """ | """ | ||||
| if page < 0: | |||||
| raise ParamValueError("Page should be greater than 0.") | |||||
| if num_row < 0: | |||||
| raise ParamValueError("num_row should be greater than 0.") | |||||
| if not isinstance(category_name, str): | |||||
| raise ParamValueError("Category name should be str.") | |||||
| if not isinstance(page, int) or page < 0: | |||||
| raise ParamValueError("Page should be int and greater than or equal to 0.") | |||||
| if not isinstance(num_row, int) or num_row <= 0: | |||||
| raise ParamValueError("num_row should be int and greater than 0.") | |||||
| return self._segment.read_at_page_by_name(category_name, page, num_row) | return self._segment.read_at_page_by_name(category_name, page, num_row) | ||||
| @@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common { | |||||
| TestShardIndexGenerator() {} | TestShardIndexGenerator() {} | ||||
| }; | }; | ||||
| /* | |||||
| TEST_F(TestShardIndexGenerator, GetField) { | TEST_F(TestShardIndexGenerator, GetField) { | ||||
| MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); | MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); | ||||
| @@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| */ | |||||
| TEST_F(TestShardIndexGenerator, TakeFieldType) { | TEST_F(TestShardIndexGenerator, TakeFieldType) { | ||||
| MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); | MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """test mindrecord base""" | """test mindrecord base""" | ||||
| import numpy as np | |||||
| import os | import os | ||||
| import uuid | import uuid | ||||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | ||||
| @@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord" | |||||
| CV3_FILE_NAME = "./imagenet_append.mindrecord" | CV3_FILE_NAME = "./imagenet_append.mindrecord" | ||||
| NLP_FILE_NAME = "./aclImdb.mindrecord" | NLP_FILE_NAME = "./aclImdb.mindrecord" | ||||
| def test_write_read_process(): | |||||
| mindrecord_file_name = "test.mindrecord" | |||||
| data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), | |||||
| "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), | |||||
| "data": bytes("image bytes abc", encoding='UTF-8')}, | |||||
| {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), | |||||
| "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), | |||||
| "data": bytes("image bytes def", encoding='UTF-8')}, | |||||
| {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), | |||||
| "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), | |||||
| "data": bytes("image bytes ghi", encoding='UTF-8')}, | |||||
| {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), | |||||
| "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), | |||||
| "data": bytes("image bytes jkl", encoding='UTF-8')}, | |||||
| {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), | |||||
| "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), | |||||
| "data": bytes("image bytes mno", encoding='UTF-8')}, | |||||
| {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), | |||||
| "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), | |||||
| "data": bytes("image bytes pqr", encoding='UTF-8')} | |||||
| ] | |||||
| writer = FileWriter(mindrecord_file_name) | |||||
| schema = {"file_name": {"type": "string"}, | |||||
| "label": {"type": "int32"}, | |||||
| "score": {"type": "float64"}, | |||||
| "mask": {"type": "int64", "shape": [-1]}, | |||||
| "segments": {"type": "float32", "shape": [2, 2]}, | |||||
| "data": {"type": "bytes"}} | |||||
| writer.add_schema(schema, "data is so cool") | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| reader = FileReader(mindrecord_file_name) | |||||
| count = 0 | |||||
| for index, x in enumerate(reader.get_next()): | |||||
| assert len(x) == 6 | |||||
| for field in x: | |||||
| if isinstance(x[field], np.ndarray): | |||||
| assert (x[field] == data[count][field]).all() | |||||
| else: | |||||
| assert x[field] == data[count][field] | |||||
| count = count + 1 | |||||
| logger.info("#item{}: {}".format(index, x)) | |||||
| assert count == 6 | |||||
| reader.close() | |||||
| os.remove("{}".format(mindrecord_file_name)) | |||||
| os.remove("{}.db".format(mindrecord_file_name)) | |||||
| def test_write_read_process_with_define_index_field(): | |||||
| mindrecord_file_name = "test.mindrecord" | |||||
| data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), | |||||
| "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), | |||||
| "data": bytes("image bytes abc", encoding='UTF-8')}, | |||||
| {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), | |||||
| "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), | |||||
| "data": bytes("image bytes def", encoding='UTF-8')}, | |||||
| {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), | |||||
| "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), | |||||
| "data": bytes("image bytes ghi", encoding='UTF-8')}, | |||||
| {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), | |||||
| "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), | |||||
| "data": bytes("image bytes jkl", encoding='UTF-8')}, | |||||
| {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), | |||||
| "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), | |||||
| "data": bytes("image bytes mno", encoding='UTF-8')}, | |||||
| {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), | |||||
| "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), | |||||
| "data": bytes("image bytes pqr", encoding='UTF-8')} | |||||
| ] | |||||
| writer = FileWriter(mindrecord_file_name) | |||||
| schema = {"file_name": {"type": "string"}, | |||||
| "label": {"type": "int32"}, | |||||
| "score": {"type": "float64"}, | |||||
| "mask": {"type": "int64", "shape": [-1]}, | |||||
| "segments": {"type": "float32", "shape": [2, 2]}, | |||||
| "data": {"type": "bytes"}} | |||||
| writer.add_schema(schema, "data is so cool") | |||||
| writer.add_index(["label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| reader = FileReader(mindrecord_file_name) | |||||
| count = 0 | |||||
| for index, x in enumerate(reader.get_next()): | |||||
| assert len(x) == 6 | |||||
| for field in x: | |||||
| if isinstance(x[field], np.ndarray): | |||||
| assert (x[field] == data[count][field]).all() | |||||
| else: | |||||
| assert x[field] == data[count][field] | |||||
| count = count + 1 | |||||
| logger.info("#item{}: {}".format(index, x)) | |||||
| assert count == 6 | |||||
| reader.close() | |||||
| os.remove("{}".format(mindrecord_file_name)) | |||||
| os.remove("{}.db".format(mindrecord_file_name)) | |||||
| def test_cv_file_writer_tutorial(): | def test_cv_file_writer_tutorial(): | ||||
| """tutorial for cv dataset writer.""" | """tutorial for cv dataset writer.""" | ||||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | writer = FileWriter(CV_FILE_NAME, FILES_NUM) | ||||
| @@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial(): | |||||
| assert len(row1[0]) == 3 | assert len(row1[0]) == 3 | ||||
| assert row1[0]['label'] == 822 | assert row1[0]['label'] == 822 | ||||
| def test_cv_page_reader_tutorial_by_file_name(): | |||||
| """tutorial for cv page reader.""" | |||||
| reader = MindPage(CV_FILE_NAME + "0") | |||||
| fields = reader.get_category_fields() | |||||
| assert fields == ['file_name', 'label'],\ | |||||
| 'failed on getting candidate category fields.' | |||||
| ret = reader.set_category_field("file_name") | |||||
| assert ret == SUCCESS, 'failed on setting category field.' | |||||
| info = reader.read_category_info() | |||||
| logger.info("category info: {}".format(info)) | |||||
| row = reader.read_at_page_by_id(0, 0, 1) | |||||
| assert len(row) == 1 | |||||
| assert len(row[0]) == 3 | |||||
| assert row[0]['label'] == 490 | |||||
| row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1) | |||||
| assert len(row1) == 1 | |||||
| assert len(row1[0]) == 3 | |||||
| assert row1[0]['label'] == 13 | |||||
| def test_cv_page_reader_tutorial_new_api(): | |||||
| """tutorial for cv page reader.""" | |||||
| reader = MindPage(CV_FILE_NAME + "0") | |||||
| fields = reader.candidate_fields | |||||
| assert fields == ['file_name', 'label'],\ | |||||
| 'failed on getting candidate category fields.' | |||||
| reader.category_field = "file_name" | |||||
| info = reader.read_category_info() | |||||
| logger.info("category info: {}".format(info)) | |||||
| row = reader.read_at_page_by_id(0, 0, 1) | |||||
| assert len(row) == 1 | |||||
| assert len(row[0]) == 3 | |||||
| assert row[0]['label'] == 490 | |||||
| row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1) | |||||
| assert len(row1) == 1 | |||||
| assert len(row1[0]) == 3 | |||||
| assert row1[0]['label'] == 13 | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| for x in paths: | for x in paths: | ||||
| @@ -15,8 +15,9 @@ | |||||
| """test mindrecord exception""" | """test mindrecord exception""" | ||||
| import os | import os | ||||
| import pytest | import pytest | ||||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage | |||||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError | |||||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | |||||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \ | |||||
| MRMFetchDataError | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from utils import get_data | from utils import get_data | ||||
| @@ -286,3 +287,67 @@ def test_add_index_without_add_schema(): | |||||
| fw = FileWriter(CV_FILE_NAME) | fw = FileWriter(CV_FILE_NAME) | ||||
| fw.add_index(["label"]) | fw.add_index(["label"]) | ||||
| assert 'Failed to get meta info' in str(err.value) | assert 'Failed to get meta info' in str(err.value) | ||||
| def test_mindpage_pageno_pagesize_not_int(): | |||||
| """test page reader when some partition does not exist.""" | |||||
| create_cv_mindrecord(4) | |||||
| reader = MindPage(CV_FILE_NAME + "0") | |||||
| fields = reader.get_category_fields() | |||||
| assert fields == ['file_name', 'label'],\ | |||||
| 'failed on getting candidate category fields.' | |||||
| ret = reader.set_category_field("label") | |||||
| assert ret == SUCCESS, 'failed on setting category field.' | |||||
| info = reader.read_category_info() | |||||
| logger.info("category info: {}".format(info)) | |||||
| with pytest.raises(ParamValueError) as err: | |||||
| reader.read_at_page_by_id(0, "0", 1) | |||||
| with pytest.raises(ParamValueError) as err: | |||||
| reader.read_at_page_by_id(0, 0, "b") | |||||
| with pytest.raises(ParamValueError) as err: | |||||
| reader.read_at_page_by_name("822", "e", 1) | |||||
| with pytest.raises(ParamValueError) as err: | |||||
| reader.read_at_page_by_name("822", 0, "qwer") | |||||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | |||||
| reader.read_at_page_by_id(99999, 0, 1) | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||||
| for x in range(FILES_NUM)] | |||||
| for x in paths: | |||||
| os.remove("{}".format(x)) | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_mindpage_filename_not_exist(): | |||||
| """test page reader when some partition does not exist.""" | |||||
| create_cv_mindrecord(4) | |||||
| reader = MindPage(CV_FILE_NAME + "0") | |||||
| fields = reader.get_category_fields() | |||||
| assert fields == ['file_name', 'label'],\ | |||||
| 'failed on getting candidate category fields.' | |||||
| ret = reader.set_category_field("file_name") | |||||
| assert ret == SUCCESS, 'failed on setting category field.' | |||||
| info = reader.read_category_info() | |||||
| logger.info("category info: {}".format(info)) | |||||
| with pytest.raises(MRMFetchDataError) as err: | |||||
| reader.read_at_page_by_id(9999, 0, 1) | |||||
| with pytest.raises(MRMFetchDataError) as err: | |||||
| reader.read_at_page_by_name("abc.jpg", 0, 1) | |||||
| with pytest.raises(ParamValueError) as err: | |||||
| reader.read_at_page_by_name(1, 0, 1) | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||||
| for x in range(FILES_NUM)] | |||||
| for x in paths: | |||||
| os.remove("{}".format(x)) | |||||
| os.remove("{}.db".format(x)) | |||||