Merge pull request !3644 from liyong126/fix_mindrecord_bugtags/v0.7.0-beta
| @@ -133,6 +133,7 @@ void BindGlobalParams(py::module *m) { | |||||
| (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; | (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; | ||||
| (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; | (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; | ||||
| (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; | (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; | ||||
| (*m).attr("MAX_FILE_COUNT") = kMaxFileCount; | |||||
| (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; | (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; | ||||
| (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); | (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); | ||||
| } | } | ||||
| @@ -104,7 +104,8 @@ const uint64_t kInt64Len = 8; | |||||
| const uint64_t kMinFileSize = kInt64Len; | const uint64_t kMinFileSize = kInt64Len; | ||||
| const int kMinShardCount = 1; | const int kMinShardCount = 1; | ||||
| const int kMaxShardCount = 1000; | |||||
| const int kMaxShardCount = 1000; // write | |||||
| const int kMaxFileCount = 4096; // read | |||||
| const int kMinConsumerCount = 1; | const int kMinConsumerCount = 1; | ||||
| const int kMaxConsumerCount = 128; | const int kMaxConsumerCount = 128; | ||||
| @@ -152,7 +152,7 @@ class ShardHeader { | |||||
| MSRStatus CheckIndexField(const std::string &field, const json &schema); | MSRStatus CheckIndexField(const std::string &field, const json &schema); | ||||
| void ParsePage(const json &page, int shard_index, bool load_dataset); | |||||
| MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset); | |||||
| MSRStatus ParseStatistics(const json &statistics); | MSRStatus ParseStatistics(const json &statistics); | ||||
| @@ -252,7 +252,7 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar | |||||
| if (shard_count <= 0) { | if (shard_count <= 0) { | ||||
| return row_group_summary; | return row_group_summary; | ||||
| } | } | ||||
| if (shard_count <= kMaxShardCount) { | |||||
| if (shard_count <= kMaxFileCount) { | |||||
| for (int shard_id = 0; shard_id < shard_count; ++shard_id) { | for (int shard_id = 0; shard_id < shard_count; ++shard_id) { | ||||
| // return -1 when page's size equals to 0. | // return -1 when page's size equals to 0. | ||||
| auto last_page_id = shard_header_->GetLastPageId(shard_id); | auto last_page_id = shard_header_->GetLastPageId(shard_id); | ||||
| @@ -1054,7 +1054,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i | |||||
| } | } | ||||
| auto offsets = std::get<1>(ret); | auto offsets = std::get<1>(ret); | ||||
| auto local_columns = std::get<2>(ret); | auto local_columns = std::get<2>(ret); | ||||
| if (shard_count_ <= kMaxShardCount) { | |||||
| if (shard_count_ <= kMaxFileCount) { | |||||
| for (int shard_id = 0; shard_id < shard_count_; shard_id++) { | for (int shard_id = 0; shard_id < shard_count_; shard_id++) { | ||||
| for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { | for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { | ||||
| tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], | tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], | ||||
| @@ -55,7 +55,9 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l | |||||
| header_size_ = header["header_size"].get<uint64_t>(); | header_size_ = header["header_size"].get<uint64_t>(); | ||||
| page_size_ = header["page_size"].get<uint64_t>(); | page_size_ = header["page_size"].get<uint64_t>(); | ||||
| } | } | ||||
| ParsePage(header["page"], shard_index, load_dataset); | |||||
| if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { | |||||
| return FAILED; | |||||
| } | |||||
| shard_index++; | shard_index++; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -248,11 +250,16 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||||
| MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||||
| // set shard_index when load_dataset is false | // set shard_index when load_dataset is false | ||||
| if (pages_.empty() && shard_count_ <= kMaxShardCount) { | |||||
| if (shard_count_ > kMaxFileCount) { | |||||
| MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount; | |||||
| return FAILED; | |||||
| } | |||||
| if (pages_.empty() && shard_count_ <= kMaxFileCount) { | |||||
| pages_.resize(shard_count_); | pages_.resize(shard_count_); | ||||
| } | } | ||||
| for (auto &page : pages) { | for (auto &page : pages) { | ||||
| int page_id = page["page_id"]; | int page_id = page["page_id"]; | ||||
| int shard_id = page["shard_id"]; | int shard_id = page["shard_id"]; | ||||
| @@ -275,6 +282,7 @@ void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_datase | |||||
| pages_[shard_index].push_back(std::move(parsed_page)); | pages_[shard_index].push_back(std::move(parsed_page)); | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | } | ||||
| MSRStatus ShardHeader::ParseStatistics(const json &statistics) { | MSRStatus ShardHeader::ParseStatistics(const json &statistics) { | ||||
| @@ -715,7 +723,9 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { | |||||
| std::string line; | std::string line; | ||||
| while (std::getline(page_in_handle, line)) { | while (std::getline(page_in_handle, line)) { | ||||
| ParsePage(json::parse(line), -1, true); | |||||
| if (SUCCESS != ParsePage(json::parse(line), -1, true)) { | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| page_in_handle.close(); | page_in_handle.close(); | ||||
| @@ -1054,45 +1054,45 @@ class Dataset: | |||||
| * - type in 'dataset' | * - type in 'dataset' | ||||
| - type in 'mindrecord' | - type in 'mindrecord' | ||||
| - detail | - detail | ||||
| * - DE_BOOL | |||||
| * - bool | |||||
| - None | - None | ||||
| - Not support | - Not support | ||||
| * - DE_INT8 | |||||
| * - int8 | |||||
| - int32 | - int32 | ||||
| - | - | ||||
| * - DE_UINT8 | |||||
| * - uint8 | |||||
| - bytes(1D uint8) | - bytes(1D uint8) | ||||
| - Drop dimension | - Drop dimension | ||||
| * - DE_INT16 | |||||
| * - int16 | |||||
| - int32 | - int32 | ||||
| - | - | ||||
| * - DE_UINT16 | |||||
| * - uint16 | |||||
| - int32 | - int32 | ||||
| - | - | ||||
| * - DE_INT32 | |||||
| * - int32 | |||||
| - int32 | - int32 | ||||
| - | - | ||||
| * - DE_UINT32 | |||||
| * - uint32 | |||||
| - int64 | - int64 | ||||
| - | - | ||||
| * - DE_INT64 | |||||
| * - int64 | |||||
| - int64 | - int64 | ||||
| - | - | ||||
| * - DE_UINT64 | |||||
| * - uint64 | |||||
| - None | - None | ||||
| - Not support | - Not support | ||||
| * - DE_FLOAT16 | |||||
| - Not support | |||||
| * - float16 | |||||
| - float32 | |||||
| - | - | ||||
| * - DE_FLOAT32 | |||||
| * - float32 | |||||
| - float32 | - float32 | ||||
| - | - | ||||
| * - DE_FLOAT64 | |||||
| * - float64 | |||||
| - float64 | - float64 | ||||
| - | - | ||||
| * - DE_STRING | |||||
| * - string | |||||
| - string | - string | ||||
| - Not support multi-dimensional DE_STRING | |||||
| - Not support multi-dimensional string | |||||
| Note: | Note: | ||||
| 1. To save the samples in order, should set dataset's shuffle false and num_files 1. | 1. To save the samples in order, should set dataset's shuffle false and num_files 1. | ||||
| @@ -278,6 +278,8 @@ def check_minddataset(method): | |||||
| dataset_file = param_dict.get('dataset_file') | dataset_file = param_dict.get('dataset_file') | ||||
| if isinstance(dataset_file, list): | if isinstance(dataset_file, list): | ||||
| if len(dataset_file) > 4096: | |||||
| raise ValueError("length of dataset_file should less than or equal to {}.".format(4096)) | |||||
| for f in dataset_file: | for f in dataset_file: | ||||
| check_file(f) | check_file(f) | ||||
| else: | else: | ||||