updates review touches ci fixes ci fix ci fixin ci fix review updates further cleanup updates updates update update commentpull/15125/head
| @@ -25,7 +25,7 @@ | |||||
| #include "minddata/dataset/include/constants.h" | #include "minddata/dataset/include/constants.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| @@ -67,9 +67,13 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) { | |||||
| if (build_num_padded_ > 0) { | if (build_num_padded_ > 0) { | ||||
| sample_json = ToJson(build_sample_); | sample_json = ToJson(build_sample_); | ||||
| } | } | ||||
| new_mind_record_op = std::make_shared<MindRecordOp>( | |||||
| build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, build_op_connector_queue_size_, | |||||
| build_columns_to_load_, build_operators_, build_num_padded_, sample_json, build_sample_bytes_); | |||||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | |||||
| new_mind_record_op = | |||||
| std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, | |||||
| build_op_connector_queue_size_, build_columns_to_load_, build_operators_, | |||||
| build_num_padded_, sample_json, build_sample_bytes_, std::move(shard_reader)); | |||||
| RETURN_IF_NOT_OK(new_mind_record_op->Init()); | RETURN_IF_NOT_OK(new_mind_record_op->Init()); | ||||
| *ptr = std::move(new_mind_record_op); | *ptr = std::move(new_mind_record_op); | ||||
| @@ -110,8 +114,10 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { | |||||
| MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, | MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, | ||||
| int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded, | const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded, | ||||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes) | |||||
| : MappableLeafOp(num_mind_record_workers, op_connector_queue_size, std::make_shared<SequentialSamplerRT>(0, 0)), | |||||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes, | |||||
| std::unique_ptr<ShardReader> shard_reader) | |||||
| : MappableLeafOp(num_mind_record_workers, op_connector_queue_size, | |||||
| std::make_shared<MindRecordSamplerRT>(shard_reader.get())), | |||||
| dataset_file_(dataset_file), | dataset_file_(dataset_file), | ||||
| load_dataset_(load_dataset), | load_dataset_(load_dataset), | ||||
| columns_to_load_(columns_to_load), | columns_to_load_(columns_to_load), | ||||
| @@ -120,7 +126,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str | |||||
| ended_worker_(0), | ended_worker_(0), | ||||
| num_padded_(num_padded), | num_padded_(num_padded), | ||||
| sample_json_(sample_json), | sample_json_(sample_json), | ||||
| sample_bytes_(sample_bytes) { | |||||
| sample_bytes_(sample_bytes), | |||||
| shard_reader_(std::move(shard_reader)) { | |||||
| io_block_queues_.Init(num_workers_, op_connector_queue_size); | io_block_queues_.Init(num_workers_, op_connector_queue_size); | ||||
| epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all | epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all | ||||
| // tasks are consumed by the worker threads would cause problem. | // tasks are consumed by the worker threads would cause problem. | ||||
| @@ -128,7 +135,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str | |||||
| // Private helper method to encapsulate some common construction/reset tasks | // Private helper method to encapsulate some common construction/reset tasks | ||||
| Status MindRecordOp::Init() { | Status MindRecordOp::Init() { | ||||
| shard_reader_ = std::make_unique<ShardReader>(); | |||||
| auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, | auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, | ||||
| num_padded_); | num_padded_); | ||||
| @@ -363,9 +369,6 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint | |||||
| Status MindRecordOp::Reset() { | Status MindRecordOp::Reset() { | ||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | MS_LOG(DEBUG) << Name() << " performing a self-reset."; | ||||
| RETURN_IF_NOT_OK(MappableLeafOp::Reset()); // Call our super class reset first. | RETURN_IF_NOT_OK(MappableLeafOp::Reset()); // Call our super class reset first. | ||||
| shard_reader_->ShuffleTask(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -140,7 +140,8 @@ class MindRecordOp : public MappableLeafOp { | |||||
| MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, | MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, | ||||
| int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_, | const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_, | ||||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_); | |||||
| const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_, | |||||
| std::unique_ptr<ShardReader> shard_reader); | |||||
| // Destructor | // Destructor | ||||
| ~MindRecordOp() override; | ~MindRecordOp() override; | ||||
| @@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES | |||||
| subset_random_sampler.cc | subset_random_sampler.cc | ||||
| subset_sampler.cc | subset_sampler.cc | ||||
| weighted_random_sampler.cc | weighted_random_sampler.cc | ||||
| mind_record_sampler.cc | |||||
| ) | ) | ||||
| if(ENABLE_PYTHON) | if(ENABLE_PYTHON) | ||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||||
| #include <algorithm> | |||||
| #include <limits> | |||||
| #include <memory> | |||||
| #include "minddata/mindrecord/include/shard_reader.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| MindRecordSamplerRT::MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, int64_t samples_per_tensor) | |||||
| : SamplerRT(0, samples_per_tensor), next_id_(0), shard_reader_(shard_reader) {} | |||||
| Status MindRecordSamplerRT::GetNextSample(TensorRow *out) { | |||||
| if (next_id_ > num_samples_) { | |||||
| RETURN_STATUS_UNEXPECTED("MindRecordSampler Internal Error"); | |||||
| } else if (next_id_ == num_samples_) { | |||||
| (*out) = TensorRow(TensorRow::kFlagEOE); | |||||
| } else { | |||||
| std::shared_ptr<Tensor> sampleIdsTensor; | |||||
| int64_t last_id = std::min(samples_per_tensor_ + next_id_, num_samples_); | |||||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIdsTensor, last_id - next_id_)); | |||||
| auto id_ptr = sampleIdsTensor->begin<int64_t>(); | |||||
| for (int64_t i = 0; i < (last_id - next_id_); i++) { | |||||
| *(id_ptr + static_cast<ptrdiff_t>(i)) = (*sample_ids_)[i]; | |||||
| } | |||||
| next_id_ = last_id; | |||||
| (*out) = {sampleIdsTensor}; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MindRecordSamplerRT::InitSampler() { | |||||
| sample_ids_ = shard_reader_->GetSampleIds(); | |||||
| if (!sample_ids_) { | |||||
| // Note, sample_ids_.empty() is okay and will just give no sample ids. | |||||
| RETURN_STATUS_UNEXPECTED("ShardReader did not provide a valid sample ids vector via MindRecordSamplerRT"); | |||||
| } | |||||
| // Usually, the num samples is given from the user interface. In our case, that data is in mindrecord. | |||||
| // Mindrecord already created the sample ids at this point, so the num samples is the size of the sampled id list. | |||||
| num_samples_ = sample_ids_->size(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status MindRecordSamplerRT::ResetSampler() { | |||||
| // drive the shard reader reshuffle tasks to redo the sampling for another epoch | |||||
| next_id_ = 0; | |||||
| shard_reader_->ShuffleTask(); | |||||
| return Status::OK(); | |||||
| } | |||||
| void MindRecordSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: MindRecordSampler"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | |||||
| } | |||||
| Status MindRecordSamplerRT::to_json(nlohmann::json *out_json) { | |||||
| nlohmann::json args; | |||||
| args["sampler_name"] = "MindRecordSampler"; | |||||
| *out_json = args; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ | |||||
| #include <limits> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||||
| #include "minddata/mindrecord/include/shard_reader.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class MindRecordSamplerRT : public SamplerRT { | |||||
| public: | |||||
| // Constructor | |||||
| // @param shard_reader - shard_reader | |||||
| // @param int64_t samples_per_tensor - Num of Sampler Ids to fetch via 1 GetNextSample call | |||||
| MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, | |||||
| int64_t samples_per_tensor = std::numeric_limits<int64_t>::max()); | |||||
| // Destructor. | |||||
| ~MindRecordSamplerRT() = default; | |||||
| // Op calls this to get next set of sampleIds | |||||
| // @param out - Tensor of sample ids to be returned to caller | |||||
| // @return Status The status code returned | |||||
| Status GetNextSample(TensorRow *out) override; | |||||
| // meant to be called by base class or python | |||||
| Status InitSampler() override; | |||||
| // for next epoch of sampleIds | |||||
| // @return Status The status code returned | |||||
| Status ResetSampler() override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| /// \brief Get the arguments of node | |||||
| /// \param[out] out_json JSON string of all attributes | |||||
| /// \return Status of the function | |||||
| Status to_json(nlohmann::json *out_json) override; | |||||
| private: | |||||
| mindrecord::ShardReader *shard_reader_; // back pointer to the shard reader | |||||
| const std::vector<int> *sample_ids_; // read-only back pointer into mind record sampler ids | |||||
| int64_t next_id_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ | |||||
| @@ -23,6 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| @@ -155,17 +156,19 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||||
| RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); | RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); | ||||
| std::shared_ptr<MindRecordOp> mindrecord_op; | std::shared_ptr<MindRecordOp> mindrecord_op; | ||||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | |||||
| // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | ||||
| // else if pass a vector to MindData(), it will be treated as specified files to be read | // else if pass a vector to MindData(), it will be treated as specified files to be read | ||||
| if (search_for_pattern_) { | if (search_for_pattern_) { | ||||
| std::vector<std::string> dataset_file_vec_ = {dataset_file_}; | std::vector<std::string> dataset_file_vec_ = {dataset_file_}; | ||||
| mindrecord_op = | |||||
| std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, | |||||
| columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_); | |||||
| mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_, | |||||
| connector_que_size_, columns_list_, operators_, num_padded_, | |||||
| padded_sample_, sample_bytes_, std::move(shard_reader)); | |||||
| } else { | } else { | ||||
| mindrecord_op = | |||||
| std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, | |||||
| columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_); | |||||
| mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_, | |||||
| connector_que_size_, columns_list_, operators_, num_padded_, | |||||
| padded_sample_, sample_bytes_, std::move(shard_reader)); | |||||
| } | } | ||||
| RETURN_IF_NOT_OK(mindrecord_op->Init()); | RETURN_IF_NOT_OK(mindrecord_op->Init()); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato | |||||
| bool GetReplacement() const { return replacement_; } | bool GetReplacement() const { return replacement_; } | ||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -39,15 +39,15 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha | |||||
| ~ShardDistributedSample() override{}; | ~ShardDistributedSample() override{}; | ||||
| MSRStatus PreExecute(ShardTask &tasks) override; | |||||
| MSRStatus PreExecute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| int no_of_padded_samples_; | int no_of_padded_samples_; | ||||
| bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch | |||||
| ShardTask task_; // maintain the input tasks in first epoch | |||||
| bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch | |||||
| ShardTaskList task_; // maintain the input tasks in first epoch | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -18,7 +18,7 @@ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/mindrecord/include/shard_task.h" | |||||
| #include "minddata/mindrecord/include/shard_task_list.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| @@ -26,7 +26,7 @@ class __attribute__((visibility("default"))) ShardOperator { | |||||
| public: | public: | ||||
| virtual ~ShardOperator() = default; | virtual ~ShardOperator() = default; | ||||
| MSRStatus operator()(ShardTask &tasks) { | |||||
| MSRStatus operator()(ShardTaskList &tasks) { | |||||
| if (SUCCESS != this->PreExecute(tasks)) { | if (SUCCESS != this->PreExecute(tasks)) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -47,11 +47,11 @@ class __attribute__((visibility("default"))) ShardOperator { | |||||
| virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | ||||
| virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus Execute(ShardTask &tasks) = 0; | |||||
| virtual MSRStatus Execute(ShardTaskList &tasks) = 0; | |||||
| virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; } | |||||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor | |||||
| ~ShardPkSample() override{}; | ~ShardPkSample() override{}; | ||||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples() const { return num_samples_; } | int64_t GetNumSamples() const { return num_samples_; } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -210,6 +210,9 @@ class API_PUBLIC ShardReader { | |||||
| /// \brief get the size of blob data | /// \brief get the size of blob data | ||||
| MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | ||||
| /// \brief get a read-only ptr to the sampled ids for this epoch | |||||
| const std::vector<int> *GetSampleIds(); | |||||
| protected: | protected: | ||||
| /// \brief sqlite call back function | /// \brief sqlite call back function | ||||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | ||||
| @@ -322,7 +325,7 @@ class API_PUBLIC ShardReader { | |||||
| std::vector<std::string> selected_columns_; // columns which will be read | std::vector<std::string> selected_columns_; // columns which will be read | ||||
| std::map<string, uint64_t> column_schema_id_; // column-schema map | std::map<string, uint64_t> column_schema_id_; // column-schema map | ||||
| std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category | std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category | ||||
| ShardTask tasks_; // shard task | |||||
| ShardTaskList tasks_; // shard task list | |||||
| std::mutex shard_locker_; // locker of shard | std::mutex shard_locker_; // locker of shard | ||||
| // flags | // flags | ||||
| @@ -339,7 +342,7 @@ class API_PUBLIC ShardReader { | |||||
| std::mutex mtx_delivery_; // locker for delivery | std::mutex mtx_delivery_; // locker for delivery | ||||
| std::condition_variable cv_delivery_; // conditional variable for delivery | std::condition_variable cv_delivery_; // conditional variable for delivery | ||||
| std::condition_variable cv_iterator_; // conditional variable for iterator | std::condition_variable cv_iterator_; // conditional variable for iterator | ||||
| std::atomic<int> task_id_; // task ID which is working | |||||
| std::atomic<int> sample_id_position_; // index into the sample ids vector for the current sample id | |||||
| std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator | std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator | ||||
| // map of delivery | // map of delivery | ||||
| std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_; | std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator | |||||
| ~ShardSample() override{}; | ~ShardSample() override{}; | ||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| MSRStatus UpdateTasks(ShardTask &tasks, int taking); | |||||
| MSRStatus UpdateTasks(ShardTaskList &tasks, int taking); | |||||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar | |||||
| ~ShardSequentialSample() override{}; | ~ShardSequentialSample() override{}; | ||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -31,11 +31,14 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator | |||||
| ~ShardShuffle() override{}; | ~ShardShuffle() override{}; | ||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| private: | private: | ||||
| // Private helper function | |||||
| MSRStatus CategoryShuffle(ShardTaskList &tasks); | |||||
| uint32_t shuffle_seed_; | uint32_t shuffle_seed_; | ||||
| int64_t no_of_samples_; | int64_t no_of_samples_; | ||||
| bool replacement_; | bool replacement_; | ||||
| @@ -1,109 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||||
| #include <algorithm> | |||||
| #include <iostream> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| class __attribute__((visibility("default"))) ShardTask { | |||||
| public: | |||||
| ShardTask(); | |||||
| ShardTask(const ShardTask &task); // copy construction | |||||
| ShardTask &operator=(const ShardTask &task); // assignment operator | |||||
| ~ShardTask() = default; | |||||
| void MakePerm(); | |||||
| inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label); | |||||
| inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label); | |||||
| inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||||
| inline void InsertTask(const uint32_t &i, | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task); | |||||
| void PopBack(); | |||||
| uint32_t Size() const; | |||||
| uint32_t SizeOfRows() const; | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id); | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | |||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||||
| int64_t num_samples); | |||||
| inline void ResizeTask(const uint32_t &size); | |||||
| uint32_t categories; | |||||
| // The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n] | |||||
| std::vector<int> permutation_; | |||||
| // The data struct is as below: | |||||
| // 1. TaskType: kCommonTask / kPaddedTask | |||||
| // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load) | |||||
| // 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields | |||||
| std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_; | |||||
| }; | |||||
| inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label) { | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id | |||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); | |||||
| } | |||||
| inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label) { | |||||
| task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; | |||||
| } | |||||
| inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) | |||||
| << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() | |||||
| << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.push_back(std::move(task)); | |||||
| } | |||||
| inline void ShardTask::InsertTask(const uint32_t &i, | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) { | |||||
| task_list_[i] = std::move(task); | |||||
| } | |||||
| inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); } | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||||
| @@ -0,0 +1,132 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||||
| #include <algorithm> | |||||
| #include <iostream> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| // The data struct is as below: | |||||
| // 1. TaskType: kCommonTask / kPaddedTask | |||||
| // 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load) | |||||
| // 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields | |||||
| using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>; | |||||
| class __attribute__((visibility("default"))) ShardTaskList { | |||||
| public: | |||||
| ShardTaskList(); | |||||
| ShardTaskList(const ShardTaskList &task); // copy construction | |||||
| ShardTaskList &operator=(const ShardTaskList &task); // assignment operator | |||||
| ~ShardTaskList() = default; | |||||
| void InitSampleIds(); | |||||
| static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks); | |||||
| // Assigns the task based on task id | |||||
| inline void AssignTask(ShardTaskList &sourceTasks, size_t id); | |||||
| inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset, | |||||
| const json &label); | |||||
| inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label); | |||||
| inline void InsertTask(ShardTask task); | |||||
| inline void InsertTask(const uint32_t &i, ShardTask task); | |||||
| void MakePerm(); | |||||
| inline void InsertSampleId(int id); | |||||
| void PopBack(); | |||||
| uint32_t Size() const; | |||||
| uint32_t SizeOfRows() const; | |||||
| ShardTask &GetTaskByID(size_t id); | |||||
| ShardTask &GetRandomTask(); | |||||
| int GetTaskSampleByID(size_t id); | |||||
| int GetRandomTaskID(); | |||||
| static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements, | |||||
| int64_t num_samples); | |||||
| inline void ResizeTask(const uint32_t &size); | |||||
| uint32_t categories; | |||||
| std::vector<int> permutation_; // A list of ints used for shuffling sample ids | |||||
| std::vector<int> sample_ids_; // The list of actual ids that were sampled | |||||
| std::vector<ShardTask> task_list_; // The full list of tasks | |||||
| }; | |||||
| inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) { | |||||
| // Insert the sample id from the source into ourself by indexing at id position. | |||||
| // Important: The task list itself does not change. | |||||
| int sample_id = sourceTasks.GetTaskSampleByID(id); | |||||
| MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id; | |||||
| sample_ids_.push_back(sample_id); | |||||
| } | |||||
| inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label) { | |||||
| MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id | |||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); | |||||
| } | |||||
| inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, | |||||
| const std::vector<uint64_t> &offset, const json &label) { | |||||
| MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id | |||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; | |||||
| } | |||||
| inline void ShardTaskList::InsertTask(ShardTask task) { | |||||
| MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task)) | |||||
| << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() | |||||
| << ", size of task_list_: " << task_list_.size() << "."; | |||||
| task_list_.push_back(std::move(task)); | |||||
| } | |||||
| inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); } | |||||
| inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); } | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -46,7 +46,7 @@ ShardReader::ShardReader() | |||||
| num_padded_(0), | num_padded_(0), | ||||
| num_rows_(0), | num_rows_(0), | ||||
| total_blob_size_(0), | total_blob_size_(0), | ||||
| task_id_(0), | |||||
| sample_id_position_(0), | |||||
| deliver_id_(0), | deliver_id_(0), | ||||
| lazy_load_(false), | lazy_load_(false), | ||||
| shard_sample_count_() {} | shard_sample_count_() {} | ||||
| @@ -1088,9 +1088,8 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator | |||||
| i++; | i++; | ||||
| } | } | ||||
| } | } | ||||
| // Generate task list, a task will create a batch | |||||
| std::vector<ShardTask> categoryTasks(categories.size()); | |||||
| // Generate a vector of task lists. Each catogory has a list of tasks. | |||||
| std::vector<ShardTaskList> categoryTasks(categories.size()); | |||||
| for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { | for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { | ||||
| int category_index = 0; | int category_index = 0; | ||||
| for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) { | for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) { | ||||
| @@ -1122,7 +1121,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); | |||||
| tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); | |||||
| tasks_.InitSampleIds(); | |||||
| if (SUCCESS != (*category_op)(tasks_)) { | if (SUCCESS != (*category_op)(tasks_)) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1246,6 +1247,10 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Created initial list of tasks. There are " << tasks_.Size() << " to start with before sampling."; | |||||
| tasks_.InitSampleIds(); | |||||
| for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { | for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { | ||||
| const auto &op = operators[operator_no]; | const auto &op = operators[operator_no]; | ||||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) continue; | if (std::dynamic_pointer_cast<ShardCategory>(op)) continue; | ||||
| @@ -1256,7 +1261,9 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u | |||||
| if (tasks_.permutation_.empty()) tasks_.MakePerm(); | if (tasks_.permutation_.empty()) tasks_.MakePerm(); | ||||
| num_rows_ = tasks_.Size(); | num_rows_ = tasks_.Size(); | ||||
| MS_LOG(INFO) << "Total rows is " << num_rows_; | |||||
| MS_LOG(INFO) << "Total rows is " << num_rows_ | |||||
| << " and total amount sampled initially is: " << tasks_.sample_ids_.size(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1272,9 +1279,9 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||||
| uint32_t blob_start = 0; | uint32_t blob_start = 0; | ||||
| uint32_t blob_end = 0; | uint32_t blob_end = 0; | ||||
| json var_fields; | json var_fields; | ||||
| // Pick up task from task list | // Pick up task from task list | ||||
| auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); | |||||
| ShardTask task; | |||||
| task = tasks_.GetTaskByID(task_id); | |||||
| // check task type | // check task type | ||||
| auto task_type = std::get<0>(task); | auto task_type = std::get<0>(task); | ||||
| @@ -1354,16 +1361,16 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { | |||||
| // Loop forever | // Loop forever | ||||
| for (;;) { | for (;;) { | ||||
| int task_id = 0; | |||||
| int sample_id_pos = 0; | |||||
| // Get next task ID | // Get next task ID | ||||
| task_id = task_id_++; | |||||
| sample_id_pos = sample_id_position_++; | |||||
| // All tasks are done | // All tasks are done | ||||
| if (task_id >= static_cast<int>(tasks_.Size())) { | |||||
| if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| const auto &ret = ConsumerOneTask(task_id, consumer_id); | |||||
| const auto &ret = ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id); | |||||
| if (SUCCESS != ret.first) { | if (SUCCESS != ret.first) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1372,11 +1379,13 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { | |||||
| // otherwise, set batch data in map | // otherwise, set batch data in map | ||||
| { | { | ||||
| std::unique_lock<std::mutex> lck(mtx_delivery_); | std::unique_lock<std::mutex> lck(mtx_delivery_); | ||||
| cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); | |||||
| cv_delivery_.wait(lck, | |||||
| [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; }); | |||||
| if (interrupt_) { | if (interrupt_) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| delivery_map_[task_id] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch)); | |||||
| delivery_map_[sample_id_pos] = | |||||
| std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch)); | |||||
| } | } | ||||
| cv_iterator_.notify_one(); | cv_iterator_.notify_one(); | ||||
| } | } | ||||
| @@ -1386,7 +1395,7 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() { | |||||
| if (interrupt_) { | if (interrupt_) { | ||||
| return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | ||||
| } | } | ||||
| if (deliver_id_ >= static_cast<int>(tasks_.Size())) { | |||||
| if (deliver_id_ >= static_cast<int>(tasks_.sample_ids_.size())) { | |||||
| return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | return std::vector<std::tuple<std::vector<uint8_t>, json>>(); | ||||
| } | } | ||||
| @@ -1458,7 +1467,7 @@ std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> Sha | |||||
| void ShardReader::Reset() { | void ShardReader::Reset() { | ||||
| { | { | ||||
| std::lock_guard<std::mutex> lck(mtx_delivery_); | std::lock_guard<std::mutex> lck(mtx_delivery_); | ||||
| task_id_ = 0; | |||||
| sample_id_position_ = 0; | |||||
| deliver_id_ = 0; | deliver_id_ = 0; | ||||
| } | } | ||||
| cv_delivery_.notify_all(); | cv_delivery_.notify_all(); | ||||
| @@ -1486,5 +1495,10 @@ void ShardReader::ShuffleTask() { | |||||
| if (tasks_.permutation_.empty()) tasks_.MakePerm(); | if (tasks_.permutation_.empty()) tasks_.MakePerm(); | ||||
| } | } | ||||
| const std::vector<int> *ShardReader::GetSampleIds() { | |||||
| // return const reference to private sample id list. | |||||
| return &(this->tasks_.sample_ids_); | |||||
| } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem | |||||
| num_categories_(num_categories), | num_categories_(num_categories), | ||||
| replacement_(replacement) {} | replacement_(replacement) {} | ||||
| MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } | |||||
| MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; } | |||||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | ||||
| if (dataset_size == 0) return dataset_size; | if (dataset_size == 0) return dataset_size; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -55,7 +55,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_ | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { | |||||
| MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||||
| auto total_no = tasks.Size(); | auto total_no = tasks.Size(); | ||||
| if (no_of_padded_samples_ > 0 && first_epoch_) { | if (no_of_padded_samples_ > 0 && first_epoch_) { | ||||
| if (total_no % denominator_ != 0) { | if (total_no % denominator_ != 0) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -37,7 +37,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | ||||
| } | } | ||||
| MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { | |||||
| MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) { | |||||
| if (shuffle_ == true) { | if (shuffle_ == true) { | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | if (SUCCESS != (*shuffle_op_)(tasks)) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -80,21 +80,21 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) { | |||||
| MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||||
| if (tasks.permutation_.empty()) { | if (tasks.permutation_.empty()) { | ||||
| ShardTask new_tasks; | |||||
| int total_no = static_cast<int>(tasks.Size()); | |||||
| ShardTaskList new_tasks; | |||||
| int total_no = static_cast<int>(tasks.sample_ids_.size()); | |||||
| if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { | if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { | ||||
| for (int i = 0; i < indices_.size(); ++i) { | for (int i = 0; i < indices_.size(); ++i) { | ||||
| int index = ((indices_[i] % total_no) + total_no) % total_no; | int index = ((indices_[i] % total_no) + total_no) % total_no; | ||||
| new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python | |||||
| new_tasks.AssignTask(tasks, index); // different mod result between c and python | |||||
| } | } | ||||
| } else { | } else { | ||||
| int count = 0; | int count = 0; | ||||
| if (nums_per_shard_.empty()) { | if (nums_per_shard_.empty()) { | ||||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | ||||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | if (no_of_samples_ != 0 && count == no_of_samples_) break; | ||||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start | |||||
| new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start | |||||
| count++; | count++; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -102,33 +102,33 @@ MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) { | |||||
| size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0; | size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0; | ||||
| for (; i < nums_per_shard_[partition_id_]; i++) { | for (; i < nums_per_shard_[partition_id_]; i++) { | ||||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | if (no_of_samples_ != 0 && count == no_of_samples_) break; | ||||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); | |||||
| new_tasks.AssignTask(tasks, i % total_no); | |||||
| count++; | count++; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::swap(tasks, new_tasks); | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| } else { | } else { | ||||
| ShardTask new_tasks; | |||||
| if (taking > static_cast<int>(tasks.permutation_.size())) { | |||||
| ShardTaskList new_tasks; | |||||
| if (taking > static_cast<int>(tasks.sample_ids_.size())) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int total_no = static_cast<int>(tasks.permutation_.size()); | int total_no = static_cast<int>(tasks.permutation_.size()); | ||||
| int count = 0; | int count = 0; | ||||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | ||||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | if (no_of_samples_ != 0 && count == no_of_samples_) break; | ||||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | |||||
| count++; | count++; | ||||
| } | } | ||||
| std::swap(tasks, new_tasks); | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||||
| MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||||
| if (offset_ != -1) { | if (offset_ != -1) { | ||||
| int64_t old_v = 0; | int64_t old_v = 0; | ||||
| int num_rows_ = static_cast<int>(tasks.Size()); | |||||
| int num_rows_ = static_cast<int>(tasks.sample_ids_.size()); | |||||
| for (int x = 0; x < denominator_; x++) { | for (int x = 0; x < denominator_; x++) { | ||||
| int samples_per_buffer_ = (num_rows_ + offset_) / denominator_; | int samples_per_buffer_ = (num_rows_ + offset_) / denominator_; | ||||
| int remainder = (num_rows_ + offset_) % denominator_; | int remainder = (num_rows_ + offset_) % denominator_; | ||||
| @@ -140,8 +140,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||||
| } | } | ||||
| } | } | ||||
| int no_of_categories = static_cast<int>(tasks.categories); | int no_of_categories = static_cast<int>(tasks.categories); | ||||
| int total_no = static_cast<int>(tasks.Size()); // make sure task_size | |||||
| int total_no = static_cast<int>(tasks.sample_ids_.size()); | |||||
| int taking = 0; | int taking = 0; | ||||
| if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 | if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 | ||||
| no_of_samples_ = std::min(no_of_samples_, total_no); | no_of_samples_ = std::min(no_of_samples_, total_no); | ||||
| @@ -167,7 +166,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||||
| return UpdateTasks(tasks, taking); | return UpdateTasks(tasks, taking); | ||||
| } | } | ||||
| MSRStatus ShardSample::SufExecute(ShardTask &tasks) { | |||||
| MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) { | |||||
| if (sampler_type_ == kSubsetRandomSampler) { | if (sampler_type_ == kSubsetRandomSampler) { | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | if (SUCCESS != (*shuffle_op_)(tasks)) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -38,9 +38,9 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c | |||||
| return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); | return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); | ||||
| } | } | ||||
| MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { | |||||
| int64_t total_no = static_cast<int64_t>(tasks.Size()); | |||||
| MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||||
| int64_t taking; | int64_t taking; | ||||
| int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size()); | |||||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | ||||
| taking = total_no; | taking = total_no; | ||||
| } else if (per_ > kEpsilon && per_ <= 1.0f) { | } else if (per_ > kEpsilon && per_ <= 1.0f) { | ||||
| @@ -50,22 +50,22 @@ MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { | |||||
| } | } | ||||
| if (tasks.permutation_.empty()) { | if (tasks.permutation_.empty()) { | ||||
| ShardTask new_tasks; | |||||
| ShardTaskList new_tasks; | |||||
| total_no = static_cast<int64_t>(tasks.Size()); | total_no = static_cast<int64_t>(tasks.Size()); | ||||
| for (size_t i = offset_; i < taking + offset_; ++i) { | for (size_t i = offset_; i < taking + offset_; ++i) { | ||||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); | |||||
| new_tasks.AssignTask(tasks, i % total_no); | |||||
| } | } | ||||
| std::swap(tasks, new_tasks); | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| } else { // shuffled | } else { // shuffled | ||||
| ShardTask new_tasks; | |||||
| ShardTaskList new_tasks; | |||||
| if (taking > static_cast<int64_t>(tasks.permutation_.size())) { | if (taking > static_cast<int64_t>(tasks.permutation_.size())) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| total_no = static_cast<int64_t>(tasks.permutation_.size()); | total_no = static_cast<int64_t>(tasks.permutation_.size()); | ||||
| for (size_t i = offset_; i < taking + offset_; ++i) { | for (size_t i = offset_; i < taking + offset_; ++i) { | ||||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | |||||
| } | } | ||||
| std::swap(tasks, new_tasks); | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -42,7 +42,31 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); | return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); | ||||
| } | } | ||||
| MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||||
| MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||||
| uint32_t individual_size; | |||||
| individual_size = tasks.sample_ids_.size() / tasks.categories; | |||||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||||
| for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j); | |||||
| std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); | |||||
| } | |||||
| tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something | |||||
| for (uint32_t j = 0; j < individual_size; j++) { | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||||
| tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i)); | |||||
| } | |||||
| } | |||||
| ShardTaskList new_tasks; | |||||
| for (size_t i = 0; i < individual_size; ++i) { | |||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | |||||
| } | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| return SUCCESS; | |||||
| } | |||||
| MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||||
| if (reshuffle_each_epoch_) shuffle_seed_++; | if (reshuffle_each_epoch_) shuffle_seed_++; | ||||
| if (tasks.categories < 1) { | if (tasks.categories < 1) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -52,43 +76,31 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||||
| tasks.MakePerm(); | tasks.MakePerm(); | ||||
| } | } | ||||
| if (replacement_ == true) { | if (replacement_ == true) { | ||||
| ShardTask new_tasks; | |||||
| if (no_of_samples_ == 0) { | |||||
| no_of_samples_ = static_cast<int>(tasks.Size()); | |||||
| } | |||||
| ShardTaskList new_tasks; | |||||
| if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size()); | |||||
| if (no_of_samples_ <= 0) { | if (no_of_samples_ <= 0) { | ||||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | MS_LOG(ERROR) << "no_of_samples need to be positive."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| new_tasks.task_list_.reserve(no_of_samples_); | new_tasks.task_list_.reserve(no_of_samples_); | ||||
| for (uint32_t i = 0; i < no_of_samples_; ++i) { | for (uint32_t i = 0; i < no_of_samples_; ++i) { | ||||
| new_tasks.InsertTask(tasks.GetRandomTask()); | |||||
| new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); | |||||
| } | } | ||||
| std::swap(tasks, new_tasks); | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| } else { | } else { | ||||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | ||||
| auto total_no = static_cast<int64_t>(tasks.Size()); | auto total_no = static_cast<int64_t>(tasks.Size()); | ||||
| if (no_of_samples_ > 0 && no_of_samples_ < total_no) { | |||||
| ShardTask new_tasks; | |||||
| for (size_t i = 0; i < no_of_samples_; ++i) { | |||||
| new_tasks.InsertTask(tasks.GetTaskByID(i)); | |||||
| } | |||||
| std::swap(tasks, new_tasks); | |||||
| ShardTaskList new_tasks; | |||||
| size_t samples_to_assign = | |||||
| (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size(); | |||||
| for (size_t i = 0; i < samples_to_assign; ++i) { | |||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | |||||
| } | } | ||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||||
| } | } | ||||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | ||||
| uint32_t individual_size = tasks.Size() / tasks.categories; | |||||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||||
| for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j); | |||||
| std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); | |||||
| } | |||||
| tasks.permutation_.clear(); | |||||
| for (uint32_t j = 0; j < individual_size; j++) { | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||||
| tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i)); | |||||
| } | |||||
| } | |||||
| return this->CategoryShuffle(tasks); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| #include "minddata/mindrecord/include/shard_task.h" | |||||
| #include "minddata/mindrecord/include/shard_task_list.h" | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/mindrecord/include/common/shard_utils.h" | #include "minddata/mindrecord/include/common/shard_utils.h" | ||||
| @@ -25,55 +25,88 @@ using mindspore::MsLogLevel::DEBUG; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardTask::ShardTask() : categories(1) {} | |||||
| ShardTaskList::ShardTaskList() : categories(1) {} | |||||
| ShardTask::ShardTask(const ShardTask &other) | |||||
| : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} | |||||
| ShardTaskList::ShardTaskList(const ShardTaskList &other) | |||||
| : categories(other.categories), | |||||
| permutation_(other.permutation_), | |||||
| sample_ids_(other.sample_ids_), | |||||
| task_list_(other.task_list_) {} | |||||
| ShardTask &ShardTask::operator=(const ShardTask &other) { | |||||
| ShardTask tmp(other); | |||||
| ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) { | |||||
| ShardTaskList tmp(other); | |||||
| std::swap(categories, tmp.categories); | std::swap(categories, tmp.categories); | ||||
| permutation_.swap(tmp.permutation_); | permutation_.swap(tmp.permutation_); | ||||
| sample_ids_.swap(tmp.sample_ids_); | |||||
| task_list_.swap(tmp.task_list_); | task_list_.swap(tmp.task_list_); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| void ShardTask::MakePerm() { | |||||
| permutation_ = std::vector<int>(task_list_.size()); | |||||
| for (uint32_t i = 0; i < task_list_.size(); i++) { | |||||
| void ShardTaskList::InitSampleIds() { | |||||
| // no-op if there already exists sample ids. Do not clobber previous list | |||||
| if (sample_ids_.empty()) { | |||||
| sample_ids_ = std::vector<int>(task_list_.size()); | |||||
| for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i; | |||||
| } | |||||
| } | |||||
| void ShardTaskList::MakePerm() { | |||||
| size_t perm_size = sample_ids_.size(); | |||||
| permutation_ = std::vector<int>(perm_size); | |||||
| for (uint32_t i = 0; i < perm_size; i++) { | |||||
| permutation_[i] = static_cast<int>(i); | permutation_[i] = static_cast<int>(i); | ||||
| } | } | ||||
| } | } | ||||
| void ShardTask::PopBack() { task_list_.pop_back(); } | |||||
| // Swap the new_tasks with orig_tasks | |||||
| void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) { | |||||
| // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a | |||||
| // new_tasks that does not have those fields will result in clobbering/losing the data after the swap. | |||||
| // The task_list_ should not be lost/clobbered. | |||||
| new_tasks.task_list_ = std::move(orig_tasks.task_list_); | |||||
| uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); } | |||||
| // Now, it's safe to drive the swap. | |||||
| std::swap(orig_tasks, new_tasks); | |||||
| } | |||||
| uint32_t ShardTask::SizeOfRows() const { | |||||
| void ShardTaskList::PopBack() { task_list_.pop_back(); } | |||||
| uint32_t ShardTaskList::Size() const { return static_cast<uint32_t>(task_list_.size()); } | |||||
| uint32_t ShardTaskList::SizeOfRows() const { | |||||
| if (task_list_.size() == 0) return static_cast<uint32_t>(0); | if (task_list_.size() == 0) return static_cast<uint32_t>(0); | ||||
| // 1 task is 1 page | // 1 task is 1 page | ||||
| auto sum_num_rows = [](int x, std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> y) { | |||||
| return x + std::get<2>(y)[0]; | |||||
| }; | |||||
| auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<2>(y)[0]; }; | |||||
| uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); | uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); | ||||
| return nRows; | return nRows; | ||||
| } | } | ||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) { | |||||
| ShardTask &ShardTaskList::GetTaskByID(size_t id) { | |||||
| MS_ASSERT(id < task_list_.size()); | MS_ASSERT(id < task_list_.size()); | ||||
| return task_list_[id]; | return task_list_[id]; | ||||
| } | } | ||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() { | |||||
| int ShardTaskList::GetTaskSampleByID(size_t id) { | |||||
| MS_ASSERT(id < sample_ids_.size()); | |||||
| return sample_ids_[id]; | |||||
| } | |||||
| int ShardTaskList::GetRandomTaskID() { | |||||
| std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | |||||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | |||||
| return dis(gen); | |||||
| } | |||||
| ShardTask &ShardTaskList::GetRandomTask() { | |||||
| std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | ||||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | ||||
| return task_list_[dis(gen)]; | return task_list_[dis(gen)]; | ||||
| } | } | ||||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||||
| int64_t num_samples) { | |||||
| ShardTask res; | |||||
| ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements, | |||||
| int64_t num_samples) { | |||||
| ShardTaskList res; | |||||
| if (category_tasks.empty()) return res; | if (category_tasks.empty()) return res; | ||||
| auto total_categories = category_tasks.size(); | auto total_categories = category_tasks.size(); | ||||
| res.categories = static_cast<uint32_t>(total_categories); | res.categories = static_cast<uint32_t>(total_categories); | ||||
| @@ -107,6 +140,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return res; | return res; | ||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||