| @@ -21,6 +21,7 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <string> | #include <string> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/mindrecord/include/common/shard_utils.h" | #include "minddata/mindrecord/include/common/shard_utils.h" | ||||
| @@ -38,10 +39,16 @@ class ShardTask { | |||||
| void MakePerm(); | void MakePerm(); | ||||
| void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label); | |||||
| inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label); | |||||
| void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||||
| inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label); | |||||
| inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||||
| inline void InsertTask(const uint32_t &i, | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||||
| void PopBack(); | void PopBack(); | ||||
| @@ -56,12 +63,41 @@ class ShardTask { | |||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | ||||
| int64_t num_samples); | int64_t num_samples); | ||||
| inline void ResizeTask(const uint32_t &size); | |||||
| uint32_t categories; | uint32_t categories; | ||||
| std::vector<int> permutation_; | std::vector<int> permutation_; | ||||
| std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_; | std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_; | ||||
| }; | }; | ||||
| inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label) { | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id | |||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); | |||||
| } | |||||
| inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label) { | |||||
| task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; | |||||
| } | |||||
| inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) | |||||
| << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() | |||||
| << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.push_back(std::move(task)); | |||||
| } | |||||
| inline void ShardTask::InsertTask(const uint32_t &i, | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||||
| task_list_[i] = std::move(task); | |||||
| } | |||||
| inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include <thread> | |||||
| #include "minddata/mindrecord/include/shard_distributed_sample.h" | #include "minddata/mindrecord/include/shard_distributed_sample.h" | ||||
| #include "minddata/mindrecord/include/shard_reader.h" | #include "minddata/mindrecord/include/shard_reader.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| @@ -1036,15 +1039,37 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i | |||||
| if (std::get<0>(ret) != SUCCESS) { | if (std::get<0>(ret) != SUCCESS) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto offsets = std::get<1>(ret); | |||||
| auto local_columns = std::get<2>(ret); | |||||
| auto &offsets = std::get<1>(ret); | |||||
| auto &local_columns = std::get<2>(ret); | |||||
| if (shard_count_ <= kMaxFileCount) { | if (shard_count_ <= kMaxFileCount) { | ||||
| int sample_count = 0; | |||||
| for (int shard_id = 0; shard_id < shard_count_; shard_id++) { | for (int shard_id = 0; shard_id < shard_count_; shard_id++) { | ||||
| for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { | |||||
| tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], | |||||
| std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]}, | |||||
| local_columns[shard_id][i]); | |||||
| } | |||||
| sample_count += offsets[shard_id].size(); | |||||
| } | |||||
| MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; | |||||
| // Init the tasks_ size | |||||
| tasks_.ResizeTask(sample_count); | |||||
| // Init the task threads, maybe use ThreadPool is better | |||||
| std::vector<std::thread> init_tasks_thread(shard_count_); | |||||
| uint32_t current_offset = 0; | |||||
| for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { | |||||
| init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() { | |||||
| auto offset = current_offset; | |||||
| for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { | |||||
| tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], | |||||
| std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]}, | |||||
| local_columns[shard_id][i]); | |||||
| offset++; | |||||
| } | |||||
| }); | |||||
| current_offset += offsets[shard_id].size(); | |||||
| } | |||||
| for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { | |||||
| init_tasks_thread[shard_id].join(); | |||||
| } | } | ||||
| } else { | } else { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -44,21 +44,6 @@ void ShardTask::MakePerm() { | |||||
| } | } | ||||
| } | } | ||||
| void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label) { | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id | |||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); | |||||
| } | |||||
| void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) | |||||
| << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() | |||||
| << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.push_back(std::move(task)); | |||||
| } | |||||
| void ShardTask::PopBack() { task_list_.pop_back(); } | void ShardTask::PopBack() { task_list_.pop_back(); } | ||||
| uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); } | uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); } | ||||